Coverage for astrocyte/pipeline/section_recall.py: 84%

142 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Section recall orchestrator (M9 PR2 commit B). 

2 

3Runs the five Hindsight-pattern parallel strategies — semantic, keyword, 

4entity, temporal, graph-expand — over the ``PageIndexStore`` SPI, then 

5fuses their ranked outputs with RRF (Reciprocal Rank Fusion, k=60). 

6 

7This is the **retrieval** layer. The cross-encoder rerank + picker-as- 

8reranker step (PR2 commit C) consumes this layer's output and feeds the 

9synth. 

10 

11Per-mode strategy gating mirrors what we found in Phase A failure 

12analysis: 

13- **temporal questions** add the temporal strategy + a wider semantic 

14 net (questions mentioning "May 2023" want the May-2023 sessions even 

15 if the topic words don't match). 

16- **multi-hop / multi-session questions** add the graph-expand strategy 

17 to bridge across sessions via section_links. 

18- **assistant-recall questions** (LME) override the keyword strategy 

19 with a ``speaker='assistant'`` filter. 

20 

21Mode dispatch is simple regex/heuristic at PR2 commit B; PR2 commit D 

22replaces it with a 1-token LLM classifier when the heuristic 

23mis-routes. 

24 

25See: 

26- ``docs/_design/recall.md`` §6 (recall pipeline) 

27- ``docs/_design/adr/adr-006-three-layer-recall-stack.md`` 

28""" 

29 

30from __future__ import annotations 

31 

32from dataclasses import dataclass, field 

33from datetime import datetime 

34from typing import TYPE_CHECKING 

35 

36from astrocyte.pipeline.fusion import DEFAULT_RRF_K 

37 

38if TYPE_CHECKING: 

39 from astrocyte.provider import LLMProvider, PageIndexStore 

40 

41 

42# ── Result types ───────────────────────────────────────────────────── 

43 

44 

45@dataclass 

46class StrategyResult: 

47 """One strategy's ranked output, plus the strategy name (for trace) 

48 and timing (for performance regression detection).""" 

49 

50 strategy: str 

51 hits: list[tuple[str, int, float]] 

52 elapsed_ms: float 

53 error: str | None = None 

54 

55 

56@dataclass 

57class FusedHit: 

58 """One section after RRF fusion. ``rrf_score`` is the sum of 

59 1/(k+rank) contributions across strategies that returned it. 

60 ``per_strategy_rank`` is kept for trace + reranker input.""" 

61 

62 document_id: str 

63 line_num: int 

64 rrf_score: float 

65 per_strategy_rank: dict[str, int] = field(default_factory=dict) 

66 

67 

68@dataclass 

69class SectionRecallResult: 

70 """Full output of one section recall call. Carries per-strategy 

71 debug data so failure analysis can attribute regressions to the 

72 right component.""" 

73 

74 fused: list[FusedHit] 

75 strategies: list[StrategyResult] 

76 mode: str 

77 elapsed_ms: float 

78 # M10.1: wiki page hits found at recall time. Surfaced separately 

79 # from ``fused`` because they carry pre-aggregated text the bench 

80 # prepends to synth excerpts as ``[OBSERVATION]`` blocks rather 

81 # than feeding through the picker. Empty list when no wiki tier or 

82 # no hits cleared the score threshold. 

83 wiki_hits: list = field(default_factory=list) 

84 

85 

86# ── Section-grain RRF fusion (specialised for tuple hits) ───────────── 

87 

88 

89def _rrf_fuse_section_hits( 

90 ranked_lists: list[StrategyResult], 

91 k: int = DEFAULT_RRF_K, 

92) -> list[FusedHit]: 

93 """RRF over ``(document_id, line_num, score)`` tuples. The existing 

94 ``astrocyte.pipeline.fusion.rrf_fusion`` is keyed on a string ``id``; 

95 section grain is a composite ``(doc, line)`` so we specialise here. 

96 

97 Hindsight uses k=60 (the original Cormack et al. RRF default); we 

98 keep that. ``per_strategy_rank`` is preserved so the reranker (PR2 

99 commit C) can inspect why a section was promoted. 

100 """ 

101 accum: dict[tuple[str, int], FusedHit] = {} 

102 for sr in ranked_lists: 

103 if sr.error: 

104 continue 

105 for rank, (doc_id, line_num, _score) in enumerate(sr.hits, start=1): 

106 key = (doc_id, line_num) 

107 entry = accum.get(key) 

108 if entry is None: 

109 entry = FusedHit( 

110 document_id=doc_id, 

111 line_num=line_num, 

112 rrf_score=0.0, 

113 ) 

114 accum[key] = entry 

115 entry.rrf_score += 1.0 / (k + rank) 

116 entry.per_strategy_rank[sr.strategy] = rank 

117 fused = sorted(accum.values(), key=lambda h: h.rrf_score, reverse=True) 

118 return fused 

119 

120 

121# ── Mode dispatch ──────────────────────────────────────────────────── 

122 

123 

124def select_strategies_for_mode(mode: str) -> set[str]: 

125 """Per-mode strategy mix. Returns a set of strategy names; the 

126 orchestrator only fires the named strategies (others get an empty 

127 StrategyResult). PR2 commit D may replace this with an LLM-driven 

128 weighted mix. 

129 

130 Defaults: every mode runs semantic + keyword + entity (the always- 

131 on signals). Modes add temporal / graph_expand / speaker filters 

132 on top. 

133 """ 

134 # Always-on baseline. 

135 base = {"semantic", "keyword", "entity"} 

136 if mode in {"temporal", "temporal-reasoning"}: 

137 return base | {"temporal"} 

138 if mode in {"multi-hop", "multi-session", "knowledge-update"}: 

139 return base | {"graph_expand"} 

140 if mode in {"single-session-assistant", "assistant-recall"}: 

141 # Keyword strategy is replaced with a speaker-filtered variant 

142 # (handled inline by the orchestrator); same set. 

143 return base 

144 return base 

145 

146 

147# ── Orchestrator ───────────────────────────────────────────────────── 

148 

149 

150import asyncio # noqa: E402 — placed after types/helpers per module style 

151import logging # noqa: E402 

152import time # noqa: E402 

153 

154logger = logging.getLogger("astrocyte.pipeline.section_recall") 

155 

156 

157async def section_recall( 

158 *, 

159 store: PageIndexStore, 

160 bank_id: str, 

161 question: str, 

162 mode: str, 

163 embedding_provider: LLMProvider, 

164 question_entities: list[str] | None = None, 

165 date_range: tuple[datetime, datetime] | None = None, 

166 semantic_seed_count: int = 20, 

167 rrf_k: int = DEFAULT_RRF_K, 

168 per_strategy_top_k: int = 20, 

169 wiki_enabled: bool = False, 

170 wiki_document_id: str | None = None, 

171 wiki_min_score: float = 0.55, 

172 wiki_top_k: int = 3, 

173 enable_spreading_activation: bool = False, 

174 spreading_seed_count: int = 10, 

175 spreading_top_k: int = 10, 

176 session_filter: str | None = None, 

177) -> SectionRecallResult: 

178 """Run all selected strategies in parallel, RRF-fuse, return. 

179 

180 Operates on **sections** (PageIndex tree nodes) — the M9 middle 

181 recall layer in ``recall.md``'s three-layer stack. Wiki recall 

182 sits above this; raw memory_units below. 

183 

184 Args: 

185 store: PageIndexStore SPI handle (in-memory or postgres). 

186 bank_id: Scope to one bank (multi-bank later). 

187 question: Raw question text — passed as-is to the keyword and 

188 embedding strategies. 

189 mode: Pre-computed mode label (e.g. "multi-hop", "temporal"). 

190 Drives which strategies fire. 

191 embedding_provider: LLM provider with an ``embed`` method. Only 

192 called when the semantic strategy is in the mix and the caller 

193 didn't pre-compute the question embedding. 

194 question_entities: Pre-extracted entities for the entity strategy. 

195 When None, the entity strategy is skipped (caller didn't 

196 prepare them — typically because PR2 commit D's question- 

197 annotator hasn't run). 

198 date_range: Pre-parsed date window for the temporal strategy. 

199 When None, temporal strategy is skipped. 

200 semantic_seed_count: Top-K for the semantic call. 

201 rrf_k: RRF smoothing constant. 

202 per_strategy_top_k: Top-K limit per strategy before fusion. 

203 

204 Returns: 

205 ``SectionRecallResult`` with the fused list (sorted by rrf_score 

206 desc) plus per-strategy traces for debugging. 

207 """ 

208 t0 = time.monotonic() 

209 selected = select_strategies_for_mode(mode) 

210 

211 # Build strategy coroutines lazily so we only embed the question 

212 # when the semantic strategy is selected. 

213 async def _semantic() -> StrategyResult: 

214 ts = time.monotonic() 

215 try: 

216 embeds = await embedding_provider.embed([question]) 

217 qvec = embeds[0] if embeds else [] 

218 hits = await store.search_sections_semantic( 

219 bank_id, 

220 qvec, 

221 top_k=semantic_seed_count, 

222 session_filter=session_filter, # M31 Fix 2 

223 ) 

224 return StrategyResult( 

225 strategy="semantic", 

226 hits=hits, 

227 elapsed_ms=(time.monotonic() - ts) * 1000.0, 

228 ) 

229 except Exception as exc: # noqa: BLE001 

230 return StrategyResult( 

231 strategy="semantic", 

232 hits=[], 

233 elapsed_ms=(time.monotonic() - ts) * 1000.0, 

234 error=f"{type(exc).__name__}: {exc}", 

235 ) 

236 

237 async def _keyword() -> StrategyResult: 

238 ts = time.monotonic() 

239 try: 

240 speaker = "assistant" if mode in {"single-session-assistant", "assistant-recall"} else None 

241 hits = await store.search_sections_keyword( 

242 bank_id, 

243 question, 

244 top_k=per_strategy_top_k, 

245 speaker=speaker, 

246 session_filter=session_filter, # M31 Fix 2 

247 ) 

248 return StrategyResult( 

249 strategy="keyword", 

250 hits=hits, 

251 elapsed_ms=(time.monotonic() - ts) * 1000.0, 

252 ) 

253 except Exception as exc: # noqa: BLE001 

254 return StrategyResult( 

255 strategy="keyword", 

256 hits=[], 

257 elapsed_ms=(time.monotonic() - ts) * 1000.0, 

258 error=f"{type(exc).__name__}: {exc}", 

259 ) 

260 

261 async def _entity() -> StrategyResult: 

262 ts = time.monotonic() 

263 if not question_entities: 

264 return StrategyResult(strategy="entity", hits=[], elapsed_ms=0.0) 

265 try: 

266 hits = await store.search_sections_by_entities( 

267 bank_id, 

268 question_entities, 

269 top_k=per_strategy_top_k, 

270 session_filter=session_filter, # M31 Fix 2 

271 ) 

272 return StrategyResult( 

273 strategy="entity", 

274 hits=hits, 

275 elapsed_ms=(time.monotonic() - ts) * 1000.0, 

276 ) 

277 except Exception as exc: # noqa: BLE001 

278 return StrategyResult( 

279 strategy="entity", 

280 hits=[], 

281 elapsed_ms=(time.monotonic() - ts) * 1000.0, 

282 error=f"{type(exc).__name__}: {exc}", 

283 ) 

284 

285 async def _temporal() -> StrategyResult: 

286 ts = time.monotonic() 

287 if date_range is None: 

288 return StrategyResult(strategy="temporal", hits=[], elapsed_ms=0.0) 

289 try: 

290 hits = await store.search_sections_temporal( 

291 bank_id, 

292 date_range, 

293 top_k=per_strategy_top_k, 

294 session_filter=session_filter, # M31 Fix 2 

295 ) 

296 return StrategyResult( 

297 strategy="temporal", 

298 hits=hits, 

299 elapsed_ms=(time.monotonic() - ts) * 1000.0, 

300 ) 

301 except Exception as exc: # noqa: BLE001 

302 return StrategyResult( 

303 strategy="temporal", 

304 hits=[], 

305 elapsed_ms=(time.monotonic() - ts) * 1000.0, 

306 error=f"{type(exc).__name__}: {exc}", 

307 ) 

308 

309 # Graph-expand needs seeds — uses the union of semantic + entity 

310 # hits as inputs. We sequence it AFTER semantic + entity finish so 

311 # we have something to expand from. 

312 tasks: list = [] 

313 semantic_task = asyncio.create_task(_semantic()) if "semantic" in selected else None 

314 keyword_task = asyncio.create_task(_keyword()) if "keyword" in selected else None 

315 entity_task = asyncio.create_task(_entity()) if "entity" in selected else None 

316 temporal_task = asyncio.create_task(_temporal()) if "temporal" in selected else None 

317 for t in (semantic_task, keyword_task, entity_task, temporal_task): 

318 if t is not None: 

319 tasks.append(t) 

320 

321 initial_results = await asyncio.gather(*tasks) 

322 by_name = {r.strategy: r for r in initial_results} 

323 

324 # Graph-expand: run after we have semantic + entity seeds. 

325 if "graph_expand" in selected: 

326 ts = time.monotonic() 

327 seeds: list[tuple[str, int]] = [] 

328 for name in ("semantic", "entity"): 

329 r = by_name.get(name) 

330 if r: 

331 # Take top-5 from each as seeds — enough for 1-hop 

332 # bridge without exploding the join. 

333 seeds.extend([(d, ln) for d, ln, _ in r.hits[:5]]) 

334 # Dedupe seeds preserving order. 

335 seen: set[tuple[str, int]] = set() 

336 unique_seeds = [s for s in seeds if not (s in seen or seen.add(s))] 

337 try: 

338 hits = await store.expand_section_links( 

339 unique_seeds, 

340 top_k=per_strategy_top_k, 

341 ) 

342 initial_results.append( 

343 StrategyResult( 

344 strategy="graph_expand", 

345 hits=hits, 

346 elapsed_ms=(time.monotonic() - ts) * 1000.0, 

347 ) 

348 ) 

349 except Exception as exc: # noqa: BLE001 

350 initial_results.append( 

351 StrategyResult( 

352 strategy="graph_expand", 

353 hits=[], 

354 elapsed_ms=(time.monotonic() - ts) * 1000.0, 

355 error=f"{type(exc).__name__}: {exc}", 

356 ) 

357 ) 

358 

359 fused = _rrf_fuse_section_hits(initial_results, k=rrf_k) 

360 

361 # M18a-3 — entity-co-occurrence spread (gated by `enable_spreading_activation`). 

362 # After RRF fusion, expand the top-K seeds through the dense 

363 # `astrocyte_pi_section_entities` table. Spread hits are appended to 

364 # `fused` with synthetic FusedHit entries (`strategy="spreading"` rank, 

365 # rrf_score scaled by spread score) so downstream callers iterate one 

366 # uniform list. Distinct from `graph_expand` which uses the sparse 

367 # LLM-link table. See `astrocyte.pipeline.spreading_activation`. 

368 if enable_spreading_activation and fused: 

369 from astrocyte.pipeline.spreading_activation import ( # noqa: PLC0415 

370 expand_via_shared_entities, 

371 ) 

372 

373 seeds = [(h.document_id, h.line_num) for h in fused[:spreading_seed_count]] 

374 already_in_fused = {(h.document_id, h.line_num) for h in fused} 

375 try: 

376 spread_hits = await expand_via_shared_entities( 

377 store=store, 

378 bank_id=bank_id, 

379 seeds=seeds, 

380 top_k=spreading_top_k, 

381 ) 

382 except Exception as exc: # noqa: BLE001 

383 logger.warning("section_recall: spreading_activation failed: %s", exc) 

384 spread_hits = [] 

385 for doc_id, line_num, score in spread_hits: 

386 if (doc_id, line_num) in already_in_fused: 

387 continue 

388 fused.append( 

389 FusedHit( 

390 document_id=doc_id, 

391 line_num=line_num, 

392 # Scale spread score onto the RRF range so it competes 

393 # plausibly with fused entries — small but non-zero so 

394 # the rerank still sees them as candidates worth scoring. 

395 rrf_score=score * 0.1, 

396 per_strategy_rank={"spreading": 0}, 

397 ) 

398 ) 

399 

400 # M10.1 wiki tier — pre-aggregated observations sit ABOVE sections. 

401 # When enabled, we semantic-search the bank's wiki pages and surface 

402 # any high-confidence hit. The bench prepends these as 

403 # ``[OBSERVATION]`` blocks to the synth excerpts so multi-session 

404 # / preference questions read pre-aggregated facts instead of 

405 # making the LLM aggregate from raw section text. 

406 wiki_hits: list = [] 

407 if wiki_enabled: 

408 ts = time.monotonic() 

409 try: 

410 # Reuse the embedding from the semantic strategy when present 

411 # so we don't pay for a second embed call per question. 

412 sem_result = by_name.get("semantic") if "semantic" in selected else None 

413 if sem_result and sem_result.hits: 

414 # The semantic strategy's embed already happened; refetch 

415 # by embedding the question again is the simple path 

416 # (one extra call). Cheap enough at our gate sizes. 

417 qvec = (await embedding_provider.embed([question]))[0] 

418 else: 

419 qvec = (await embedding_provider.embed([question]))[0] 

420 raw_hits = await store.search_wiki_pages_semantic( 

421 bank_id, 

422 qvec, 

423 top_k=wiki_top_k, 

424 document_id=wiki_document_id, 

425 ) 

426 wiki_hits = [h for h in raw_hits if h.score >= wiki_min_score] 

427 initial_results.append( 

428 StrategyResult( 

429 strategy="wiki", 

430 hits=[(h.page_id, 0, h.score) for h in wiki_hits], 

431 elapsed_ms=(time.monotonic() - ts) * 1000.0, 

432 ) 

433 ) 

434 except Exception as exc: # noqa: BLE001 

435 initial_results.append( 

436 StrategyResult( 

437 strategy="wiki", 

438 hits=[], 

439 elapsed_ms=(time.monotonic() - ts) * 1000.0, 

440 error=f"{type(exc).__name__}: {exc}", 

441 ) 

442 ) 

443 

444 return SectionRecallResult( 

445 fused=fused, 

446 strategies=initial_results, 

447 mode=mode, 

448 elapsed_ms=(time.monotonic() - t0) * 1000.0, 

449 wiki_hits=wiki_hits, 

450 )