Coverage for astrocyte/pipeline/pageindex_pipeline.py: 77%

81 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""M32 — PageIndex recall pipeline (unifies bench + production stacks). 

2 

3Implements the same ``async def recall(request: RecallRequest) -> RecallResult`` 

4contract as the legacy ``orchestrator.recall()`` so it slots into the 

5existing ``ProviderDispatcher`` without changes to ``Astrocyte.recall()``. 

6 

7Why this exists 

8--------------- 

9 

10Before M32, Astrocyte had two parallel retrieval stacks: 

11 

12- **PageIndex** (bench): ``astrocyte_client.search()`` → ``fact_recall`` + 

13 ``section_recall`` + cross-encoder rerank → ``PostgresPageIndexStore``. 

14 Every bench score since M14 measured this stack. All the M14-M31 cycle 

15 work (RRF fusion, fact↔chunk pairing, per-Q-type prompts, M27 fields, 

16 M28-M29 coreference, M30 parallelization, M31 session_filter + 

17 event_date) lives here. 

18- **Vector store** (public ``Astrocyte.recall()``): orchestrator → vector 

19 store + graph store. M9-era plumbing; none of the M14-M31 improvements 

20 ever landed here. 

21 

22The v0.15.0 ship audit surfaced the drift: README badges describe the 

23PageIndex stack but users calling ``Astrocyte.recall()`` get the 

24vector-store stack. M32 closes that gap by making PageIndex the 

25production recall pipeline, so future bench scores actually represent 

26what ``pip install astrocyte`` produces. 

27 

28Design notes 

29------------ 

30 

31- **No new retrieval logic.** This pipeline is a thin adapter around 

32 the existing ``fact_recall`` + ``section_recall`` primitives + the 

33 bench-validated rerank. It's not a re-implementation; it's a re-shape 

34 of the result type. 

35- **Result-shape adapter.** Fact-grain and section-grain candidates 

36 become ``MemoryHit`` instances with ``memory_layer`` set so downstream 

37 consumers can tell them apart. 

38- **Honours ``RecallRequest`` fields** the bench has historically 

39 threaded through: ``session_id`` (M31 Fix 2), ``time_range`` (M9 

40 temporal filter), ``fact_types``, ``max_results``, 

41 ``query_reference_date``, ``as_of``. Multi-bank ``banks=[...]`` 

42 routing stays on the legacy ``orchestrator.recall()`` for now — 

43 PageIndex is single-bank/single-document per call. 

44""" 

45 

46from __future__ import annotations 

47 

48import asyncio 

49import logging 

50import time 

51from typing import TYPE_CHECKING, Any 

52 

53from astrocyte.types import ( 

54 MemoryHit, 

55 RecallRequest, 

56 RecallResult, 

57 RecallTrace, 

58) 

59 

60if TYPE_CHECKING: 

61 from astrocyte.config import AstrocyteConfig 

62 from astrocyte.provider import LLMProvider, PageIndexStore 

63 

64_logger = logging.getLogger("astrocyte.pipeline.pageindex_pipeline") 

65 

66 

67class PageIndexPipeline: 

68 """Recall pipeline that drives the PageIndex stack. 

69 

70 Implements the ``async recall(request) -> RecallResult`` contract 

71 so ``ProviderDispatcher`` treats it like any other pipeline. 

72 """ 

73 

74 def __init__( 

75 self, 

76 store: "PageIndexStore", 

77 embedding_provider: "LLMProvider", 

78 config: "AstrocyteConfig | None" = None, 

79 *, 

80 document_resolver: Any | None = None, 

81 ) -> None: 

82 """ 

83 Args: 

84 store: PageIndex SPI handle (Postgres or in-memory). 

85 embedding_provider: For query-embedding the search text. 

86 config: Astrocyte config; used to gate optional retrieval 

87 features (episodic, link-expansion). Defaults to a fresh 

88 ``AstrocyteConfig()``. 

89 document_resolver: Optional callable mapping 

90 ``(bank_id) -> document_id | None``. Used when the caller 

91 wants single-document scope (matches the bench's 

92 ``self._doc_ids[user_id]`` pattern). When ``None``, the 

93 pipeline runs bank-wide (``document_id=None`` to the SPI). 

94 """ 

95 if config is None: 

96 from astrocyte.config import AstrocyteConfig # noqa: PLC0415 

97 

98 config = AstrocyteConfig() 

99 self._store = store 

100 self._provider = embedding_provider 

101 self._config = config 

102 self._document_resolver = document_resolver 

103 

104 async def recall(self, request: RecallRequest) -> RecallResult: 

105 """Run PageIndex retrieval + rerank, return RecallResult.""" 

106 t0 = time.monotonic() 

107 

108 # 1. Embed query. 

109 try: 

110 embeds = await self._provider.embed([request.query]) 

111 query_vec = embeds[0] if embeds else [] 

112 except Exception as exc: # noqa: BLE001 

113 _logger.warning("pageindex_pipeline: embed failed: %s", exc) 

114 return RecallResult(hits=[], total_available=0, truncated=False) 

115 

116 if not query_vec: 

117 return RecallResult(hits=[], total_available=0, truncated=False) 

118 

119 # 2. Resolve optional document scope. 

120 document_id: str | None = None 

121 if self._document_resolver is not None and request.bank_id: 

122 try: 

123 document_id = self._document_resolver(request.bank_id) 

124 except Exception: # noqa: BLE001 

125 document_id = None 

126 

127 # 3. Query analyzer for temporal range (only when caller didn't 

128 # provide one). Resolves relative phrases like "last week" using 

129 # ``query_reference_date`` (or ``as_of``) as the anchor. 

130 date_range = request.time_range 

131 if date_range is None: 

132 try: 

133 from astrocyte.pipeline.query_analyzer import ( # noqa: PLC0415 

134 analyze_query, 

135 ) 

136 

137 anchor = request.query_reference_date or request.as_of 

138 analysis = await analyze_query( 

139 request.query, 

140 reference_date=anchor, 

141 llm_provider=None, 

142 allow_llm_fallback=False, 

143 allow_temporal_expansion=True, 

144 ) 

145 if ( 

146 analysis.temporal_constraint 

147 and analysis.temporal_constraint.is_bounded() 

148 ): 

149 date_range = ( 

150 analysis.temporal_constraint.start_date, 

151 analysis.temporal_constraint.end_date, 

152 ) 

153 except Exception: # noqa: BLE001 

154 date_range = None 

155 

156 # 4. Parallel fact + section recall (M30-L1 pattern). 

157 from astrocyte.pipeline.fact_recall import fact_recall # noqa: PLC0415 

158 from astrocyte.pipeline.section_recall import ( # noqa: PLC0415 

159 section_recall, 

160 ) 

161 

162 fact_coro = fact_recall( 

163 store=self._store, 

164 bank_id=request.bank_id, 

165 document_id=document_id, 

166 query=request.query, 

167 query_embedding=query_vec, 

168 config=self._config, 

169 temporal_range=date_range, 

170 session_filter=request.session_id, # M31 Fix 2 

171 # M34-4 — per-fact-type segmentation when caller specifies 

172 # which types to retrieve; default None preserves single-pool. 

173 fact_types=request.fact_types, 

174 # M35-2 — token budget cap. None → no cap (legacy 

175 # callers); otherwise tiktoken-counted pack from 

176 # token_budget.pack_to_budget. 

177 max_tokens=request.max_tokens, 

178 ) 

179 

180 recall_mode = "temporal" if date_range is not None else "single-hop" 

181 section_coro = section_recall( 

182 store=self._store, 

183 bank_id=request.bank_id, 

184 question=request.query, 

185 mode=recall_mode, 

186 embedding_provider=self._provider, 

187 date_range=date_range, 

188 wiki_enabled=False, # Wiki tier handled by orchestrator's _try_wiki_tier 

189 session_filter=request.session_id, # M31 Fix 2 

190 ) 

191 

192 fact_result, section_result = await asyncio.gather( 

193 fact_coro, section_coro, return_exceptions=True, 

194 ) 

195 

196 fact_hits: list = [] 

197 if isinstance(fact_result, BaseException): 

198 _logger.warning( 

199 "pageindex_pipeline: fact_recall failed: %s", fact_result, 

200 ) 

201 else: 

202 fact_hits = fact_result 

203 

204 section_recall_result = None 

205 if isinstance(section_result, BaseException): 

206 _logger.warning( 

207 "pageindex_pipeline: section_recall failed: %s", 

208 section_result, 

209 ) 

210 else: 

211 section_recall_result = section_result 

212 

213 # 5. Convert to MemoryHit. fact_types filter is applied here so 

214 # downstream consumers see only the requested types. 

215 hits = self._to_memory_hits( 

216 fact_hits=fact_hits, 

217 section_result=section_recall_result, 

218 request=request, 

219 ) 

220 

221 # 6. Honour max_results truncation; report truncated flag. 

222 total = len(hits) 

223 truncated = total > request.max_results 

224 hits = hits[: request.max_results] 

225 

226 elapsed_ms = (time.monotonic() - t0) * 1000.0 

227 trace = RecallTrace( 

228 strategies_used=["fact_recall", "section_recall"], 

229 total_candidates=total, 

230 fusion_method="rrf+rerank", 

231 latency_ms=elapsed_ms, 

232 ) 

233 return RecallResult( 

234 hits=hits, 

235 total_available=total, 

236 truncated=truncated, 

237 trace=trace, 

238 ) 

239 

240 def _to_memory_hits( 

241 self, 

242 *, 

243 fact_hits: list, 

244 section_result: Any, 

245 request: RecallRequest, 

246 ) -> list[MemoryHit]: 

247 """Shape fact/section results into the public ``MemoryHit`` list. 

248 

249 Fact-grain → ``memory_layer="fact"``; section-grain → 

250 ``memory_layer="section"``. M31 ``event_date`` is preferred over 

251 ``occurred_start`` for the ``occurred_at`` surface field, since 

252 ``event_date`` is the deterministically-resolved canonical date 

253 for the fact's primary event. 

254 """ 

255 # Apply fact_types filter early. 

256 if request.fact_types: 

257 wanted = set(request.fact_types) 

258 fact_hits = [ 

259 fh for fh in fact_hits if (fh.fact_type or "") in wanted 

260 ] 

261 

262 out: list[MemoryHit] = [] 

263 for fh in fact_hits: 

264 occurred = getattr(fh, "event_date", None) or fh.occurred_start 

265 # M32 — stash the fact-grain metadata (M27 confidence_score, 

266 # M27 mentioned_at, M31 Fix 4 event_date, line_num for 

267 # bench-side source-chunk rendering) into the metadata dict 

268 # so downstream consumers (bench harness, production agents) 

269 # can read them without adding fields to MemoryHit's public 

270 # surface. ``None`` values are still included so consumers 

271 # can do a single `hit.metadata.get("confidence_score")` 

272 # check without an attribute branch. 

273 meta: dict[str, Any] = { 

274 "grain": "fact", 

275 "confidence_score": getattr(fh, "confidence_score", None), 

276 "mentioned_at": getattr(fh, "mentioned_at", None), 

277 "event_date": getattr(fh, "event_date", None), 

278 "line_num": fh.line_num, 

279 "document_id": fh.document_id, 

280 "speaker": getattr(fh, "speaker", None), 

281 "entities": list(getattr(fh, "entities", None) or []), 

282 } 

283 out.append( 

284 MemoryHit( 

285 text=fh.text, 

286 score=fh.score, 

287 fact_type=fh.fact_type, 

288 metadata=meta, 

289 occurred_at=occurred, 

290 source=None, 

291 memory_id=fh.fact_id, 

292 bank_id=request.bank_id, 

293 memory_layer="fact", 

294 chunk_id=fh.chunk_id, 

295 ) 

296 ) 

297 

298 # Section-grain: ``section_recall`` returns a ``SectionRecallResult`` 

299 # with ``fused`` (FusedHit list) and ``wiki_hits``. We surface the 

300 # fused section hits as MemoryHit(memory_layer="section") so the 

301 # public API mirrors the bench's multi-grain shape. ``wiki_hits`` 

302 # are skipped here — the orchestrator's _try_wiki_tier handles them 

303 # at a different layer. 

304 if section_result is not None and getattr(section_result, "fused", None): 

305 for sh in section_result.fused: 

306 title = getattr(sh, "title", None) or "" 

307 text = title # PageIndex stores summary; bench excerpts at search-time 

308 meta_section: dict[str, Any] = { 

309 "grain": "section", 

310 "line_num": sh.line_num, 

311 "document_id": sh.document_id, 

312 "session_date": getattr(sh, "session_date", None), 

313 } 

314 out.append( 

315 MemoryHit( 

316 text=text, 

317 score=getattr(sh, "rrf_score", 0.0), 

318 fact_type=None, 

319 metadata=meta_section, 

320 occurred_at=None, 

321 source=None, 

322 memory_id=f"section:{sh.document_id}:{sh.line_num}", 

323 bank_id=request.bank_id, 

324 memory_layer="section", 

325 ) 

326 ) 

327 

328 # Sort by score desc so the caller's max_results cut picks the 

329 # top hits across both grains. 

330 out.sort(key=lambda h: h.score, reverse=True) 

331 return out 

332 

333 

334__all__ = ["PageIndexPipeline"]