Coverage for astrocyte/audit.py: 64%

143 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Domain-level audit logging for Astrocyte operations. 

2 

3Adopted from Hindsight (`hindsight_api/engine/audit.py`). Captures every 

4retain/recall/classify/consolidate operation with structured input, 

5output, and metadata to a single ``astrocyte_audit_log`` table. 

6 

7Design properties: 

8 

9- **Fire-and-forget.** ``log()`` schedules an async write and returns 

10 immediately. Audit writes never block the calling operation. 

11- **Failure-isolating.** A failed audit write is logged at WARNING and 

12 swallowed. Never propagates to the caller. 

13- **Late-bound pool / schema.** Uses callables for pool + schema lookup 

14 so audit doesn't tightly couple to a specific store's lifecycle. 

15- **Action allowlist.** Audit is enabled per-action; default is empty 

16 (no audit). Production deployments opt actions in explicitly. 

17- **Retention sweep (opt-in).** Background task deletes entries older 

18 than ``retention_days``. Disabled by default (``retention_days=-1``). 

19 

20Public API: 

21 AuditEntry(action, transport, bank_id, request, response, metadata) 

22 AuditLogger(pool_getter, schema_getter, enabled, allowed_actions, 

23 retention_days=-1) 

24 .log(entry) 

25 .log_with_timing(entry) # context manager 

26 .start_retention_sweep() / .stop_retention_sweep() 

27 

28Wiring pattern: 

29 logger = AuditLogger( 

30 pool_getter=lambda: store._pool, 

31 schema_getter=lambda: store._current_schema_or_public(), 

32 enabled=True, 

33 allowed_actions=["retain", "recall", "classify"], 

34 ) 

35 logger.log(AuditEntry(action="recall", transport="bench", 

36 bank_id=bank_id, request={"query": q, "top_k": 20}, 

37 response={"n_results": 18})) 

38""" 

39 

40from __future__ import annotations 

41 

42import asyncio 

43import json 

44import logging 

45import uuid 

46from contextlib import asynccontextmanager 

47from dataclasses import dataclass, field 

48from datetime import datetime, timedelta, timezone 

49from typing import TYPE_CHECKING, Any, Callable 

50 

51if TYPE_CHECKING: 

52 from collections.abc import AsyncIterator 

53 

54logger = logging.getLogger(__name__) 

55 

56 

57# ─── data ───────────────────────────────────────────────────────────── 

58 

59 

60@dataclass 

61class AuditEntry: 

62 """A single audit log entry. 

63 

64 ``request`` and ``response`` should be normalized — omit raw text 

65 where possible (use lengths/counts/ids) to keep the table compact 

66 and avoid logging sensitive content. 

67 """ 

68 

69 action: str 

70 transport: str # "bench", "harness", "system", "http", "mcp", ... 

71 bank_id: str | None = None 

72 started_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc)) 

73 ended_at: datetime | None = None 

74 request: dict[str, Any] | None = None 

75 response: dict[str, Any] | None = None 

76 metadata: dict[str, Any] = field(default_factory=dict) 

77 

78 

79# ─── JSON helpers ───────────────────────────────────────────────────── 

80 

81 

82def _json_default(obj: Any) -> str: 

83 if isinstance(obj, datetime): 

84 return obj.isoformat() 

85 if isinstance(obj, uuid.UUID): 

86 return str(obj) 

87 if isinstance(obj, bytes): 

88 return "<bytes>" 

89 if isinstance(obj, set): 

90 return list(obj) # type: ignore[return-value] 

91 return str(obj) 

92 

93 

94def _safe_json(data: Any) -> str | None: 

95 if data is None: 

96 return None 

97 try: 

98 return json.dumps(data, default=_json_default) 

99 except Exception: # noqa: BLE001 

100 logger.debug("audit: JSON serialization failed", exc_info=True) 

101 return None 

102 

103 

104# ─── logger ─────────────────────────────────────────────────────────── 

105 

106 

107_SWEEP_INTERVAL_SECONDS = 3600 # retention sweep every hour 

108_DEFAULT_TABLE = "astrocyte_audit_log" 

109 

110 

111class AuditLogger: 

112 """Fire-and-forget audit log writer. 

113 

114 Pool and schema are looked up via callables at write time, so a 

115 single ``AuditLogger`` can serve multiple tenants whose schema 

116 context changes per request. 

117 """ 

118 

119 def __init__( 

120 self, 

121 *, 

122 pool_getter: Callable[[], Any | None], 

123 schema_getter: Callable[[], str] | None = None, 

124 enabled: bool = False, 

125 allowed_actions: list[str] | None = None, 

126 retention_days: int = -1, 

127 table: str = _DEFAULT_TABLE, 

128 ) -> None: 

129 self._pool_getter = pool_getter 

130 self._schema_getter = schema_getter or (lambda: "public") 

131 self._enabled = enabled 

132 self._allowed: frozenset[str] | None = frozenset(allowed_actions) if allowed_actions else None 

133 self._retention_days = retention_days 

134 self._table = table 

135 self._sweep_task: asyncio.Task[None] | None = None 

136 

137 # ── status ──────────────────────────────────────────────────────── 

138 

139 def is_enabled(self, action: str) -> bool: 

140 if not self._enabled: 

141 return False 

142 if self._allowed is not None: 

143 return action in self._allowed 

144 return True 

145 

146 # ── write ───────────────────────────────────────────────────────── 

147 

148 def log(self, entry: AuditEntry) -> None: 

149 """Schedule an audit write as a background task (fire-and-forget).""" 

150 if not self.is_enabled(entry.action): 

151 return 

152 if entry.ended_at is None: 

153 entry.ended_at = datetime.now(timezone.utc) 

154 try: 

155 asyncio.create_task(self._safe_write(entry)) 

156 except RuntimeError: 

157 # No running event loop (e.g. shutdown / sync context). Skip. 

158 logger.debug("audit: no running event loop; skip %s", entry.action) 

159 

160 @asynccontextmanager 

161 async def log_with_timing(self, entry: AuditEntry) -> AsyncIterator[AuditEntry]: 

162 """Context manager that records ``ended_at`` on exit and logs. 

163 

164 Usage: 

165 async with audit.log_with_timing(AuditEntry(action="recall", ...)) as e: 

166 results = await do_recall(...) 

167 e.response = {"n_results": len(results)} 

168 """ 

169 entry.started_at = datetime.now(timezone.utc) 

170 try: 

171 yield entry 

172 finally: 

173 entry.ended_at = datetime.now(timezone.utc) 

174 self.log(entry) 

175 

176 async def _safe_write(self, entry: AuditEntry) -> None: 

177 pool = self._pool_getter() 

178 if pool is None: 

179 logger.debug("audit: pool unavailable; skip %s", entry.action) 

180 return 

181 try: 

182 schema = self._schema_getter() 

183 qualified = f"{schema}.{self._table}" if schema != "public" else self._table 

184 await self._do_write(pool, qualified, entry) 

185 except Exception as exc: # noqa: BLE001 

186 logger.warning( 

187 "audit: write failed action=%s bank=%s: %s", 

188 entry.action, 

189 entry.bank_id, 

190 exc, 

191 ) 

192 

193 async def _do_write(self, pool: Any, table: str, entry: AuditEntry) -> None: 

194 """Write via psycopg pool (matches astrocyte_postgres adapter pattern). 

195 

196 Tolerates either ``connection()`` async-context-manager or 

197 ``acquire()``/``release()`` shapes by feature-detecting. 

198 """ 

199 sql = f""" 

200 INSERT INTO {table} 

201 (id, action, transport, bank_id, started_at, ended_at, request, response, metadata) 

202 VALUES 

203 (%s, %s, %s, %s, %s, %s, %s::jsonb, %s::jsonb, %s::jsonb) 

204 """ 

205 params = ( 

206 str(uuid.uuid4()), 

207 entry.action, 

208 entry.transport, 

209 entry.bank_id, 

210 entry.started_at, 

211 entry.ended_at, 

212 _safe_json(entry.request), 

213 _safe_json(entry.response), 

214 _safe_json(entry.metadata) or "{}", 

215 ) 

216 # psycopg AsyncConnectionPool 

217 if hasattr(pool, "connection"): 

218 async with pool.connection() as conn: 

219 async with conn.cursor() as cur: 

220 await cur.execute(sql, params) 

221 return 

222 # asyncpg-style fallback (Hindsight uses this) 

223 if hasattr(pool, "acquire"): 

224 async with pool.acquire() as conn: 

225 await conn.execute(sql.replace("%s", "$1$2$3$4$5$6$7$8$9"), *params) 

226 return 

227 logger.warning("audit: unrecognized pool shape: %r", type(pool)) 

228 

229 # ── retention sweep ─────────────────────────────────────────────── 

230 

231 def start_retention_sweep(self) -> None: 

232 if self._retention_days <= 0 or not self._enabled: 

233 return 

234 try: 

235 self._sweep_task = asyncio.create_task(self._sweep_loop()) 

236 except RuntimeError: 

237 logger.debug("audit: cannot start sweep (no event loop)") 

238 

239 async def stop_retention_sweep(self) -> None: 

240 if self._sweep_task is None or self._sweep_task.done(): 

241 return 

242 self._sweep_task.cancel() 

243 try: 

244 await self._sweep_task 

245 except asyncio.CancelledError: 

246 # Expected — we just cancelled it on the line above. Awaiting 

247 # a cancelled task always raises CancelledError; consuming it 

248 # here is the canonical shutdown pattern. 

249 pass 

250 self._sweep_task = None 

251 

252 async def _sweep_loop(self) -> None: 

253 while True: 

254 try: 

255 await self._run_sweep() 

256 except Exception as exc: # noqa: BLE001 

257 logger.warning("audit: sweep iteration failed: %s", exc) 

258 await asyncio.sleep(_SWEEP_INTERVAL_SECONDS) 

259 

260 async def _run_sweep(self) -> None: 

261 pool = self._pool_getter() 

262 if pool is None: 

263 return 

264 schema = self._schema_getter() 

265 qualified = f"{schema}.{self._table}" if schema != "public" else self._table 

266 cutoff = datetime.now(timezone.utc) - timedelta(days=self._retention_days) 

267 sql = f"DELETE FROM {qualified} WHERE started_at < %s" 

268 try: 

269 if hasattr(pool, "connection"): 

270 async with pool.connection() as conn: 

271 async with conn.cursor() as cur: 

272 await cur.execute(sql, (cutoff,)) 

273 elif hasattr(pool, "acquire"): 

274 async with pool.acquire() as conn: 

275 await conn.execute(sql.replace("%s", "$1"), cutoff) 

276 except Exception as exc: # noqa: BLE001 

277 logger.warning("audit: sweep delete failed: %s", exc) 

278 

279 

280# ─── module-level convenience: a default no-op logger ───────────────── 

281 

282_DEFAULT_LOGGER: AuditLogger | None = None 

283 

284 

285def get_default_audit_logger() -> AuditLogger: 

286 """Return the process-wide default audit logger. 

287 

288 The default is disabled (no-op) until something replaces it via 

289 ``set_default_audit_logger`` (e.g. at gateway startup). This lets 

290 pipeline code call ``get_default_audit_logger().log(...)`` without 

291 branching on "audit configured or not." 

292 """ 

293 global _DEFAULT_LOGGER 

294 if _DEFAULT_LOGGER is None: 

295 _DEFAULT_LOGGER = AuditLogger( 

296 pool_getter=lambda: None, 

297 enabled=False, 

298 ) 

299 return _DEFAULT_LOGGER 

300 

301 

302def set_default_audit_logger(logger_instance: AuditLogger) -> None: 

303 global _DEFAULT_LOGGER 

304 _DEFAULT_LOGGER = logger_instance