Coverage for astrocyte/pipeline/retain_fsm/engine.py: 90%

111 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""M14.0: FSM engine driving the retain pipeline. 

2 

3Pure-Python, no DSL, no Rust. States are async coroutines that take 

4``(ctx, services)`` and return one of: 

5- a state name string → next state 

6- :class:`Complete` → terminate successfully 

7- :class:`Failed` → terminate with error 

8- :class:`Parallel` → fan out to multiple states concurrently, then 

9 join at a named next state 

10 

11The engine handles checkpoint between transitions, error capture, and 

12parallel-join semantics. State implementations live in 

13:mod:`astrocyte.pipeline.retain_fsm.states` and are registered via 

14``RetainFSM.register``. 

15 

16See ``docs/_design/m13-m14-roadmap.md`` §4. 

17""" 

18 

19from __future__ import annotations 

20 

21import asyncio 

22import logging 

23import time 

24from dataclasses import dataclass 

25from datetime import datetime, timezone 

26from typing import TYPE_CHECKING, Awaitable, Callable 

27 

28from astrocyte.pipeline.retain_fsm.context import ( 

29 RetainContext, 

30 RetainServices, 

31 StepLogEntry, 

32) 

33 

34if TYPE_CHECKING: 

35 from astrocyte.pipeline.retain_fsm.checkpoint import Checkpoint 

36 

37logger = logging.getLogger("astrocyte.pipeline.retain_fsm.engine") 

38 

39 

40# ── Transition return types ──────────────────────────────────────────── 

41 

42 

43@dataclass(frozen=True) 

44class Complete: 

45 """Terminal success — drives the engine to mark ``ctx.completed_at`` 

46 and return.""" 

47 

48 

49@dataclass(frozen=True) 

50class Failed: 

51 """Terminal failure with a reason. Appended to ``ctx.errors``.""" 

52 

53 reason: str 

54 

55 

56@dataclass(frozen=True) 

57class Parallel: 

58 """Fan out to ``branches`` concurrently; all must complete before 

59 transitioning to ``join``. Each branch is run as if it were a top- 

60 level state (with the same context, awaited via ``asyncio.gather``). 

61 

62 Branches share the context by reference — they MUST not race on the 

63 same field. Use this for embarrassingly-parallel work (e.g. 

64 extraction + entities + embeddings, which write disjoint context 

65 fields). 

66 """ 

67 

68 branches: tuple[str, ...] 

69 join: str 

70 

71 

72StateResult = "str | Complete | Failed | Parallel" 

73StateFunc = Callable[[RetainContext, RetainServices], Awaitable["StateResult"]] 

74 

75 

76# ── Engine ───────────────────────────────────────────────────────────── 

77 

78 

79class RetainFSM: 

80 """Drive a :class:`RetainContext` through registered states until 

81 termination. 

82 

83 Usage:: 

84 

85 fsm = RetainFSM(services) 

86 fsm.register("INIT", state_init) 

87 fsm.register("READ", state_read) 

88 fsm.register("COMPLETE", state_complete) # optional explicit 

89 ctx = RetainContext(bank_id="b1", source_id="s1", md_text="...") 

90 ctx = await fsm.run(ctx) 

91 assert ctx.completed_at is not None 

92 """ 

93 

94 def __init__(self, services: RetainServices) -> None: 

95 self.services = services 

96 self._registry: dict[str, StateFunc] = {} 

97 

98 # ── Registration ─────────────────────────────────────────────────── 

99 

100 def register(self, name: str, fn: StateFunc) -> None: 

101 """Register a state coroutine. Subsequent registrations of the 

102 same name overwrite (intended for test stubbing).""" 

103 self._registry[name] = fn 

104 

105 def registered_states(self) -> tuple[str, ...]: 

106 return tuple(sorted(self._registry)) 

107 

108 # ── Run loop ──────────────────────────────────────────────────────── 

109 

110 async def run( 

111 self, 

112 ctx: RetainContext, 

113 *, 

114 initial_state: str = "INIT", 

115 checkpoint: Checkpoint | None = None, 

116 ) -> RetainContext: 

117 """Drive ``ctx`` from ``initial_state`` to termination. 

118 

119 If ``checkpoint`` is supplied, ``ctx`` is persisted after every 

120 state transition; on error the partial context is also persisted 

121 so :meth:`resume` can pick up. 

122 """ 

123 current = initial_state 

124 while True: 

125 ctx.last_state = current 

126 fn = self._registry.get(current) 

127 if fn is None: 

128 ctx.errors.append(f"unknown state: {current!r}") 

129 if checkpoint is not None: 

130 await checkpoint.save(ctx) 

131 return ctx 

132 

133 entry = StepLogEntry( 

134 state=current, 

135 started_at=datetime.now(tz=timezone.utc), 

136 ) 

137 ctx.step_log.append(entry) 

138 t0 = time.monotonic() 

139 

140 try: 

141 result = await fn(ctx, self.services) 

142 except Exception as exc: # noqa: BLE001 — state errors must surface 

143 entry.completed_at = datetime.now(tz=timezone.utc) 

144 entry.duration_ms = (time.monotonic() - t0) * 1000 

145 entry.error = f"{type(exc).__name__}: {exc}" 

146 ctx.errors.append(f"{current}: {entry.error}") 

147 logger.warning( 

148 "retain_fsm: state %r raised %s", 

149 current, 

150 entry.error, 

151 ) 

152 if checkpoint is not None: 

153 await checkpoint.save(ctx) 

154 return ctx 

155 

156 entry.completed_at = datetime.now(tz=timezone.utc) 

157 entry.duration_ms = (time.monotonic() - t0) * 1000 

158 

159 # ── Dispatch on result type ── 

160 if isinstance(result, Complete): 

161 ctx.completed_at = datetime.now(tz=timezone.utc) 

162 logger.info( 

163 "retain_fsm: completed source=%s in %d states", 

164 ctx.source_id, 

165 len(ctx.step_log), 

166 ) 

167 if checkpoint is not None: 

168 await checkpoint.save(ctx) 

169 return ctx 

170 

171 if isinstance(result, Failed): 

172 ctx.errors.append(f"{current}: {result.reason}") 

173 logger.warning( 

174 "retain_fsm: state %r reported Failed: %s", 

175 current, 

176 result.reason, 

177 ) 

178 if checkpoint is not None: 

179 await checkpoint.save(ctx) 

180 return ctx 

181 

182 if isinstance(result, Parallel): 

183 await self._run_parallel(ctx, result) 

184 if ctx.errors: 

185 # A branch failed; treat as terminal. 

186 if checkpoint is not None: 

187 await checkpoint.save(ctx) 

188 return ctx 

189 current = result.join 

190 if checkpoint is not None: 

191 await checkpoint.save(ctx) 

192 continue 

193 

194 # Must be a state-name string at this point. 

195 if not isinstance(result, str): 

196 ctx.errors.append( 

197 f"{current}: state returned unsupported type {type(result).__name__}", 

198 ) 

199 if checkpoint is not None: 

200 await checkpoint.save(ctx) 

201 return ctx 

202 

203 current = result 

204 if checkpoint is not None: 

205 await checkpoint.save(ctx) 

206 

207 # ── Parallel branch runner ───────────────────────────────────────── 

208 

209 async def _run_parallel( 

210 self, 

211 ctx: RetainContext, 

212 spec: Parallel, 

213 ) -> None: 

214 """Run ``spec.branches`` concurrently against the same context. 

215 

216 Each branch is a state name registered on this FSM. Branches 

217 return their own ``StateResult`` but we only honour ``Complete`` 

218 / ``Failed`` / state-name (treated as a single-step branch — no 

219 nested chains within a parallel block; that's deferred to a 

220 future engine extension if needed). For M14.0 + M14.1 the 

221 parallel branches all do exactly one step then join. 

222 """ 

223 

224 async def _one(branch: str) -> tuple[str, str | None]: 

225 fn = self._registry.get(branch) 

226 if fn is None: 

227 return branch, f"unknown parallel branch: {branch!r}" 

228 entry = StepLogEntry( 

229 state=f"parallel:{branch}", 

230 started_at=datetime.now(tz=timezone.utc), 

231 ) 

232 ctx.step_log.append(entry) 

233 t0 = time.monotonic() 

234 try: 

235 result = await fn(ctx, self.services) 

236 except Exception as exc: # noqa: BLE001 

237 entry.completed_at = datetime.now(tz=timezone.utc) 

238 entry.duration_ms = (time.monotonic() - t0) * 1000 

239 entry.error = f"{type(exc).__name__}: {exc}" 

240 return branch, entry.error 

241 entry.completed_at = datetime.now(tz=timezone.utc) 

242 entry.duration_ms = (time.monotonic() - t0) * 1000 

243 if isinstance(result, Failed): 

244 entry.error = result.reason 

245 return branch, result.reason 

246 # Complete / state-name / Parallel returned from a branch 

247 # are all treated as "branch did its work" — we ignore the 

248 # value because the JOIN state is fixed in the spec. 

249 return branch, None 

250 

251 results = await asyncio.gather(*[_one(b) for b in spec.branches]) 

252 for branch, err in results: 

253 if err is not None: 

254 ctx.errors.append(f"parallel:{branch}: {err}") 

255 

256 # ── Resume ───────────────────────────────────────────────────────── 

257 

258 async def resume( 

259 self, 

260 ctx: RetainContext, 

261 *, 

262 checkpoint: Checkpoint | None = None, 

263 ) -> RetainContext: 

264 """Re-enter the run loop starting from ``ctx.last_state``. 

265 

266 Caller is responsible for loading ``ctx`` (typically via 

267 :meth:`Checkpoint.load`) before calling. 

268 """ 

269 return await self.run( 

270 ctx, 

271 initial_state=ctx.last_state, 

272 checkpoint=checkpoint, 

273 )