Coverage for astrocyte/pipeline/retain_fsm/checkpoint.py: 98%

89 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""M14.0: checkpoint persistence for the retain FSM. 

2 

3After each state transition the engine optionally hands the partially- 

4populated ``RetainContext`` to a :class:`Checkpoint`, which persists it 

5to disk (or a future Postgres-backed backend). Resume reads the latest 

6checkpoint for a ``(bank_id, source_id)`` and re-enters the engine at 

7``ctx.last_state``. 

8 

9Two backends for M14.0: 

10 

11- :class:`FileCheckpoint` — JSON files under a configurable root. 

12 Default; sufficient for bench runs and unit tests. 

13- :class:`InMemoryCheckpoint` — dict; for tests that want resume 

14 semantics without touching the filesystem. 

15 

16Postgres-backed checkpoint is deferred. The interface accepts arbitrary 

17backends; switching is a matter of subclassing :class:`Checkpoint` and 

18implementing ``save`` / ``load`` / ``list``. 

19 

20Only fields that are JSON-serialisable round-trip cleanly. Datetime 

21fields are stored as ISO strings; PageIndexSection / PageIndexFact 

22dataclasses are NOT persisted (they're recoverable from the bank 

23post-resume by re-reading from the store). For M14.0 we persist only 

24the control-plane fields (state, errors, step_log, document_id, ids 

25of created/updated wikis); state implementations can mark fields as 

26"derived" so they're reconstructed on resume. 

27""" 

28 

29from __future__ import annotations 

30 

31import json 

32import logging 

33from abc import ABC, abstractmethod 

34from datetime import datetime, timezone 

35from pathlib import Path 

36from typing import TYPE_CHECKING, Any 

37 

38if TYPE_CHECKING: 

39 from astrocyte.pipeline.retain_fsm.context import RetainContext 

40 

41logger = logging.getLogger("astrocyte.pipeline.retain_fsm.checkpoint") 

42 

43 

44class Checkpoint(ABC): 

45 """Persistence backend for ``RetainContext`` snapshots.""" 

46 

47 @abstractmethod 

48 async def save(self, ctx: RetainContext) -> None: 

49 """Persist a snapshot of ``ctx``. Called after every state 

50 transition by the engine.""" 

51 

52 @abstractmethod 

53 async def load( 

54 self, 

55 bank_id: str, 

56 source_id: str, 

57 ) -> RetainContext | None: 

58 """Load the latest snapshot for the (bank, source) pair, or 

59 ``None`` if no checkpoint exists.""" 

60 

61 @abstractmethod 

62 async def delete( 

63 self, 

64 bank_id: str, 

65 source_id: str, 

66 ) -> bool: 

67 """Drop the checkpoint after successful completion. Returns 

68 ``True`` if a checkpoint existed and was deleted.""" 

69 

70 

71# ── In-memory backend (tests, ephemeral runs) ─────────────────────────── 

72 

73 

74class InMemoryCheckpoint(Checkpoint): 

75 """Dict-backed checkpoint. Loses state on process exit — used by 

76 tests that need round-trip semantics without filesystem coupling. 

77 """ 

78 

79 def __init__(self) -> None: 

80 # keyed by (bank_id, source_id) → serialised dict 

81 self._store: dict[tuple[str, str], dict[str, Any]] = {} 

82 

83 async def save(self, ctx: RetainContext) -> None: 

84 self._store[(ctx.bank_id, ctx.source_id)] = _serialise(ctx) 

85 

86 async def load( 

87 self, 

88 bank_id: str, 

89 source_id: str, 

90 ) -> RetainContext | None: 

91 raw = self._store.get((bank_id, source_id)) 

92 if raw is None: 

93 return None 

94 return _deserialise(raw) 

95 

96 async def delete(self, bank_id: str, source_id: str) -> bool: 

97 return self._store.pop((bank_id, source_id), None) is not None 

98 

99 

100# ── Filesystem backend (default) ──────────────────────────────────────── 

101 

102 

103class FileCheckpoint(Checkpoint): 

104 """JSON-on-disk checkpoint. Files named 

105 ``<root>/<bank_id>/<source_id>.json``. Safe across process restarts; 

106 not safe across concurrent writes to the same source. 

107 """ 

108 

109 def __init__(self, root: Path | str) -> None: 

110 self.root = Path(root) 

111 self.root.mkdir(parents=True, exist_ok=True) 

112 

113 def _path_for(self, bank_id: str, source_id: str) -> Path: 

114 # Sanitise bank_id for filesystem: replace anything that's not 

115 # alnum/dash/dot/underscore. Same for source_id. 

116 safe_bank = _safe_segment(bank_id) 

117 safe_src = _safe_segment(source_id) 

118 bank_dir = self.root / safe_bank 

119 bank_dir.mkdir(parents=True, exist_ok=True) 

120 return bank_dir / f"{safe_src}.json" 

121 

122 async def save(self, ctx: RetainContext) -> None: 

123 path = self._path_for(ctx.bank_id, ctx.source_id) 

124 payload = _serialise(ctx) 

125 # Atomic-ish: write to tmp then rename. 

126 tmp = path.with_suffix(".json.tmp") 

127 tmp.write_text(json.dumps(payload, indent=2, default=str)) 

128 tmp.replace(path) 

129 

130 async def load( 

131 self, 

132 bank_id: str, 

133 source_id: str, 

134 ) -> RetainContext | None: 

135 path = self._path_for(bank_id, source_id) 

136 if not path.exists(): 

137 return None 

138 try: 

139 raw = json.loads(path.read_text()) 

140 except json.JSONDecodeError as exc: 

141 logger.warning( 

142 "checkpoint load: malformed JSON at %s: %s", 

143 path, 

144 exc, 

145 ) 

146 return None 

147 return _deserialise(raw) 

148 

149 async def delete(self, bank_id: str, source_id: str) -> bool: 

150 path = self._path_for(bank_id, source_id) 

151 if not path.exists(): 

152 return False 

153 path.unlink() 

154 return True 

155 

156 

157# ── Serialisation helpers ────────────────────────────────────────────── 

158 

159 

160def _safe_segment(s: str) -> str: 

161 import re 

162 

163 return re.sub(r"[^a-zA-Z0-9._-]", "_", s)[:128] or "_" 

164 

165 

166def _serialise(ctx: RetainContext) -> dict[str, Any]: 

167 """Reduce a ``RetainContext`` to a JSON-serialisable dict. 

168 

169 Sections / facts are NOT persisted by default — they're recoverable 

170 by re-reading from the bank store after resume. We persist only the 

171 control-plane fields and small primitive lists. 

172 """ 

173 out: dict[str, Any] = { 

174 "schema_version": 1, 

175 "bank_id": ctx.bank_id, 

176 "source_id": ctx.source_id, 

177 "md_text_len": len(ctx.md_text), # checkpoint avoids storing the full text 

178 "reference_date": _iso(ctx.reference_date), 

179 "document_id": ctx.document_id, 

180 "entities": list(ctx.entities), 

181 "wikis_created": list(ctx.wikis_created), 

182 "wikis_updated": list(ctx.wikis_updated), 

183 "supersedes_edges": [list(e) for e in ctx.supersedes_edges], 

184 "last_state": ctx.last_state, 

185 "step_log": [ 

186 { 

187 "state": e.state, 

188 "started_at": _iso(e.started_at), 

189 "completed_at": _iso(e.completed_at), 

190 "duration_ms": e.duration_ms, 

191 "error": e.error, 

192 "notes": e.notes, 

193 } 

194 for e in ctx.step_log 

195 ], 

196 "errors": list(ctx.errors), 

197 "started_at": _iso(ctx.started_at), 

198 "completed_at": _iso(ctx.completed_at), 

199 } 

200 return out 

201 

202 

203def _deserialise(raw: dict[str, Any]) -> RetainContext: 

204 """Reconstruct a ``RetainContext`` from a serialised dict. 

205 

206 Note: ``md_text`` is NOT restored (we only stored its length). 

207 Resume callers must re-supply ``md_text`` from the source if it's 

208 needed by remaining states. Sections / facts are also NOT restored 

209 — they live in the bank store; states that need them must reload 

210 via ``store.load_sections_with_embeddings`` etc. 

211 """ 

212 from astrocyte.pipeline.retain_fsm.context import ( 

213 RetainContext, 

214 StepLogEntry, 

215 ) 

216 

217 ctx = RetainContext( 

218 bank_id=raw["bank_id"], 

219 source_id=raw["source_id"], 

220 md_text="", # NOT persisted; caller must supply on resume 

221 ) 

222 ctx.reference_date = _parse_iso(raw.get("reference_date")) 

223 ctx.document_id = raw.get("document_id") 

224 ctx.entities = list(raw.get("entities") or []) 

225 ctx.wikis_created = list(raw.get("wikis_created") or []) 

226 ctx.wikis_updated = list(raw.get("wikis_updated") or []) 

227 ctx.supersedes_edges = [tuple(e) for e in raw.get("supersedes_edges") or []] 

228 ctx.last_state = raw.get("last_state") or "INIT" 

229 ctx.errors = list(raw.get("errors") or []) 

230 ctx.started_at = _parse_iso(raw.get("started_at")) or datetime.now( 

231 tz=timezone.utc, 

232 ) 

233 ctx.completed_at = _parse_iso(raw.get("completed_at")) 

234 ctx.step_log = [ 

235 StepLogEntry( 

236 state=e["state"], 

237 started_at=_parse_iso(e.get("started_at")) or ctx.started_at, 

238 completed_at=_parse_iso(e.get("completed_at")), 

239 duration_ms=e.get("duration_ms"), 

240 error=e.get("error"), 

241 notes=e.get("notes") or {}, 

242 ) 

243 for e in (raw.get("step_log") or []) 

244 ] 

245 return ctx 

246 

247 

248def _iso(dt: datetime | None) -> str | None: 

249 return dt.isoformat() if dt is not None else None 

250 

251 

252def _parse_iso(s: str | None) -> datetime | None: 

253 if s is None: 

254 return None 

255 try: 

256 return datetime.fromisoformat(s) 

257 except (TypeError, ValueError): 

258 return None