Coverage for astrocyte/pipeline/section_recall.py: 84%
142 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Section recall orchestrator (M9 PR2 commit B).
3Runs the five Hindsight-pattern parallel strategies — semantic, keyword,
4entity, temporal, graph-expand — over the ``PageIndexStore`` SPI, then
5fuses their ranked outputs with RRF (Reciprocal Rank Fusion, k=60).
7This is the **retrieval** layer. The cross-encoder rerank + picker-as-
8reranker step (PR2 commit C) consumes this layer's output and feeds the
9synth.
11Per-mode strategy gating mirrors what we found in Phase A failure
12analysis:
13- **temporal questions** add the temporal strategy + a wider semantic
14 net (questions mentioning "May 2023" want the May-2023 sessions even
15 if the topic words don't match).
16- **multi-hop / multi-session questions** add the graph-expand strategy
17 to bridge across sessions via section_links.
18- **assistant-recall questions** (LME) override the keyword strategy
19 with a ``speaker='assistant'`` filter.
21Mode dispatch is simple regex/heuristic at PR2 commit B; PR2 commit D
22replaces it with a 1-token LLM classifier when the heuristic
23mis-routes.
25See:
26- ``docs/_design/recall.md`` §6 (recall pipeline)
27- ``docs/_design/adr/adr-006-three-layer-recall-stack.md``
28"""
30from __future__ import annotations
32from dataclasses import dataclass, field
33from datetime import datetime
34from typing import TYPE_CHECKING
36from astrocyte.pipeline.fusion import DEFAULT_RRF_K
38if TYPE_CHECKING:
39 from astrocyte.provider import LLMProvider, PageIndexStore
42# ── Result types ─────────────────────────────────────────────────────
45@dataclass
46class StrategyResult:
47 """One strategy's ranked output, plus the strategy name (for trace)
48 and timing (for performance regression detection)."""
50 strategy: str
51 hits: list[tuple[str, int, float]]
52 elapsed_ms: float
53 error: str | None = None
56@dataclass
57class FusedHit:
58 """One section after RRF fusion. ``rrf_score`` is the sum of
59 1/(k+rank) contributions across strategies that returned it.
60 ``per_strategy_rank`` is kept for trace + reranker input."""
62 document_id: str
63 line_num: int
64 rrf_score: float
65 per_strategy_rank: dict[str, int] = field(default_factory=dict)
68@dataclass
69class SectionRecallResult:
70 """Full output of one section recall call. Carries per-strategy
71 debug data so failure analysis can attribute regressions to the
72 right component."""
74 fused: list[FusedHit]
75 strategies: list[StrategyResult]
76 mode: str
77 elapsed_ms: float
78 # M10.1: wiki page hits found at recall time. Surfaced separately
79 # from ``fused`` because they carry pre-aggregated text the bench
80 # prepends to synth excerpts as ``[OBSERVATION]`` blocks rather
81 # than feeding through the picker. Empty list when no wiki tier or
82 # no hits cleared the score threshold.
83 wiki_hits: list = field(default_factory=list)
86# ── Section-grain RRF fusion (specialised for tuple hits) ─────────────
89def _rrf_fuse_section_hits(
90 ranked_lists: list[StrategyResult],
91 k: int = DEFAULT_RRF_K,
92) -> list[FusedHit]:
93 """RRF over ``(document_id, line_num, score)`` tuples. The existing
94 ``astrocyte.pipeline.fusion.rrf_fusion`` is keyed on a string ``id``;
95 section grain is a composite ``(doc, line)`` so we specialise here.
97 Hindsight uses k=60 (the original Cormack et al. RRF default); we
98 keep that. ``per_strategy_rank`` is preserved so the reranker (PR2
99 commit C) can inspect why a section was promoted.
100 """
101 accum: dict[tuple[str, int], FusedHit] = {}
102 for sr in ranked_lists:
103 if sr.error:
104 continue
105 for rank, (doc_id, line_num, _score) in enumerate(sr.hits, start=1):
106 key = (doc_id, line_num)
107 entry = accum.get(key)
108 if entry is None:
109 entry = FusedHit(
110 document_id=doc_id,
111 line_num=line_num,
112 rrf_score=0.0,
113 )
114 accum[key] = entry
115 entry.rrf_score += 1.0 / (k + rank)
116 entry.per_strategy_rank[sr.strategy] = rank
117 fused = sorted(accum.values(), key=lambda h: h.rrf_score, reverse=True)
118 return fused
121# ── Mode dispatch ────────────────────────────────────────────────────
124def select_strategies_for_mode(mode: str) -> set[str]:
125 """Per-mode strategy mix. Returns a set of strategy names; the
126 orchestrator only fires the named strategies (others get an empty
127 StrategyResult). PR2 commit D may replace this with an LLM-driven
128 weighted mix.
130 Defaults: every mode runs semantic + keyword + entity (the always-
131 on signals). Modes add temporal / graph_expand / speaker filters
132 on top.
133 """
134 # Always-on baseline.
135 base = {"semantic", "keyword", "entity"}
136 if mode in {"temporal", "temporal-reasoning"}:
137 return base | {"temporal"}
138 if mode in {"multi-hop", "multi-session", "knowledge-update"}:
139 return base | {"graph_expand"}
140 if mode in {"single-session-assistant", "assistant-recall"}:
141 # Keyword strategy is replaced with a speaker-filtered variant
142 # (handled inline by the orchestrator); same set.
143 return base
144 return base
147# ── Orchestrator ─────────────────────────────────────────────────────
150import asyncio # noqa: E402 — placed after types/helpers per module style
151import logging # noqa: E402
152import time # noqa: E402
154logger = logging.getLogger("astrocyte.pipeline.section_recall")
157async def section_recall(
158 *,
159 store: PageIndexStore,
160 bank_id: str,
161 question: str,
162 mode: str,
163 embedding_provider: LLMProvider,
164 question_entities: list[str] | None = None,
165 date_range: tuple[datetime, datetime] | None = None,
166 semantic_seed_count: int = 20,
167 rrf_k: int = DEFAULT_RRF_K,
168 per_strategy_top_k: int = 20,
169 wiki_enabled: bool = False,
170 wiki_document_id: str | None = None,
171 wiki_min_score: float = 0.55,
172 wiki_top_k: int = 3,
173 enable_spreading_activation: bool = False,
174 spreading_seed_count: int = 10,
175 spreading_top_k: int = 10,
176 session_filter: str | None = None,
177) -> SectionRecallResult:
178 """Run all selected strategies in parallel, RRF-fuse, return.
180 Operates on **sections** (PageIndex tree nodes) — the M9 middle
181 recall layer in ``recall.md``'s three-layer stack. Wiki recall
182 sits above this; raw memory_units below.
184 Args:
185 store: PageIndexStore SPI handle (in-memory or postgres).
186 bank_id: Scope to one bank (multi-bank later).
187 question: Raw question text — passed as-is to the keyword and
188 embedding strategies.
189 mode: Pre-computed mode label (e.g. "multi-hop", "temporal").
190 Drives which strategies fire.
191 embedding_provider: LLM provider with an ``embed`` method. Only
192 called when the semantic strategy is in the mix and the caller
193 didn't pre-compute the question embedding.
194 question_entities: Pre-extracted entities for the entity strategy.
195 When None, the entity strategy is skipped (caller didn't
196 prepare them — typically because PR2 commit D's question-
197 annotator hasn't run).
198 date_range: Pre-parsed date window for the temporal strategy.
199 When None, temporal strategy is skipped.
200 semantic_seed_count: Top-K for the semantic call.
201 rrf_k: RRF smoothing constant.
202 per_strategy_top_k: Top-K limit per strategy before fusion.
204 Returns:
205 ``SectionRecallResult`` with the fused list (sorted by rrf_score
206 desc) plus per-strategy traces for debugging.
207 """
208 t0 = time.monotonic()
209 selected = select_strategies_for_mode(mode)
211 # Build strategy coroutines lazily so we only embed the question
212 # when the semantic strategy is selected.
213 async def _semantic() -> StrategyResult:
214 ts = time.monotonic()
215 try:
216 embeds = await embedding_provider.embed([question])
217 qvec = embeds[0] if embeds else []
218 hits = await store.search_sections_semantic(
219 bank_id,
220 qvec,
221 top_k=semantic_seed_count,
222 session_filter=session_filter, # M31 Fix 2
223 )
224 return StrategyResult(
225 strategy="semantic",
226 hits=hits,
227 elapsed_ms=(time.monotonic() - ts) * 1000.0,
228 )
229 except Exception as exc: # noqa: BLE001
230 return StrategyResult(
231 strategy="semantic",
232 hits=[],
233 elapsed_ms=(time.monotonic() - ts) * 1000.0,
234 error=f"{type(exc).__name__}: {exc}",
235 )
237 async def _keyword() -> StrategyResult:
238 ts = time.monotonic()
239 try:
240 speaker = "assistant" if mode in {"single-session-assistant", "assistant-recall"} else None
241 hits = await store.search_sections_keyword(
242 bank_id,
243 question,
244 top_k=per_strategy_top_k,
245 speaker=speaker,
246 session_filter=session_filter, # M31 Fix 2
247 )
248 return StrategyResult(
249 strategy="keyword",
250 hits=hits,
251 elapsed_ms=(time.monotonic() - ts) * 1000.0,
252 )
253 except Exception as exc: # noqa: BLE001
254 return StrategyResult(
255 strategy="keyword",
256 hits=[],
257 elapsed_ms=(time.monotonic() - ts) * 1000.0,
258 error=f"{type(exc).__name__}: {exc}",
259 )
261 async def _entity() -> StrategyResult:
262 ts = time.monotonic()
263 if not question_entities:
264 return StrategyResult(strategy="entity", hits=[], elapsed_ms=0.0)
265 try:
266 hits = await store.search_sections_by_entities(
267 bank_id,
268 question_entities,
269 top_k=per_strategy_top_k,
270 session_filter=session_filter, # M31 Fix 2
271 )
272 return StrategyResult(
273 strategy="entity",
274 hits=hits,
275 elapsed_ms=(time.monotonic() - ts) * 1000.0,
276 )
277 except Exception as exc: # noqa: BLE001
278 return StrategyResult(
279 strategy="entity",
280 hits=[],
281 elapsed_ms=(time.monotonic() - ts) * 1000.0,
282 error=f"{type(exc).__name__}: {exc}",
283 )
285 async def _temporal() -> StrategyResult:
286 ts = time.monotonic()
287 if date_range is None:
288 return StrategyResult(strategy="temporal", hits=[], elapsed_ms=0.0)
289 try:
290 hits = await store.search_sections_temporal(
291 bank_id,
292 date_range,
293 top_k=per_strategy_top_k,
294 session_filter=session_filter, # M31 Fix 2
295 )
296 return StrategyResult(
297 strategy="temporal",
298 hits=hits,
299 elapsed_ms=(time.monotonic() - ts) * 1000.0,
300 )
301 except Exception as exc: # noqa: BLE001
302 return StrategyResult(
303 strategy="temporal",
304 hits=[],
305 elapsed_ms=(time.monotonic() - ts) * 1000.0,
306 error=f"{type(exc).__name__}: {exc}",
307 )
309 # Graph-expand needs seeds — uses the union of semantic + entity
310 # hits as inputs. We sequence it AFTER semantic + entity finish so
311 # we have something to expand from.
312 tasks: list = []
313 semantic_task = asyncio.create_task(_semantic()) if "semantic" in selected else None
314 keyword_task = asyncio.create_task(_keyword()) if "keyword" in selected else None
315 entity_task = asyncio.create_task(_entity()) if "entity" in selected else None
316 temporal_task = asyncio.create_task(_temporal()) if "temporal" in selected else None
317 for t in (semantic_task, keyword_task, entity_task, temporal_task):
318 if t is not None:
319 tasks.append(t)
321 initial_results = await asyncio.gather(*tasks)
322 by_name = {r.strategy: r for r in initial_results}
324 # Graph-expand: run after we have semantic + entity seeds.
325 if "graph_expand" in selected:
326 ts = time.monotonic()
327 seeds: list[tuple[str, int]] = []
328 for name in ("semantic", "entity"):
329 r = by_name.get(name)
330 if r:
331 # Take top-5 from each as seeds — enough for 1-hop
332 # bridge without exploding the join.
333 seeds.extend([(d, ln) for d, ln, _ in r.hits[:5]])
334 # Dedupe seeds preserving order.
335 seen: set[tuple[str, int]] = set()
336 unique_seeds = [s for s in seeds if not (s in seen or seen.add(s))]
337 try:
338 hits = await store.expand_section_links(
339 unique_seeds,
340 top_k=per_strategy_top_k,
341 )
342 initial_results.append(
343 StrategyResult(
344 strategy="graph_expand",
345 hits=hits,
346 elapsed_ms=(time.monotonic() - ts) * 1000.0,
347 )
348 )
349 except Exception as exc: # noqa: BLE001
350 initial_results.append(
351 StrategyResult(
352 strategy="graph_expand",
353 hits=[],
354 elapsed_ms=(time.monotonic() - ts) * 1000.0,
355 error=f"{type(exc).__name__}: {exc}",
356 )
357 )
359 fused = _rrf_fuse_section_hits(initial_results, k=rrf_k)
361 # M18a-3 — entity-co-occurrence spread (gated by `enable_spreading_activation`).
362 # After RRF fusion, expand the top-K seeds through the dense
363 # `astrocyte_pi_section_entities` table. Spread hits are appended to
364 # `fused` with synthetic FusedHit entries (`strategy="spreading"` rank,
365 # rrf_score scaled by spread score) so downstream callers iterate one
366 # uniform list. Distinct from `graph_expand` which uses the sparse
367 # LLM-link table. See `astrocyte.pipeline.spreading_activation`.
368 if enable_spreading_activation and fused:
369 from astrocyte.pipeline.spreading_activation import ( # noqa: PLC0415
370 expand_via_shared_entities,
371 )
373 seeds = [(h.document_id, h.line_num) for h in fused[:spreading_seed_count]]
374 already_in_fused = {(h.document_id, h.line_num) for h in fused}
375 try:
376 spread_hits = await expand_via_shared_entities(
377 store=store,
378 bank_id=bank_id,
379 seeds=seeds,
380 top_k=spreading_top_k,
381 )
382 except Exception as exc: # noqa: BLE001
383 logger.warning("section_recall: spreading_activation failed: %s", exc)
384 spread_hits = []
385 for doc_id, line_num, score in spread_hits:
386 if (doc_id, line_num) in already_in_fused:
387 continue
388 fused.append(
389 FusedHit(
390 document_id=doc_id,
391 line_num=line_num,
392 # Scale spread score onto the RRF range so it competes
393 # plausibly with fused entries — small but non-zero so
394 # the rerank still sees them as candidates worth scoring.
395 rrf_score=score * 0.1,
396 per_strategy_rank={"spreading": 0},
397 )
398 )
400 # M10.1 wiki tier — pre-aggregated observations sit ABOVE sections.
401 # When enabled, we semantic-search the bank's wiki pages and surface
402 # any high-confidence hit. The bench prepends these as
403 # ``[OBSERVATION]`` blocks to the synth excerpts so multi-session
404 # / preference questions read pre-aggregated facts instead of
405 # making the LLM aggregate from raw section text.
406 wiki_hits: list = []
407 if wiki_enabled:
408 ts = time.monotonic()
409 try:
410 # Reuse the embedding from the semantic strategy when present
411 # so we don't pay for a second embed call per question.
412 sem_result = by_name.get("semantic") if "semantic" in selected else None
413 if sem_result and sem_result.hits:
414 # The semantic strategy's embed already happened; refetch
415 # by embedding the question again is the simple path
416 # (one extra call). Cheap enough at our gate sizes.
417 qvec = (await embedding_provider.embed([question]))[0]
418 else:
419 qvec = (await embedding_provider.embed([question]))[0]
420 raw_hits = await store.search_wiki_pages_semantic(
421 bank_id,
422 qvec,
423 top_k=wiki_top_k,
424 document_id=wiki_document_id,
425 )
426 wiki_hits = [h for h in raw_hits if h.score >= wiki_min_score]
427 initial_results.append(
428 StrategyResult(
429 strategy="wiki",
430 hits=[(h.page_id, 0, h.score) for h in wiki_hits],
431 elapsed_ms=(time.monotonic() - ts) * 1000.0,
432 )
433 )
434 except Exception as exc: # noqa: BLE001
435 initial_results.append(
436 StrategyResult(
437 strategy="wiki",
438 hits=[],
439 elapsed_ms=(time.monotonic() - ts) * 1000.0,
440 error=f"{type(exc).__name__}: {exc}",
441 )
442 )
444 return SectionRecallResult(
445 fused=fused,
446 strategies=initial_results,
447 mode=mode,
448 elapsed_ms=(time.monotonic() - t0) * 1000.0,
449 wiki_hits=wiki_hits,
450 )