Coverage for astrocyte/pipeline/pageindex_pipeline.py: 77%
81 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""M32 — PageIndex recall pipeline (unifies bench + production stacks).
3Implements the same ``async def recall(request: RecallRequest) -> RecallResult``
4contract as the legacy ``orchestrator.recall()`` so it slots into the
5existing ``ProviderDispatcher`` without changes to ``Astrocyte.recall()``.
7Why this exists
8---------------
10Before M32, Astrocyte had two parallel retrieval stacks:
12- **PageIndex** (bench): ``astrocyte_client.search()`` → ``fact_recall`` +
13 ``section_recall`` + cross-encoder rerank → ``PostgresPageIndexStore``.
14 Every bench score since M14 measured this stack. All the M14-M31 cycle
15 work (RRF fusion, fact↔chunk pairing, per-Q-type prompts, M27 fields,
16 M28-M29 coreference, M30 parallelization, M31 session_filter +
17 event_date) lives here.
18- **Vector store** (public ``Astrocyte.recall()``): orchestrator → vector
19 store + graph store. M9-era plumbing; none of the M14-M31 improvements
20 ever landed here.
22The v0.15.0 ship audit surfaced the drift: README badges describe the
23PageIndex stack but users calling ``Astrocyte.recall()`` get the
24vector-store stack. M32 closes that gap by making PageIndex the
25production recall pipeline, so future bench scores actually represent
26what ``pip install astrocyte`` produces.
28Design notes
29------------
31- **No new retrieval logic.** This pipeline is a thin adapter around
32 the existing ``fact_recall`` + ``section_recall`` primitives + the
33 bench-validated rerank. It's not a re-implementation; it's a re-shape
34 of the result type.
35- **Result-shape adapter.** Fact-grain and section-grain candidates
36 become ``MemoryHit`` instances with ``memory_layer`` set so downstream
37 consumers can tell them apart.
38- **Honours ``RecallRequest`` fields** the bench has historically
39 threaded through: ``session_id`` (M31 Fix 2), ``time_range`` (M9
40 temporal filter), ``fact_types``, ``max_results``,
41 ``query_reference_date``, ``as_of``. Multi-bank ``banks=[...]``
42 routing stays on the legacy ``orchestrator.recall()`` for now —
43 PageIndex is single-bank/single-document per call.
44"""
46from __future__ import annotations
48import asyncio
49import logging
50import time
51from typing import TYPE_CHECKING, Any
53from astrocyte.types import (
54 MemoryHit,
55 RecallRequest,
56 RecallResult,
57 RecallTrace,
58)
60if TYPE_CHECKING:
61 from astrocyte.config import AstrocyteConfig
62 from astrocyte.provider import LLMProvider, PageIndexStore
64_logger = logging.getLogger("astrocyte.pipeline.pageindex_pipeline")
67class PageIndexPipeline:
68 """Recall pipeline that drives the PageIndex stack.
70 Implements the ``async recall(request) -> RecallResult`` contract
71 so ``ProviderDispatcher`` treats it like any other pipeline.
72 """
74 def __init__(
75 self,
76 store: "PageIndexStore",
77 embedding_provider: "LLMProvider",
78 config: "AstrocyteConfig | None" = None,
79 *,
80 document_resolver: Any | None = None,
81 ) -> None:
82 """
83 Args:
84 store: PageIndex SPI handle (Postgres or in-memory).
85 embedding_provider: For query-embedding the search text.
86 config: Astrocyte config; used to gate optional retrieval
87 features (episodic, link-expansion). Defaults to a fresh
88 ``AstrocyteConfig()``.
89 document_resolver: Optional callable mapping
90 ``(bank_id) -> document_id | None``. Used when the caller
91 wants single-document scope (matches the bench's
92 ``self._doc_ids[user_id]`` pattern). When ``None``, the
93 pipeline runs bank-wide (``document_id=None`` to the SPI).
94 """
95 if config is None:
96 from astrocyte.config import AstrocyteConfig # noqa: PLC0415
98 config = AstrocyteConfig()
99 self._store = store
100 self._provider = embedding_provider
101 self._config = config
102 self._document_resolver = document_resolver
104 async def recall(self, request: RecallRequest) -> RecallResult:
105 """Run PageIndex retrieval + rerank, return RecallResult."""
106 t0 = time.monotonic()
108 # 1. Embed query.
109 try:
110 embeds = await self._provider.embed([request.query])
111 query_vec = embeds[0] if embeds else []
112 except Exception as exc: # noqa: BLE001
113 _logger.warning("pageindex_pipeline: embed failed: %s", exc)
114 return RecallResult(hits=[], total_available=0, truncated=False)
116 if not query_vec:
117 return RecallResult(hits=[], total_available=0, truncated=False)
119 # 2. Resolve optional document scope.
120 document_id: str | None = None
121 if self._document_resolver is not None and request.bank_id:
122 try:
123 document_id = self._document_resolver(request.bank_id)
124 except Exception: # noqa: BLE001
125 document_id = None
127 # 3. Query analyzer for temporal range (only when caller didn't
128 # provide one). Resolves relative phrases like "last week" using
129 # ``query_reference_date`` (or ``as_of``) as the anchor.
130 date_range = request.time_range
131 if date_range is None:
132 try:
133 from astrocyte.pipeline.query_analyzer import ( # noqa: PLC0415
134 analyze_query,
135 )
137 anchor = request.query_reference_date or request.as_of
138 analysis = await analyze_query(
139 request.query,
140 reference_date=anchor,
141 llm_provider=None,
142 allow_llm_fallback=False,
143 allow_temporal_expansion=True,
144 )
145 if (
146 analysis.temporal_constraint
147 and analysis.temporal_constraint.is_bounded()
148 ):
149 date_range = (
150 analysis.temporal_constraint.start_date,
151 analysis.temporal_constraint.end_date,
152 )
153 except Exception: # noqa: BLE001
154 date_range = None
156 # 4. Parallel fact + section recall (M30-L1 pattern).
157 from astrocyte.pipeline.fact_recall import fact_recall # noqa: PLC0415
158 from astrocyte.pipeline.section_recall import ( # noqa: PLC0415
159 section_recall,
160 )
162 fact_coro = fact_recall(
163 store=self._store,
164 bank_id=request.bank_id,
165 document_id=document_id,
166 query=request.query,
167 query_embedding=query_vec,
168 config=self._config,
169 temporal_range=date_range,
170 session_filter=request.session_id, # M31 Fix 2
171 # M34-4 — per-fact-type segmentation when caller specifies
172 # which types to retrieve; default None preserves single-pool.
173 fact_types=request.fact_types,
174 # M35-2 — token budget cap. None → no cap (legacy
175 # callers); otherwise tiktoken-counted pack from
176 # token_budget.pack_to_budget.
177 max_tokens=request.max_tokens,
178 )
180 recall_mode = "temporal" if date_range is not None else "single-hop"
181 section_coro = section_recall(
182 store=self._store,
183 bank_id=request.bank_id,
184 question=request.query,
185 mode=recall_mode,
186 embedding_provider=self._provider,
187 date_range=date_range,
188 wiki_enabled=False, # Wiki tier handled by orchestrator's _try_wiki_tier
189 session_filter=request.session_id, # M31 Fix 2
190 )
192 fact_result, section_result = await asyncio.gather(
193 fact_coro, section_coro, return_exceptions=True,
194 )
196 fact_hits: list = []
197 if isinstance(fact_result, BaseException):
198 _logger.warning(
199 "pageindex_pipeline: fact_recall failed: %s", fact_result,
200 )
201 else:
202 fact_hits = fact_result
204 section_recall_result = None
205 if isinstance(section_result, BaseException):
206 _logger.warning(
207 "pageindex_pipeline: section_recall failed: %s",
208 section_result,
209 )
210 else:
211 section_recall_result = section_result
213 # 5. Convert to MemoryHit. fact_types filter is applied here so
214 # downstream consumers see only the requested types.
215 hits = self._to_memory_hits(
216 fact_hits=fact_hits,
217 section_result=section_recall_result,
218 request=request,
219 )
221 # 6. Honour max_results truncation; report truncated flag.
222 total = len(hits)
223 truncated = total > request.max_results
224 hits = hits[: request.max_results]
226 elapsed_ms = (time.monotonic() - t0) * 1000.0
227 trace = RecallTrace(
228 strategies_used=["fact_recall", "section_recall"],
229 total_candidates=total,
230 fusion_method="rrf+rerank",
231 latency_ms=elapsed_ms,
232 )
233 return RecallResult(
234 hits=hits,
235 total_available=total,
236 truncated=truncated,
237 trace=trace,
238 )
240 def _to_memory_hits(
241 self,
242 *,
243 fact_hits: list,
244 section_result: Any,
245 request: RecallRequest,
246 ) -> list[MemoryHit]:
247 """Shape fact/section results into the public ``MemoryHit`` list.
249 Fact-grain → ``memory_layer="fact"``; section-grain →
250 ``memory_layer="section"``. M31 ``event_date`` is preferred over
251 ``occurred_start`` for the ``occurred_at`` surface field, since
252 ``event_date`` is the deterministically-resolved canonical date
253 for the fact's primary event.
254 """
255 # Apply fact_types filter early.
256 if request.fact_types:
257 wanted = set(request.fact_types)
258 fact_hits = [
259 fh for fh in fact_hits if (fh.fact_type or "") in wanted
260 ]
262 out: list[MemoryHit] = []
263 for fh in fact_hits:
264 occurred = getattr(fh, "event_date", None) or fh.occurred_start
265 # M32 — stash the fact-grain metadata (M27 confidence_score,
266 # M27 mentioned_at, M31 Fix 4 event_date, line_num for
267 # bench-side source-chunk rendering) into the metadata dict
268 # so downstream consumers (bench harness, production agents)
269 # can read them without adding fields to MemoryHit's public
270 # surface. ``None`` values are still included so consumers
271 # can do a single `hit.metadata.get("confidence_score")`
272 # check without an attribute branch.
273 meta: dict[str, Any] = {
274 "grain": "fact",
275 "confidence_score": getattr(fh, "confidence_score", None),
276 "mentioned_at": getattr(fh, "mentioned_at", None),
277 "event_date": getattr(fh, "event_date", None),
278 "line_num": fh.line_num,
279 "document_id": fh.document_id,
280 "speaker": getattr(fh, "speaker", None),
281 "entities": list(getattr(fh, "entities", None) or []),
282 }
283 out.append(
284 MemoryHit(
285 text=fh.text,
286 score=fh.score,
287 fact_type=fh.fact_type,
288 metadata=meta,
289 occurred_at=occurred,
290 source=None,
291 memory_id=fh.fact_id,
292 bank_id=request.bank_id,
293 memory_layer="fact",
294 chunk_id=fh.chunk_id,
295 )
296 )
298 # Section-grain: ``section_recall`` returns a ``SectionRecallResult``
299 # with ``fused`` (FusedHit list) and ``wiki_hits``. We surface the
300 # fused section hits as MemoryHit(memory_layer="section") so the
301 # public API mirrors the bench's multi-grain shape. ``wiki_hits``
302 # are skipped here — the orchestrator's _try_wiki_tier handles them
303 # at a different layer.
304 if section_result is not None and getattr(section_result, "fused", None):
305 for sh in section_result.fused:
306 title = getattr(sh, "title", None) or ""
307 text = title # PageIndex stores summary; bench excerpts at search-time
308 meta_section: dict[str, Any] = {
309 "grain": "section",
310 "line_num": sh.line_num,
311 "document_id": sh.document_id,
312 "session_date": getattr(sh, "session_date", None),
313 }
314 out.append(
315 MemoryHit(
316 text=text,
317 score=getattr(sh, "rrf_score", 0.0),
318 fact_type=None,
319 metadata=meta_section,
320 occurred_at=None,
321 source=None,
322 memory_id=f"section:{sh.document_id}:{sh.line_num}",
323 bank_id=request.bank_id,
324 memory_layer="section",
325 )
326 )
328 # Sort by score desc so the caller's max_results cut picks the
329 # top hits across both grains.
330 out.sort(key=lambda h: h.score, reverse=True)
331 return out
334__all__ = ["PageIndexPipeline"]