Coverage for astrocyte/pipeline/retain_fsm/checkpoint.py: 98%
89 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""M14.0: checkpoint persistence for the retain FSM.
3After each state transition the engine optionally hands the partially-
4populated ``RetainContext`` to a :class:`Checkpoint`, which persists it
5to disk (or a future Postgres-backed backend). Resume reads the latest
6checkpoint for a ``(bank_id, source_id)`` and re-enters the engine at
7``ctx.last_state``.
9Two backends for M14.0:
11- :class:`FileCheckpoint` — JSON files under a configurable root.
12 Default; sufficient for bench runs and unit tests.
13- :class:`InMemoryCheckpoint` — dict; for tests that want resume
14 semantics without touching the filesystem.
16Postgres-backed checkpoint is deferred. The interface accepts arbitrary
17backends; switching is a matter of subclassing :class:`Checkpoint` and
18implementing ``save`` / ``load`` / ``list``.
20Only fields that are JSON-serialisable round-trip cleanly. Datetime
21fields are stored as ISO strings; PageIndexSection / PageIndexFact
22dataclasses are NOT persisted (they're recoverable from the bank
23post-resume by re-reading from the store). For M14.0 we persist only
24the control-plane fields (state, errors, step_log, document_id, ids
25of created/updated wikis); state implementations can mark fields as
26"derived" so they're reconstructed on resume.
27"""
29from __future__ import annotations
31import json
32import logging
33from abc import ABC, abstractmethod
34from datetime import datetime, timezone
35from pathlib import Path
36from typing import TYPE_CHECKING, Any
38if TYPE_CHECKING:
39 from astrocyte.pipeline.retain_fsm.context import RetainContext
41logger = logging.getLogger("astrocyte.pipeline.retain_fsm.checkpoint")
44class Checkpoint(ABC):
45 """Persistence backend for ``RetainContext`` snapshots."""
47 @abstractmethod
48 async def save(self, ctx: RetainContext) -> None:
49 """Persist a snapshot of ``ctx``. Called after every state
50 transition by the engine."""
52 @abstractmethod
53 async def load(
54 self,
55 bank_id: str,
56 source_id: str,
57 ) -> RetainContext | None:
58 """Load the latest snapshot for the (bank, source) pair, or
59 ``None`` if no checkpoint exists."""
61 @abstractmethod
62 async def delete(
63 self,
64 bank_id: str,
65 source_id: str,
66 ) -> bool:
67 """Drop the checkpoint after successful completion. Returns
68 ``True`` if a checkpoint existed and was deleted."""
71# ── In-memory backend (tests, ephemeral runs) ───────────────────────────
74class InMemoryCheckpoint(Checkpoint):
75 """Dict-backed checkpoint. Loses state on process exit — used by
76 tests that need round-trip semantics without filesystem coupling.
77 """
79 def __init__(self) -> None:
80 # keyed by (bank_id, source_id) → serialised dict
81 self._store: dict[tuple[str, str], dict[str, Any]] = {}
83 async def save(self, ctx: RetainContext) -> None:
84 self._store[(ctx.bank_id, ctx.source_id)] = _serialise(ctx)
86 async def load(
87 self,
88 bank_id: str,
89 source_id: str,
90 ) -> RetainContext | None:
91 raw = self._store.get((bank_id, source_id))
92 if raw is None:
93 return None
94 return _deserialise(raw)
96 async def delete(self, bank_id: str, source_id: str) -> bool:
97 return self._store.pop((bank_id, source_id), None) is not None
100# ── Filesystem backend (default) ────────────────────────────────────────
103class FileCheckpoint(Checkpoint):
104 """JSON-on-disk checkpoint. Files named
105 ``<root>/<bank_id>/<source_id>.json``. Safe across process restarts;
106 not safe across concurrent writes to the same source.
107 """
109 def __init__(self, root: Path | str) -> None:
110 self.root = Path(root)
111 self.root.mkdir(parents=True, exist_ok=True)
113 def _path_for(self, bank_id: str, source_id: str) -> Path:
114 # Sanitise bank_id for filesystem: replace anything that's not
115 # alnum/dash/dot/underscore. Same for source_id.
116 safe_bank = _safe_segment(bank_id)
117 safe_src = _safe_segment(source_id)
118 bank_dir = self.root / safe_bank
119 bank_dir.mkdir(parents=True, exist_ok=True)
120 return bank_dir / f"{safe_src}.json"
122 async def save(self, ctx: RetainContext) -> None:
123 path = self._path_for(ctx.bank_id, ctx.source_id)
124 payload = _serialise(ctx)
125 # Atomic-ish: write to tmp then rename.
126 tmp = path.with_suffix(".json.tmp")
127 tmp.write_text(json.dumps(payload, indent=2, default=str))
128 tmp.replace(path)
130 async def load(
131 self,
132 bank_id: str,
133 source_id: str,
134 ) -> RetainContext | None:
135 path = self._path_for(bank_id, source_id)
136 if not path.exists():
137 return None
138 try:
139 raw = json.loads(path.read_text())
140 except json.JSONDecodeError as exc:
141 logger.warning(
142 "checkpoint load: malformed JSON at %s: %s",
143 path,
144 exc,
145 )
146 return None
147 return _deserialise(raw)
149 async def delete(self, bank_id: str, source_id: str) -> bool:
150 path = self._path_for(bank_id, source_id)
151 if not path.exists():
152 return False
153 path.unlink()
154 return True
157# ── Serialisation helpers ──────────────────────────────────────────────
160def _safe_segment(s: str) -> str:
161 import re
163 return re.sub(r"[^a-zA-Z0-9._-]", "_", s)[:128] or "_"
166def _serialise(ctx: RetainContext) -> dict[str, Any]:
167 """Reduce a ``RetainContext`` to a JSON-serialisable dict.
169 Sections / facts are NOT persisted by default — they're recoverable
170 by re-reading from the bank store after resume. We persist only the
171 control-plane fields and small primitive lists.
172 """
173 out: dict[str, Any] = {
174 "schema_version": 1,
175 "bank_id": ctx.bank_id,
176 "source_id": ctx.source_id,
177 "md_text_len": len(ctx.md_text), # checkpoint avoids storing the full text
178 "reference_date": _iso(ctx.reference_date),
179 "document_id": ctx.document_id,
180 "entities": list(ctx.entities),
181 "wikis_created": list(ctx.wikis_created),
182 "wikis_updated": list(ctx.wikis_updated),
183 "supersedes_edges": [list(e) for e in ctx.supersedes_edges],
184 "last_state": ctx.last_state,
185 "step_log": [
186 {
187 "state": e.state,
188 "started_at": _iso(e.started_at),
189 "completed_at": _iso(e.completed_at),
190 "duration_ms": e.duration_ms,
191 "error": e.error,
192 "notes": e.notes,
193 }
194 for e in ctx.step_log
195 ],
196 "errors": list(ctx.errors),
197 "started_at": _iso(ctx.started_at),
198 "completed_at": _iso(ctx.completed_at),
199 }
200 return out
203def _deserialise(raw: dict[str, Any]) -> RetainContext:
204 """Reconstruct a ``RetainContext`` from a serialised dict.
206 Note: ``md_text`` is NOT restored (we only stored its length).
207 Resume callers must re-supply ``md_text`` from the source if it's
208 needed by remaining states. Sections / facts are also NOT restored
209 — they live in the bank store; states that need them must reload
210 via ``store.load_sections_with_embeddings`` etc.
211 """
212 from astrocyte.pipeline.retain_fsm.context import (
213 RetainContext,
214 StepLogEntry,
215 )
217 ctx = RetainContext(
218 bank_id=raw["bank_id"],
219 source_id=raw["source_id"],
220 md_text="", # NOT persisted; caller must supply on resume
221 )
222 ctx.reference_date = _parse_iso(raw.get("reference_date"))
223 ctx.document_id = raw.get("document_id")
224 ctx.entities = list(raw.get("entities") or [])
225 ctx.wikis_created = list(raw.get("wikis_created") or [])
226 ctx.wikis_updated = list(raw.get("wikis_updated") or [])
227 ctx.supersedes_edges = [tuple(e) for e in raw.get("supersedes_edges") or []]
228 ctx.last_state = raw.get("last_state") or "INIT"
229 ctx.errors = list(raw.get("errors") or [])
230 ctx.started_at = _parse_iso(raw.get("started_at")) or datetime.now(
231 tz=timezone.utc,
232 )
233 ctx.completed_at = _parse_iso(raw.get("completed_at"))
234 ctx.step_log = [
235 StepLogEntry(
236 state=e["state"],
237 started_at=_parse_iso(e.get("started_at")) or ctx.started_at,
238 completed_at=_parse_iso(e.get("completed_at")),
239 duration_ms=e.get("duration_ms"),
240 error=e.get("error"),
241 notes=e.get("notes") or {},
242 )
243 for e in (raw.get("step_log") or [])
244 ]
245 return ctx
248def _iso(dt: datetime | None) -> str | None:
249 return dt.isoformat() if dt is not None else None
252def _parse_iso(s: str | None) -> datetime | None:
253 if s is None:
254 return None
255 try:
256 return datetime.fromisoformat(s)
257 except (TypeError, ValueError):
258 return None