Coverage for astrocyte/pipeline/retain_fsm/engine.py: 90%
111 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""M14.0: FSM engine driving the retain pipeline.
3Pure-Python, no DSL, no Rust. States are async coroutines that take
4``(ctx, services)`` and return one of:
5- a state name string → next state
6- :class:`Complete` → terminate successfully
7- :class:`Failed` → terminate with error
8- :class:`Parallel` → fan out to multiple states concurrently, then
9 join at a named next state
11The engine handles checkpoint between transitions, error capture, and
12parallel-join semantics. State implementations live in
13:mod:`astrocyte.pipeline.retain_fsm.states` and are registered via
14``RetainFSM.register``.
16See ``docs/_design/m13-m14-roadmap.md`` §4.
17"""
19from __future__ import annotations
21import asyncio
22import logging
23import time
24from dataclasses import dataclass
25from datetime import datetime, timezone
26from typing import TYPE_CHECKING, Awaitable, Callable
28from astrocyte.pipeline.retain_fsm.context import (
29 RetainContext,
30 RetainServices,
31 StepLogEntry,
32)
34if TYPE_CHECKING:
35 from astrocyte.pipeline.retain_fsm.checkpoint import Checkpoint
37logger = logging.getLogger("astrocyte.pipeline.retain_fsm.engine")
40# ── Transition return types ────────────────────────────────────────────
43@dataclass(frozen=True)
44class Complete:
45 """Terminal success — drives the engine to mark ``ctx.completed_at``
46 and return."""
49@dataclass(frozen=True)
50class Failed:
51 """Terminal failure with a reason. Appended to ``ctx.errors``."""
53 reason: str
56@dataclass(frozen=True)
57class Parallel:
58 """Fan out to ``branches`` concurrently; all must complete before
59 transitioning to ``join``. Each branch is run as if it were a top-
60 level state (with the same context, awaited via ``asyncio.gather``).
62 Branches share the context by reference — they MUST not race on the
63 same field. Use this for embarrassingly-parallel work (e.g.
64 extraction + entities + embeddings, which write disjoint context
65 fields).
66 """
68 branches: tuple[str, ...]
69 join: str
72StateResult = "str | Complete | Failed | Parallel"
73StateFunc = Callable[[RetainContext, RetainServices], Awaitable["StateResult"]]
76# ── Engine ─────────────────────────────────────────────────────────────
79class RetainFSM:
80 """Drive a :class:`RetainContext` through registered states until
81 termination.
83 Usage::
85 fsm = RetainFSM(services)
86 fsm.register("INIT", state_init)
87 fsm.register("READ", state_read)
88 fsm.register("COMPLETE", state_complete) # optional explicit
89 ctx = RetainContext(bank_id="b1", source_id="s1", md_text="...")
90 ctx = await fsm.run(ctx)
91 assert ctx.completed_at is not None
92 """
94 def __init__(self, services: RetainServices) -> None:
95 self.services = services
96 self._registry: dict[str, StateFunc] = {}
98 # ── Registration ───────────────────────────────────────────────────
100 def register(self, name: str, fn: StateFunc) -> None:
101 """Register a state coroutine. Subsequent registrations of the
102 same name overwrite (intended for test stubbing)."""
103 self._registry[name] = fn
105 def registered_states(self) -> tuple[str, ...]:
106 return tuple(sorted(self._registry))
108 # ── Run loop ────────────────────────────────────────────────────────
110 async def run(
111 self,
112 ctx: RetainContext,
113 *,
114 initial_state: str = "INIT",
115 checkpoint: Checkpoint | None = None,
116 ) -> RetainContext:
117 """Drive ``ctx`` from ``initial_state`` to termination.
119 If ``checkpoint`` is supplied, ``ctx`` is persisted after every
120 state transition; on error the partial context is also persisted
121 so :meth:`resume` can pick up.
122 """
123 current = initial_state
124 while True:
125 ctx.last_state = current
126 fn = self._registry.get(current)
127 if fn is None:
128 ctx.errors.append(f"unknown state: {current!r}")
129 if checkpoint is not None:
130 await checkpoint.save(ctx)
131 return ctx
133 entry = StepLogEntry(
134 state=current,
135 started_at=datetime.now(tz=timezone.utc),
136 )
137 ctx.step_log.append(entry)
138 t0 = time.monotonic()
140 try:
141 result = await fn(ctx, self.services)
142 except Exception as exc: # noqa: BLE001 — state errors must surface
143 entry.completed_at = datetime.now(tz=timezone.utc)
144 entry.duration_ms = (time.monotonic() - t0) * 1000
145 entry.error = f"{type(exc).__name__}: {exc}"
146 ctx.errors.append(f"{current}: {entry.error}")
147 logger.warning(
148 "retain_fsm: state %r raised %s",
149 current,
150 entry.error,
151 )
152 if checkpoint is not None:
153 await checkpoint.save(ctx)
154 return ctx
156 entry.completed_at = datetime.now(tz=timezone.utc)
157 entry.duration_ms = (time.monotonic() - t0) * 1000
159 # ── Dispatch on result type ──
160 if isinstance(result, Complete):
161 ctx.completed_at = datetime.now(tz=timezone.utc)
162 logger.info(
163 "retain_fsm: completed source=%s in %d states",
164 ctx.source_id,
165 len(ctx.step_log),
166 )
167 if checkpoint is not None:
168 await checkpoint.save(ctx)
169 return ctx
171 if isinstance(result, Failed):
172 ctx.errors.append(f"{current}: {result.reason}")
173 logger.warning(
174 "retain_fsm: state %r reported Failed: %s",
175 current,
176 result.reason,
177 )
178 if checkpoint is not None:
179 await checkpoint.save(ctx)
180 return ctx
182 if isinstance(result, Parallel):
183 await self._run_parallel(ctx, result)
184 if ctx.errors:
185 # A branch failed; treat as terminal.
186 if checkpoint is not None:
187 await checkpoint.save(ctx)
188 return ctx
189 current = result.join
190 if checkpoint is not None:
191 await checkpoint.save(ctx)
192 continue
194 # Must be a state-name string at this point.
195 if not isinstance(result, str):
196 ctx.errors.append(
197 f"{current}: state returned unsupported type {type(result).__name__}",
198 )
199 if checkpoint is not None:
200 await checkpoint.save(ctx)
201 return ctx
203 current = result
204 if checkpoint is not None:
205 await checkpoint.save(ctx)
207 # ── Parallel branch runner ─────────────────────────────────────────
209 async def _run_parallel(
210 self,
211 ctx: RetainContext,
212 spec: Parallel,
213 ) -> None:
214 """Run ``spec.branches`` concurrently against the same context.
216 Each branch is a state name registered on this FSM. Branches
217 return their own ``StateResult`` but we only honour ``Complete``
218 / ``Failed`` / state-name (treated as a single-step branch — no
219 nested chains within a parallel block; that's deferred to a
220 future engine extension if needed). For M14.0 + M14.1 the
221 parallel branches all do exactly one step then join.
222 """
224 async def _one(branch: str) -> tuple[str, str | None]:
225 fn = self._registry.get(branch)
226 if fn is None:
227 return branch, f"unknown parallel branch: {branch!r}"
228 entry = StepLogEntry(
229 state=f"parallel:{branch}",
230 started_at=datetime.now(tz=timezone.utc),
231 )
232 ctx.step_log.append(entry)
233 t0 = time.monotonic()
234 try:
235 result = await fn(ctx, self.services)
236 except Exception as exc: # noqa: BLE001
237 entry.completed_at = datetime.now(tz=timezone.utc)
238 entry.duration_ms = (time.monotonic() - t0) * 1000
239 entry.error = f"{type(exc).__name__}: {exc}"
240 return branch, entry.error
241 entry.completed_at = datetime.now(tz=timezone.utc)
242 entry.duration_ms = (time.monotonic() - t0) * 1000
243 if isinstance(result, Failed):
244 entry.error = result.reason
245 return branch, result.reason
246 # Complete / state-name / Parallel returned from a branch
247 # are all treated as "branch did its work" — we ignore the
248 # value because the JOIN state is fixed in the spec.
249 return branch, None
251 results = await asyncio.gather(*[_one(b) for b in spec.branches])
252 for branch, err in results:
253 if err is not None:
254 ctx.errors.append(f"parallel:{branch}: {err}")
256 # ── Resume ─────────────────────────────────────────────────────────
258 async def resume(
259 self,
260 ctx: RetainContext,
261 *,
262 checkpoint: Checkpoint | None = None,
263 ) -> RetainContext:
264 """Re-enter the run loop starting from ``ctx.last_state``.
266 Caller is responsible for loading ``ctx`` (typically via
267 :meth:`Checkpoint.load`) before calling.
268 """
269 return await self.run(
270 ctx,
271 initial_state=ctx.last_state,
272 checkpoint=checkpoint,
273 )