Coverage for astrocyte/audit.py: 64%
143 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Domain-level audit logging for Astrocyte operations.
3Adopted from Hindsight (`hindsight_api/engine/audit.py`). Captures every
4retain/recall/classify/consolidate operation with structured input,
5output, and metadata to a single ``astrocyte_audit_log`` table.
7Design properties:
9- **Fire-and-forget.** ``log()`` schedules an async write and returns
10 immediately. Audit writes never block the calling operation.
11- **Failure-isolating.** A failed audit write is logged at WARNING and
12 swallowed. Never propagates to the caller.
13- **Late-bound pool / schema.** Uses callables for pool + schema lookup
14 so audit doesn't tightly couple to a specific store's lifecycle.
15- **Action allowlist.** Audit is enabled per-action; default is empty
16 (no audit). Production deployments opt actions in explicitly.
17- **Retention sweep (opt-in).** Background task deletes entries older
18 than ``retention_days``. Disabled by default (``retention_days=-1``).
20Public API:
21 AuditEntry(action, transport, bank_id, request, response, metadata)
22 AuditLogger(pool_getter, schema_getter, enabled, allowed_actions,
23 retention_days=-1)
24 .log(entry)
25 .log_with_timing(entry) # context manager
26 .start_retention_sweep() / .stop_retention_sweep()
28Wiring pattern:
29 logger = AuditLogger(
30 pool_getter=lambda: store._pool,
31 schema_getter=lambda: store._current_schema_or_public(),
32 enabled=True,
33 allowed_actions=["retain", "recall", "classify"],
34 )
35 logger.log(AuditEntry(action="recall", transport="bench",
36 bank_id=bank_id, request={"query": q, "top_k": 20},
37 response={"n_results": 18}))
38"""
40from __future__ import annotations
42import asyncio
43import json
44import logging
45import uuid
46from contextlib import asynccontextmanager
47from dataclasses import dataclass, field
48from datetime import datetime, timedelta, timezone
49from typing import TYPE_CHECKING, Any, Callable
51if TYPE_CHECKING:
52 from collections.abc import AsyncIterator
54logger = logging.getLogger(__name__)
57# ─── data ─────────────────────────────────────────────────────────────
60@dataclass
61class AuditEntry:
62 """A single audit log entry.
64 ``request`` and ``response`` should be normalized — omit raw text
65 where possible (use lengths/counts/ids) to keep the table compact
66 and avoid logging sensitive content.
67 """
69 action: str
70 transport: str # "bench", "harness", "system", "http", "mcp", ...
71 bank_id: str | None = None
72 started_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
73 ended_at: datetime | None = None
74 request: dict[str, Any] | None = None
75 response: dict[str, Any] | None = None
76 metadata: dict[str, Any] = field(default_factory=dict)
79# ─── JSON helpers ─────────────────────────────────────────────────────
82def _json_default(obj: Any) -> str:
83 if isinstance(obj, datetime):
84 return obj.isoformat()
85 if isinstance(obj, uuid.UUID):
86 return str(obj)
87 if isinstance(obj, bytes):
88 return "<bytes>"
89 if isinstance(obj, set):
90 return list(obj) # type: ignore[return-value]
91 return str(obj)
94def _safe_json(data: Any) -> str | None:
95 if data is None:
96 return None
97 try:
98 return json.dumps(data, default=_json_default)
99 except Exception: # noqa: BLE001
100 logger.debug("audit: JSON serialization failed", exc_info=True)
101 return None
104# ─── logger ───────────────────────────────────────────────────────────
107_SWEEP_INTERVAL_SECONDS = 3600 # retention sweep every hour
108_DEFAULT_TABLE = "astrocyte_audit_log"
111class AuditLogger:
112 """Fire-and-forget audit log writer.
114 Pool and schema are looked up via callables at write time, so a
115 single ``AuditLogger`` can serve multiple tenants whose schema
116 context changes per request.
117 """
119 def __init__(
120 self,
121 *,
122 pool_getter: Callable[[], Any | None],
123 schema_getter: Callable[[], str] | None = None,
124 enabled: bool = False,
125 allowed_actions: list[str] | None = None,
126 retention_days: int = -1,
127 table: str = _DEFAULT_TABLE,
128 ) -> None:
129 self._pool_getter = pool_getter
130 self._schema_getter = schema_getter or (lambda: "public")
131 self._enabled = enabled
132 self._allowed: frozenset[str] | None = frozenset(allowed_actions) if allowed_actions else None
133 self._retention_days = retention_days
134 self._table = table
135 self._sweep_task: asyncio.Task[None] | None = None
137 # ── status ────────────────────────────────────────────────────────
139 def is_enabled(self, action: str) -> bool:
140 if not self._enabled:
141 return False
142 if self._allowed is not None:
143 return action in self._allowed
144 return True
146 # ── write ─────────────────────────────────────────────────────────
148 def log(self, entry: AuditEntry) -> None:
149 """Schedule an audit write as a background task (fire-and-forget)."""
150 if not self.is_enabled(entry.action):
151 return
152 if entry.ended_at is None:
153 entry.ended_at = datetime.now(timezone.utc)
154 try:
155 asyncio.create_task(self._safe_write(entry))
156 except RuntimeError:
157 # No running event loop (e.g. shutdown / sync context). Skip.
158 logger.debug("audit: no running event loop; skip %s", entry.action)
160 @asynccontextmanager
161 async def log_with_timing(self, entry: AuditEntry) -> AsyncIterator[AuditEntry]:
162 """Context manager that records ``ended_at`` on exit and logs.
164 Usage:
165 async with audit.log_with_timing(AuditEntry(action="recall", ...)) as e:
166 results = await do_recall(...)
167 e.response = {"n_results": len(results)}
168 """
169 entry.started_at = datetime.now(timezone.utc)
170 try:
171 yield entry
172 finally:
173 entry.ended_at = datetime.now(timezone.utc)
174 self.log(entry)
176 async def _safe_write(self, entry: AuditEntry) -> None:
177 pool = self._pool_getter()
178 if pool is None:
179 logger.debug("audit: pool unavailable; skip %s", entry.action)
180 return
181 try:
182 schema = self._schema_getter()
183 qualified = f"{schema}.{self._table}" if schema != "public" else self._table
184 await self._do_write(pool, qualified, entry)
185 except Exception as exc: # noqa: BLE001
186 logger.warning(
187 "audit: write failed action=%s bank=%s: %s",
188 entry.action,
189 entry.bank_id,
190 exc,
191 )
193 async def _do_write(self, pool: Any, table: str, entry: AuditEntry) -> None:
194 """Write via psycopg pool (matches astrocyte_postgres adapter pattern).
196 Tolerates either ``connection()`` async-context-manager or
197 ``acquire()``/``release()`` shapes by feature-detecting.
198 """
199 sql = f"""
200 INSERT INTO {table}
201 (id, action, transport, bank_id, started_at, ended_at, request, response, metadata)
202 VALUES
203 (%s, %s, %s, %s, %s, %s, %s::jsonb, %s::jsonb, %s::jsonb)
204 """
205 params = (
206 str(uuid.uuid4()),
207 entry.action,
208 entry.transport,
209 entry.bank_id,
210 entry.started_at,
211 entry.ended_at,
212 _safe_json(entry.request),
213 _safe_json(entry.response),
214 _safe_json(entry.metadata) or "{}",
215 )
216 # psycopg AsyncConnectionPool
217 if hasattr(pool, "connection"):
218 async with pool.connection() as conn:
219 async with conn.cursor() as cur:
220 await cur.execute(sql, params)
221 return
222 # asyncpg-style fallback (Hindsight uses this)
223 if hasattr(pool, "acquire"):
224 async with pool.acquire() as conn:
225 await conn.execute(sql.replace("%s", "$1$2$3$4$5$6$7$8$9"), *params)
226 return
227 logger.warning("audit: unrecognized pool shape: %r", type(pool))
229 # ── retention sweep ───────────────────────────────────────────────
231 def start_retention_sweep(self) -> None:
232 if self._retention_days <= 0 or not self._enabled:
233 return
234 try:
235 self._sweep_task = asyncio.create_task(self._sweep_loop())
236 except RuntimeError:
237 logger.debug("audit: cannot start sweep (no event loop)")
239 async def stop_retention_sweep(self) -> None:
240 if self._sweep_task is None or self._sweep_task.done():
241 return
242 self._sweep_task.cancel()
243 try:
244 await self._sweep_task
245 except asyncio.CancelledError:
246 # Expected — we just cancelled it on the line above. Awaiting
247 # a cancelled task always raises CancelledError; consuming it
248 # here is the canonical shutdown pattern.
249 pass
250 self._sweep_task = None
252 async def _sweep_loop(self) -> None:
253 while True:
254 try:
255 await self._run_sweep()
256 except Exception as exc: # noqa: BLE001
257 logger.warning("audit: sweep iteration failed: %s", exc)
258 await asyncio.sleep(_SWEEP_INTERVAL_SECONDS)
260 async def _run_sweep(self) -> None:
261 pool = self._pool_getter()
262 if pool is None:
263 return
264 schema = self._schema_getter()
265 qualified = f"{schema}.{self._table}" if schema != "public" else self._table
266 cutoff = datetime.now(timezone.utc) - timedelta(days=self._retention_days)
267 sql = f"DELETE FROM {qualified} WHERE started_at < %s"
268 try:
269 if hasattr(pool, "connection"):
270 async with pool.connection() as conn:
271 async with conn.cursor() as cur:
272 await cur.execute(sql, (cutoff,))
273 elif hasattr(pool, "acquire"):
274 async with pool.acquire() as conn:
275 await conn.execute(sql.replace("%s", "$1"), cutoff)
276 except Exception as exc: # noqa: BLE001
277 logger.warning("audit: sweep delete failed: %s", exc)
280# ─── module-level convenience: a default no-op logger ─────────────────
282_DEFAULT_LOGGER: AuditLogger | None = None
285def get_default_audit_logger() -> AuditLogger:
286 """Return the process-wide default audit logger.
288 The default is disabled (no-op) until something replaces it via
289 ``set_default_audit_logger`` (e.g. at gateway startup). This lets
290 pipeline code call ``get_default_audit_logger().log(...)`` without
291 branching on "audit configured or not."
292 """
293 global _DEFAULT_LOGGER
294 if _DEFAULT_LOGGER is None:
295 _DEFAULT_LOGGER = AuditLogger(
296 pool_getter=lambda: None,
297 enabled=False,
298 )
299 return _DEFAULT_LOGGER
302def set_default_audit_logger(logger_instance: AuditLogger) -> None:
303 global _DEFAULT_LOGGER
304 _DEFAULT_LOGGER = logger_instance