Coverage for astrocyte/pipeline/fact_recall.py: 93%
141 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Fact-grain recall — core entry point with RRF fusion.
3Unified fact-grain search that runs up to three strategies as PARALLEL
4SIBLINGS and merges them via Reciprocal Rank Fusion (Hindsight-parity
5architecture; see docs/_design/m18-quick-wins.md §3.3):
7 - **Semantic fact search** — always runs. Cosine over fact-text
8 embeddings via ``store.search_facts_semantic``.
9 - **Episodic fact search** (M18a-4, gated) — when
10 ``config.episodic_extract.enabled`` AND the question matches an
11 episodic cue (``question_has_episodic_cue``), additionally search
12 by ``EPISODIC_MARKER`` entity to surface facts tagged at retain
13 time.
14 - **Temporal fact search** (M18a-1 Pass B integration) — when the
15 caller passes ``temporal_range=(start, end)``, additionally search
16 by ``occurred_start`` overlapping the range via
17 ``store.search_facts_temporal``.
19Why RRF instead of append-then-rerank:
20 RRF scores each candidate by ``Σ 1/(k + rank)`` across the strategies
21 that surfaced it. A junk candidate that only appears at rank-1 of a
22 bogus temporal hit contributes ``1/(60+1) ≈ 0.016`` to the final
23 fusion score; a real answer that's rank-1 in semantic AND rank-3 in
24 temporal contributes ``1/61 + 1/63 ≈ 0.032`` — twice as much.
25 False-positive dateparser hits get damped automatically, and
26 cross-strategy agreement gets rewarded.
28 Compare to the old "append everything, let the cross-encoder sort it
29 out" path: the reranker is source-blind, so a junk candidate whose
30 text happens to vaguely match the query can displace a real one
31 purely from rerank score noise.
33Public API:
34 fact_recall(*, store, bank_id, document_id, query, query_embedding,
35 config, temporal_range=None,
36 top_k_semantic=40, top_k_episodic=20, top_k_temporal=20,
37 rrf_k=60)
38 -> list[PageIndexFact]
40Backward compatibility:
41 Callers that don't pass ``temporal_range`` get the same 2-strategy
42 semantic+(optional episodic) recall behavior. The fused result is
43 ordered by RRF score (descending), which is a different ordering
44 than the legacy semantic-then-episodic-append, but the downstream
45 cross-encoder rerank consumes the list as an unordered candidate
46 pool — so the only observable change for those callers is that
47 episodic hits which co-occur with semantic hits get a slightly
48 higher fusion-induced ranking, which can only help the reranker.
49"""
51from __future__ import annotations
53import asyncio
54import logging
55from typing import TYPE_CHECKING, Any
57from astrocyte.pipeline.fusion import DEFAULT_RRF_K
58from astrocyte.pipeline.intent_weights import weights_for_intent
60if TYPE_CHECKING:
61 from datetime import datetime
63 from astrocyte.config import AstrocyteConfig
64 from astrocyte.pipeline.query_intent import QueryIntent
65 from astrocyte.types import PageIndexFact
67_logger = logging.getLogger("astrocyte.pipeline.fact_recall")
70async def fact_recall(
71 *,
72 store: Any,
73 bank_id: str,
74 document_id: str | None,
75 query: str,
76 query_embedding: list[float],
77 config: AstrocyteConfig,
78 temporal_range: tuple[datetime, datetime] | None = None,
79 query_entities: list[str] | None = None,
80 top_k_semantic: int = 40,
81 top_k_episodic: int = 20,
82 top_k_temporal: int = 20,
83 top_k_link_expansion: int = 20,
84 top_k_keyword: int = 20, # M31c BM25 over fact_text
85 rrf_k: int = DEFAULT_RRF_K,
86 session_filter: str | None = None,
87 intent: QueryIntent | None = None, # M34-2 — intent-weighted RRF
88 fact_types: list[str] | None = None, # M34-4 — per-fact-type segmentation
89 max_tokens: int | None = None, # M35-2 — token budget cap on merged output
90) -> list[PageIndexFact]:
91 """Run unified fact-grain recall and return an RRF-fused ranked list.
93 Always runs semantic search. Optionally runs:
94 - Episodic — when ``config.episodic_extract.enabled`` AND
95 ``question_has_episodic_cue(query)`` matches.
96 - Temporal — when ``temporal_range`` is provided (caller
97 typically gets this from ``query_analyzer.analyze_query``).
98 - **Link-expansion (M27)** — when ``query_entities`` is provided
99 and non-empty. For each query entity, fetches facts whose
100 ``entities`` array contains that entity, ACROSS ALL SESSIONS
101 in the bank (no ``document_id`` filter — that's the whole
102 point of the cross-session graph traversal). Mirrors
103 Hindsight's ``link_expansion_retrieval`` strategy. Lets
104 questions like "who else did I discuss this with" surface
105 facts from sessions the query embedding wouldn't naturally
106 reach.
108 Branches run in parallel via ``asyncio.gather``. Per-branch failures
109 are isolated (logged + treated as empty list).
111 M34-4 — when ``fact_types`` is provided, retrieval runs **per
112 fact_type**: each fact_type gets its own 4-channel run (filtered
113 via the M34-3 SPI param) and its own RRF-fused pool. Final result
114 is the concatenation of each pool's top-``per_type_k`` (where
115 ``per_type_k = ceil(final_top_n / len(fact_types))``). This
116 Hindsight-parity segmentation prevents a flood in one channel
117 (e.g. temporal returning experience facts) from displacing
118 relevant facts of other types (e.g. preference). See
119 ``docs/_design/m34-query-intent-routing.md`` for the v015i/v015j
120 forensic that motivated this.
122 Returns:
123 Facts ranked by RRF score (highest first). Dedupe by ``fact_id``.
124 The caller's downstream cross-encoder rerank picks the final
125 ranking; RRF here primarily ensures that source-blind rerank
126 doesn't get polluted by single-strategy junk.
127 """
128 # M34-4 — when fact_types is provided, segment retrieval per type.
129 # Default (None) preserves the pre-M34 single-pool behaviour for BC.
130 if fact_types:
131 # M35-2 — per-type pool is fetched generously (~50 each); the
132 # final pack_to_budget step trims the merged result by token
133 # count. This gives each fact_type a fair shot at contributing
134 # to the budget without a hard per-type item cap.
135 per_type_pools: list[list[PageIndexFact]] = []
136 for ft in fact_types:
137 pool = await _run_channels_and_fuse(
138 store=store, bank_id=bank_id, document_id=document_id,
139 query=query, query_embedding=query_embedding, config=config,
140 temporal_range=temporal_range, query_entities=query_entities,
141 top_k_semantic=top_k_semantic, top_k_episodic=top_k_episodic,
142 top_k_temporal=top_k_temporal,
143 top_k_link_expansion=top_k_link_expansion,
144 rrf_k=rrf_k, session_filter=session_filter,
145 intent=intent, fact_type=ft,
146 )
147 per_type_pools.append(pool)
148 return _merge_per_type_pools(per_type_pools, max_tokens=max_tokens)
150 pool = await _run_channels_and_fuse(
151 store=store, bank_id=bank_id, document_id=document_id,
152 query=query, query_embedding=query_embedding, config=config,
153 temporal_range=temporal_range, query_entities=query_entities,
154 top_k_semantic=top_k_semantic, top_k_episodic=top_k_episodic,
155 top_k_temporal=top_k_temporal,
156 top_k_link_expansion=top_k_link_expansion,
157 rrf_k=rrf_k, session_filter=session_filter,
158 intent=intent, fact_type=None,
159 )
160 # M35-2 — apply token budget to the single-pool path too.
161 if max_tokens is not None and max_tokens > 0:
162 from astrocyte.pipeline.token_budget import pack_to_budget # noqa: PLC0415
164 pool = pack_to_budget(
165 pool,
166 max_tokens=max_tokens,
167 text_of=lambda f: getattr(f, "text", "") or "",
168 )
169 return pool
172async def _run_channels_and_fuse(
173 *,
174 store: Any,
175 bank_id: str,
176 document_id: str | None,
177 query: str,
178 query_embedding: list[float],
179 config: AstrocyteConfig,
180 temporal_range: tuple[datetime, datetime] | None,
181 query_entities: list[str] | None,
182 top_k_semantic: int,
183 top_k_episodic: int,
184 top_k_temporal: int,
185 top_k_link_expansion: int,
186 rrf_k: int,
187 session_filter: str | None,
188 intent: QueryIntent | None,
189 fact_type: str | None,
190) -> list[PageIndexFact]:
191 """Run the 4 channels in parallel + fuse. Extracted from
192 :func:`fact_recall` so per-fact-type segmentation can call it once
193 per type. When ``fact_type`` is non-None, each SPI call filters to
194 that single fact_type (M34-3)."""
195 # Resolve the episodic gate cheaply (no DB call) before scheduling tasks.
196 want_episodic = _is_episodic_enabled(config) and _question_has_episodic_cue(query)
198 # M31 Fix 2 — session_filter applies to semantic / episodic /
199 # temporal branches (which would otherwise span all sessions in the
200 # document) but DELIBERATELY NOT to link-expansion: link-expansion's
201 # purpose is cross-session entity traversal, so constraining it to
202 # one session defeats the point. Real systems passing session_id
203 # still benefit from cross-session entity matches surfacing in the
204 # candidate pool — the cross-encoder rerank picks the best.
205 semantic_task = _safe_call(
206 "semantic",
207 store.search_facts_semantic(
208 bank_id, query_embedding,
209 top_k=top_k_semantic, document_id=document_id,
210 fact_type=fact_type, # M34-3
211 session_filter=session_filter,
212 ),
213 )
214 tasks: list[asyncio.Task[list[PageIndexFact]]] = [
215 asyncio.create_task(semantic_task, name="fact_recall.semantic"),
216 ]
218 episodic_idx: int | None = None
219 if want_episodic:
220 episodic_idx = len(tasks)
221 tasks.append(
222 asyncio.create_task(
223 _safe_call(
224 "episodic",
225 _search_episodic(
226 store, bank_id, document_id, top_k_episodic,
227 fact_type=fact_type,
228 session_filter=session_filter,
229 ),
230 ),
231 name="fact_recall.episodic",
232 ),
233 )
235 temporal_idx: int | None = None
236 if temporal_range is not None:
237 temporal_idx = len(tasks)
238 tasks.append(
239 asyncio.create_task(
240 _safe_call(
241 "temporal",
242 store.search_facts_temporal(
243 bank_id, temporal_range,
244 top_k=top_k_temporal, document_id=document_id,
245 fact_type=fact_type, # M34-3
246 session_filter=session_filter,
247 ),
248 ),
249 name="fact_recall.temporal",
250 ),
251 )
253 # M27 — link-expansion: cross-session entity-graph traversal.
254 # No ``document_id`` filter AND no ``session_filter`` — the whole
255 # point is to surface facts from OTHER sessions that share entities
256 # with the query (see M31 Fix 2 design note above).
257 link_idx: int | None = None
258 if query_entities:
259 link_idx = len(tasks)
260 tasks.append(
261 asyncio.create_task(
262 _safe_call(
263 "link_expansion",
264 _search_link_expansion(
265 store, bank_id, query_entities, top_k_link_expansion,
266 fact_type=fact_type,
267 ),
268 ),
269 name="fact_recall.link_expansion",
270 ),
271 )
273 # M34-5 — BM25 keyword channel, intent-gated. The 5th-sibling
274 # regression in M31c was caused by uniform-weight RRF flooding
275 # synthesis-heavy categories. With intent weights, BM25 only
276 # contributes meaningfully when the intent prefers it (FACTUAL
277 # weights bm25=1.5, others 1.0 or below). When intent is None
278 # (pre-M34 BC path), BM25 stays off entirely — preserves the
279 # M31c-era decision until the bench wiring (M34-6) starts passing
280 # intent.
281 keyword_idx: int | None = None
282 if intent is not None and weights_for_intent(intent).bm25 > 0.0 and hasattr(store, "search_facts_keyword"):
283 keyword_idx = len(tasks)
284 tasks.append(
285 asyncio.create_task(
286 _safe_call(
287 "keyword",
288 store.search_facts_keyword(
289 bank_id, query,
290 top_k=20, # bound; intent weight controls effective influence
291 document_id=document_id,
292 fact_type=fact_type,
293 session_filter=session_filter,
294 ),
295 ),
296 name="fact_recall.keyword",
297 ),
298 )
300 results = await asyncio.gather(*tasks)
301 semantic_hits = results[0]
302 episodic_hits = results[episodic_idx] if episodic_idx is not None else []
303 temporal_hits = results[temporal_idx] if temporal_idx is not None else []
304 link_hits = results[link_idx] if link_idx is not None else []
305 keyword_hits = results[keyword_idx] if keyword_idx is not None else []
307 # M34-2 — intent-weighted RRF. When ``intent`` is None we fall back
308 # to equal-weight fusion (identical to pre-M34 behaviour). When
309 # provided, the intent's per-channel weights bias which strategy
310 # contributes most. See ``astrocyte.pipeline.intent_weights`` for
311 # the calibration table and rationale.
312 if intent is None:
313 return _rrf_fuse_fact_hits(
314 [semantic_hits, episodic_hits, temporal_hits, link_hits],
315 k=rrf_k,
316 )
318 w = weights_for_intent(intent)
319 return _rrf_fuse_fact_hits_weighted(
320 [
321 (semantic_hits, w.semantic),
322 (episodic_hits, w.episodic),
323 (temporal_hits, w.temporal),
324 (link_hits, w.link_expansion),
325 (keyword_hits, w.bm25),
326 ],
327 k=rrf_k,
328 )
331def _merge_per_type_pools(
332 pools: list[list[PageIndexFact]],
333 *,
334 max_tokens: int | None,
335) -> list[PageIndexFact]:
336 """M34-4 + M35-2 — round-robin interleave per-type pools, dedupe by
337 fact_id, then token-budget cap.
339 Within-pool order is preserved (per-type RRF rank). Cross-pool order
340 is round-robin so we don't bias toward whichever fact_type happens
341 to be first in ``fact_types`` — round-robin gives every type's top
342 hit a slot before any type's second hit.
344 M35-2: the final trim is by ``max_tokens`` (token budget) rather
345 than item count. When ``max_tokens`` is None, all deduped items
346 are returned (legacy callers + tests can opt out).
347 """
348 # Round-robin interleave so each type contributes alternately.
349 interleaved: list[PageIndexFact] = []
350 max_len = max((len(p) for p in pools), default=0)
351 for i in range(max_len):
352 for pool in pools:
353 if i < len(pool):
354 interleaved.append(pool[i])
356 # Dedupe by fact_id preserving first-seen order.
357 seen: set[str] = set()
358 out: list[PageIndexFact] = []
359 for hit in interleaved:
360 fid = getattr(hit, "fact_id", None)
361 if fid is None or fid in seen:
362 continue
363 seen.add(fid)
364 out.append(hit)
366 if max_tokens is not None and max_tokens > 0:
367 from astrocyte.pipeline.token_budget import pack_to_budget # noqa: PLC0415
369 out = pack_to_budget(
370 out,
371 max_tokens=max_tokens,
372 text_of=lambda f: getattr(f, "text", "") or "",
373 )
374 return out
377# ─────────────────────────────────────────────────────────────────────────
378# Internal helpers
379# ─────────────────────────────────────────────────────────────────────────
382async def _safe_call(
383 branch_name: str,
384 coro: Any,
385) -> list[PageIndexFact]:
386 """Await ``coro`` and treat exceptions as empty result.
388 Per-branch failure isolation is critical for the recall path —
389 a temporary DB index issue on (say) the temporal SPI must NOT
390 take down semantic recall too.
391 """
392 try:
393 return await coro
394 except Exception as exc: # noqa: BLE001
395 _logger.warning("fact_recall: %s branch failed: %s", branch_name, exc)
396 return []
399async def _search_episodic(
400 store: Any,
401 bank_id: str,
402 document_id: str | None,
403 top_k: int,
404 *,
405 fact_type: str | None = None, # M34-3 — per-fact-type segmentation
406 session_filter: str | None = None,
407) -> list[PageIndexFact]:
408 """Lazy import of EPISODIC_MARKER + search_facts_by_entity call.
410 M31 Fix 2: episodic facts are session-anchored (a section emits at
411 most one EPISODIC_MARKER fact for its session), so session_filter
412 naturally scopes the result to the matching session's episodic facts.
413 """
414 from astrocyte.pipeline.episodic_extract import EPISODIC_MARKER # noqa: PLC0415
416 return await store.search_facts_by_entity(
417 bank_id, EPISODIC_MARKER,
418 top_k=top_k, document_id=document_id,
419 fact_type=fact_type,
420 session_filter=session_filter,
421 )
424async def _search_link_expansion(
425 store: Any,
426 bank_id: str,
427 query_entities: list[str],
428 top_k_per_entity: int,
429 *,
430 fact_type: str | None = None, # M34-3 — per-fact-type segmentation
431) -> list[PageIndexFact]:
432 """M27 — cross-session entity-graph traversal.
434 For each query entity, fetch facts whose ``entities`` array
435 contains it (case-insensitive via ``search_facts_by_entity``).
436 Crucially passes ``document_id=None`` so the search spans ALL
437 sessions in the bank — that's the "cross-session" part. Single-
438 session matches will also show up via the semantic strategy; this
439 branch's value-add is the multi-session hits.
441 Per-entity results are interleaved with dedupe-by-fact-id. The
442 cap is per entity to bound cost when the query has many entities;
443 the RRF fusion downstream will pick the most-prominent facts
444 across the combined pool.
446 Hindsight reference: ``hindsight_api/engine/search/link_expansion_retrieval.py``.
447 """
448 if not query_entities:
449 return []
451 # Run per-entity searches in parallel, dedupe by fact_id, preserve
452 # first-seen ordering (later RRF reranks anyway).
453 tasks = [
454 asyncio.create_task(
455 _safe_call(
456 f"link_expansion[{ent}]",
457 store.search_facts_by_entity(
458 bank_id, ent, top_k=top_k_per_entity, document_id=None,
459 fact_type=fact_type, # M34-3
460 ),
461 ),
462 name=f"fact_recall.link_expansion.{ent[:32]}",
463 )
464 for ent in query_entities
465 ]
466 per_entity_results = await asyncio.gather(*tasks)
468 seen: set[str] = set()
469 merged: list[PageIndexFact] = []
470 # Interleave: take 1 from each entity's list per round (round-robin)
471 # to give every query entity fair representation, not just the
472 # first one's top-K.
473 max_len = max((len(r) for r in per_entity_results), default=0)
474 for i in range(max_len):
475 for ent_hits in per_entity_results:
476 if i >= len(ent_hits):
477 continue
478 fact = ent_hits[i]
479 fid = getattr(fact, "fact_id", None) or getattr(fact, "id", None)
480 if fid is None or fid in seen:
481 continue
482 seen.add(fid)
483 merged.append(fact)
484 return merged
487def _rrf_fuse_fact_hits(
488 ranked_lists: list[list[PageIndexFact]],
489 *,
490 k: int = DEFAULT_RRF_K,
491) -> list[PageIndexFact]:
492 """Reciprocal Rank Fusion over fact-hit lists, dedupe by ``fact_id``.
494 Each rank-r appearance contributes ``1.0 / (k + r + 1)`` to the
495 fused score (r is 0-indexed). Hits without a ``fact_id`` are
496 dropped (we cannot dedupe them safely).
498 Returns facts ordered by descending fused score. When a fact
499 appears in multiple ranked lists, its highest-scoring instance is
500 kept as the representative.
501 """
502 fused_score: dict[str, float] = {}
503 representative: dict[str, PageIndexFact] = {}
505 for ranked_list in ranked_lists:
506 if not ranked_list:
507 continue
508 for rank, hit in enumerate(ranked_list):
509 fid = getattr(hit, "fact_id", None)
510 if fid is None:
511 continue
512 fused_score[fid] = fused_score.get(fid, 0.0) + 1.0 / (k + rank + 1)
513 # Keep the first-seen hit as the representative; ranking is
514 # determined by the fused score below, so the per-instance
515 # rank within its source list doesn't matter for ordering.
516 if fid not in representative:
517 representative[fid] = hit
519 sorted_ids = sorted(
520 fused_score.keys(),
521 key=lambda fid: fused_score[fid],
522 reverse=True,
523 )
524 return [representative[fid] for fid in sorted_ids]
527def _rrf_fuse_fact_hits_weighted(
528 ranked_lists_with_weights: list[tuple[list[PageIndexFact], float]],
529 *,
530 k: int = DEFAULT_RRF_K,
531) -> list[PageIndexFact]:
532 """M34-2 — weighted RRF over fact-hit lists.
534 Each rank-r appearance contributes ``weight / (k + r + 1)`` to the
535 fused score. A list with weight 0.0 is skipped entirely (no items
536 contribute, no items added to the candidate pool). Negative weights
537 are a caller bug and raise ``ValueError``.
539 Mirrors the contract of
540 :func:`astrocyte.pipeline.fusion.weighted_rrf_fusion` but operates
541 on :class:`PageIndexFact` instances directly so we don't pay the
542 ScoredItem conversion cost. The two functions stay in sync — if
543 you change one, change the other.
545 When all weights are 1.0 this is mathematically identical to
546 :func:`_rrf_fuse_fact_hits`.
547 """
548 fused_score: dict[str, float] = {}
549 representative: dict[str, PageIndexFact] = {}
551 for ranked_list, weight in ranked_lists_with_weights:
552 if weight < 0.0:
553 raise ValueError(
554 f"RRF weight must be >= 0.0; got {weight!r}. Pass weight=0.0 "
555 "to mute a channel.",
556 )
557 if weight == 0.0:
558 continue
559 if not ranked_list:
560 continue
561 for rank, hit in enumerate(ranked_list):
562 fid = getattr(hit, "fact_id", None)
563 if fid is None:
564 continue
565 fused_score[fid] = fused_score.get(fid, 0.0) + weight / (k + rank + 1)
566 if fid not in representative:
567 representative[fid] = hit
569 sorted_ids = sorted(
570 fused_score.keys(),
571 key=lambda fid: fused_score[fid],
572 reverse=True,
573 )
574 return [representative[fid] for fid in sorted_ids]
577def _is_episodic_enabled(config: AstrocyteConfig) -> bool:
578 """Gate: ``config.episodic_extract.enabled`` must be True."""
579 sub = getattr(config, "episodic_extract", None)
580 if sub is None:
581 return False
582 return bool(getattr(sub, "enabled", False))
585def _question_has_episodic_cue(query: str) -> bool:
586 """Gate: question must match a known episodic cue regex.
588 Lazy-imported so module load doesn't pay the cost when episodic
589 extraction is disabled (typical M17 baseline). On ImportError
590 (episodic_extract module unavailable for some reason), return
591 False — no cue means no episodic branch.
592 """
593 try:
594 from astrocyte.pipeline.episodic_extract import ( # noqa: PLC0415
595 question_has_episodic_cue,
596 )
597 except ImportError:
598 return False
599 return bool(question_has_episodic_cue(query))