Coverage for astrocyte/pipeline/retrieval.py: 96%

143 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Parallel multi-strategy retrieval — runs concurrent searches across stores. 

2 

3Async (I/O-bound). See docs/_design/built-in-pipeline.md section 3. 

4 

5Strategies fused via RRF: 

6 

7- ``semantic`` — dense vector similarity (always runs). 

8- ``keyword`` — BM25 full-text (runs when ``document_store`` is configured). 

9- ``graph`` — entity-graph neighbor traversal (runs when ``graph_store`` is 

10 configured AND query entities are resolved). 

11- ``temporal`` — recency-ranked list of bank vectors (runs when the store 

12 exposes ``list_vectors`` AND the strategy is enabled via 

13 ``enable_temporal=True``). Rescues recently-retained memories that 

14 lose the semantic cutoff to older near-matches. Inspired by Hindsight's 

15 4-way parallel retrieval (see 

16 ``docs/_design/platform-positioning.md`` §Mystique). 

17""" 

18 

19from __future__ import annotations 

20 

21import asyncio 

22import logging 

23import math 

24import time 

25from datetime import datetime, timezone 

26from typing import TYPE_CHECKING, Any 

27 

28from astrocyte.pipeline.fusion import ScoredItem 

29 

30if TYPE_CHECKING: 

31 from astrocyte.provider import DocumentStore, GraphStore, VectorStore 

32 from astrocyte.types import VectorFilters, VectorItem 

33 

34logger = logging.getLogger("astrocyte.retrieval") 

35 

36#: Cap on how many vectors the temporal strategy will scan per recall. 

37#: For large banks we can't enumerate everything on every query — 500 is a 

38#: reasonable default that still surfaces recent writes without crushing 

39#: the store. Operators can lower this via ``max_temporal_scan`` in the 

40#: orchestrator for hot paths. 

41DEFAULT_TEMPORAL_SCAN_CAP = 500 

42 

43#: Half-life for the temporal score's exponential decay. A memory this many 

44#: days old scores 0.5; older memories fall off; fresher memories climb 

45#: toward 1.0. Tuned to "a week feels recent, a month feels old" for 

46#: conversational memory workloads. Configurable per orchestrator. 

47DEFAULT_TEMPORAL_HALF_LIFE_DAYS = 7.0 

48 

49 

50async def parallel_retrieve( 

51 query_vector: list[float], 

52 query_text: str, 

53 bank_id: str, 

54 vector_store: VectorStore, 

55 graph_store: GraphStore | None = None, 

56 document_store: DocumentStore | None = None, 

57 entity_ids: list[str] | None = None, 

58 limit: int = 30, 

59 filters: VectorFilters | None = None, 

60 *, 

61 enable_temporal: bool = True, 

62 temporal_scan_cap: int = DEFAULT_TEMPORAL_SCAN_CAP, 

63 temporal_half_life_days: float = DEFAULT_TEMPORAL_HALF_LIFE_DAYS, 

64 hyde_vector: list[float] | None = None, 

65 strategy_timings_ms: dict[str, float] | None = None, 

66 strategy_candidate_counts: dict[str, int] | None = None, 

67 use_bm25_idf: bool = False, 

68) -> dict[str, list[ScoredItem]]: 

69 """Run parallel retrieval across all configured stores. 

70 

71 Returns a dict of ``{strategy_name: list[ScoredItem]}``. Strategies run 

72 concurrently via ``asyncio.gather``; a failure in one strategy never 

73 blocks the others. 

74 

75 Args: 

76 enable_temporal: Gate the temporal strategy. Default on. Turn off 

77 for workloads where recency is not a signal (static document 

78 corpora). Requires ``vector_store.list_vectors``. 

79 temporal_scan_cap: Cap on vectors scanned per recall for temporal 

80 ranking. Guards against O(bank) cost on large stores. 

81 temporal_half_life_days: Exponential decay half-life for temporal 

82 score. Tune shorter (e.g. 1.0) for fast-moving chat workloads, 

83 longer (e.g. 30.0) for slower knowledge bases. 

84 hyde_vector: Optional pre-computed HyDE embedding (hypothetical 

85 document embedding). When provided, an additional ``"hyde"`` 

86 strategy runs semantic search with this vector and its results 

87 are fused via RRF alongside the standard ``"semantic"`` strategy. 

88 Generate with :func:`astrocyte.pipeline.hyde.generate_hyde_vector`. 

89 use_bm25_idf: When ``True`` AND the document store advertises 

90 ``search_fulltext_bm25`` (PostgresStore via migration 013), 

91 route the keyword strategy through the BM25-with-IDF 

92 materialized-view path instead of the classic ``ts_rank_cd`` 

93 path. Skips the hybrid CTE (which uses ts_rank_cd) — keyword 

94 and semantic run as separate strategies. Default ``False``. 

95 """ 

96 tasks: dict[str, asyncio.Task[tuple[list[ScoredItem], float]]] = {} 

97 hybrid_task: asyncio.Task[tuple[dict[str, list[ScoredItem]], float]] | None = None 

98 

99 hybrid_search = getattr(vector_store, "search_hybrid_semantic_bm25", None) 

100 use_hybrid_search = ( 

101 callable(hybrid_search) 

102 and document_store is vector_store 

103 and query_text.strip() 

104 and hyde_vector is None 

105 # BM25-IDF requires its own keyword path (the materialized-view 

106 # query). When enabled, skip the hybrid CTE so keyword goes 

107 # through ``_keyword_search`` → ``search_fulltext_bm25``. 

108 and not use_bm25_idf 

109 ) 

110 

111 # PostgresStore can answer semantic and keyword retrieval in one SQL round trip. 

112 # Other adapters keep the portable per-strategy path. 

113 if use_hybrid_search: 

114 hybrid_task = asyncio.create_task( 

115 _timed_hybrid_semantic_keyword_search( 

116 vector_store, 

117 query_vector, 

118 query_text, 

119 bank_id, 

120 limit, 

121 filters, 

122 ) 

123 ) 

124 else: 

125 tasks["semantic"] = asyncio.create_task( 

126 _timed(_semantic_search(vector_store, query_vector, bank_id, limit, filters)) 

127 ) 

128 

129 # HyDE (R1): second semantic pass with hypothetical-document embedding. 

130 # Runs concurrently with the standard semantic strategy; RRF fusion merges 

131 # both result sets. No-op when hyde_vector is None (feature disabled or 

132 # generation failed upstream). 

133 if hyde_vector is not None: 

134 tasks["hyde"] = asyncio.create_task( 

135 _timed(_semantic_search(vector_store, hyde_vector, bank_id, limit, filters)) 

136 ) 

137 

138 # Graph search if store configured and entities found 

139 if graph_store and entity_ids: 

140 tasks["graph"] = asyncio.create_task(_timed(_graph_search(graph_store, entity_ids, bank_id, limit))) 

141 

142 # Full-text search if document store configured 

143 if document_store and not use_hybrid_search: 

144 tasks["keyword"] = asyncio.create_task( 

145 _timed(_keyword_search(document_store, query_text, bank_id, limit, use_bm25_idf=use_bm25_idf)) 

146 ) 

147 

148 # Temporal search if the vector store can enumerate. Capped scan keeps 

149 # cost bounded; rank by metadata[_created_at]/occurred_at recency decay. 

150 as_of = filters.as_of if filters is not None else None 

151 if enable_temporal and hasattr(vector_store, "list_vectors"): 

152 tasks["temporal"] = asyncio.create_task( 

153 _timed( 

154 _temporal_search( 

155 vector_store, 

156 bank_id, 

157 limit, 

158 scan_cap=temporal_scan_cap, 

159 half_life_days=temporal_half_life_days, 

160 as_of=as_of, 

161 filters=filters, 

162 ) 

163 ) 

164 ) 

165 

166 # Wait for all strategies 

167 results: dict[str, list[ScoredItem]] = {} 

168 if hybrid_task is not None: 

169 try: 

170 hybrid_results, elapsed_ms = await hybrid_task 

171 for name, items in hybrid_results.items(): 

172 results[name] = items 

173 if strategy_timings_ms is not None: 

174 strategy_timings_ms[name] = elapsed_ms 

175 if strategy_candidate_counts is not None: 

176 strategy_candidate_counts[name] = len(items) 

177 except Exception as exc: 

178 # Hybrid CTE failed (transient pool error, deadlock, lock-wait, 

179 # OOM, etc.). DON'T clobber semantic + keyword to [] — that 

180 # turns one transient DB hiccup into a recall failure for the 

181 # entire question. Fall back to running the same two strategies 

182 # as separate per-store calls (the portable path other adapters 

183 # use). Each fallback is isolated, so if e.g. semantic succeeds 

184 # but keyword fails, semantic still flows through. 

185 logger.warning( 

186 "retrieval strategy hybrid_semantic_bm25 failed (%s); falling back to per-strategy semantic + keyword", 

187 exc, 

188 ) 

189 sem_task = asyncio.create_task( 

190 _timed(_semantic_search(vector_store, query_vector, bank_id, limit, filters)) 

191 ) 

192 kw_task = asyncio.create_task(_timed(_keyword_search(document_store, query_text, bank_id, limit))) 

193 for name, fallback_task in (("semantic", sem_task), ("keyword", kw_task)): 

194 try: 

195 items, fallback_ms = await fallback_task 

196 results[name] = items 

197 if strategy_timings_ms is not None: 

198 strategy_timings_ms[name] = fallback_ms 

199 if strategy_candidate_counts is not None: 

200 strategy_candidate_counts[name] = len(items) 

201 except Exception as inner_exc: 

202 logger.warning( 

203 "fallback strategy %s also failed: %s", 

204 name, 

205 inner_exc, 

206 ) 

207 results[name] = [] 

208 if strategy_timings_ms is not None: 

209 strategy_timings_ms[name] = 0.0 

210 if strategy_candidate_counts is not None: 

211 strategy_candidate_counts[name] = 0 

212 

213 for name, task in tasks.items(): 

214 try: 

215 items, elapsed_ms = await task 

216 results[name] = items 

217 if strategy_timings_ms is not None: 

218 strategy_timings_ms[name] = elapsed_ms 

219 if strategy_candidate_counts is not None: 

220 strategy_candidate_counts[name] = len(items) 

221 except Exception as exc: # pragma: no cover — per-strategy isolation 

222 logger.warning("retrieval strategy %s failed: %s", name, exc) 

223 results[name] = [] # Strategy failure should not block others 

224 if strategy_timings_ms is not None: 

225 strategy_timings_ms[name] = 0.0 

226 if strategy_candidate_counts is not None: 

227 strategy_candidate_counts[name] = 0 

228 

229 return results 

230 

231 

232async def _timed(coro) -> tuple[list[ScoredItem], float]: 

233 start = time.perf_counter() 

234 result = await coro 

235 elapsed_ms = (time.perf_counter() - start) * 1000 

236 return result, elapsed_ms 

237 

238 

239async def _semantic_search( 

240 vector_store: VectorStore, 

241 query_vector: list[float], 

242 bank_id: str, 

243 limit: int, 

244 filters: VectorFilters | None, 

245) -> list[ScoredItem]: 

246 """Vector similarity search.""" 

247 hits = await vector_store.search_similar(query_vector, bank_id, limit=limit, filters=filters) 

248 return [ 

249 ScoredItem( 

250 id=h.id, 

251 text=h.text, 

252 score=h.score, 

253 fact_type=h.fact_type, 

254 metadata=h.metadata, 

255 tags=h.tags, 

256 occurred_at=h.occurred_at, 

257 retained_at=getattr(h, "retained_at", None), 

258 chunk_id=getattr(h, "chunk_id", None), 

259 ) 

260 for h in hits 

261 ] 

262 

263 

264async def _timed_hybrid_semantic_keyword_search( 

265 vector_store: VectorStore, 

266 query_vector: list[float], 

267 query_text: str, 

268 bank_id: str, 

269 limit: int, 

270 filters: VectorFilters | None, 

271) -> tuple[dict[str, list[ScoredItem]], float]: 

272 start = time.perf_counter() 

273 raw = await vector_store.search_hybrid_semantic_bm25( 

274 query_vector, 

275 query_text, 

276 bank_id, 

277 limit=limit, 

278 filters=filters, 

279 ) 

280 elapsed_ms = (time.perf_counter() - start) * 1000 

281 return _hybrid_hits_to_scored_items(raw), elapsed_ms 

282 

283 

284def _hybrid_hits_to_scored_items(raw: dict[str, list[Any]]) -> dict[str, list[ScoredItem]]: 

285 results: dict[str, list[ScoredItem]] = {"semantic": [], "keyword": []} 

286 for hit in raw.get("semantic", []): 

287 results["semantic"].append( 

288 ScoredItem( 

289 id=hit.id, 

290 text=hit.text, 

291 score=hit.score, 

292 fact_type=hit.fact_type, 

293 metadata=hit.metadata, 

294 tags=hit.tags, 

295 occurred_at=hit.occurred_at, 

296 retained_at=getattr(hit, "retained_at", None), 

297 chunk_id=getattr(hit, "chunk_id", None), 

298 ) 

299 ) 

300 for hit in raw.get("keyword", []): 

301 results["keyword"].append( 

302 ScoredItem( 

303 id=hit.document_id, 

304 text=hit.text, 

305 score=hit.score, 

306 metadata=hit.metadata, 

307 ) 

308 ) 

309 return results 

310 

311 

312async def _graph_search( 

313 graph_store: GraphStore, 

314 entity_ids: list[str], 

315 bank_id: str, 

316 limit: int, 

317) -> list[ScoredItem]: 

318 """Graph neighbor traversal.""" 

319 hits = await graph_store.query_neighbors(entity_ids, bank_id, max_depth=2, limit=limit) 

320 return [ 

321 ScoredItem( 

322 id=h.memory_id, 

323 text=h.text, 

324 score=h.score, 

325 fact_type=None, 

326 ) 

327 for h in hits 

328 ] 

329 

330 

331async def _keyword_search( 

332 document_store: DocumentStore, 

333 query_text: str, 

334 bank_id: str, 

335 limit: int, 

336 *, 

337 use_bm25_idf: bool = False, 

338) -> list[ScoredItem]: 

339 """Full-text search. 

340 

341 Routes through :meth:`PostgresStore.search_fulltext_bm25` (proper BM25 

342 with corpus IDF + length normalisation) when ``use_bm25_idf=True`` AND 

343 the store advertises that method; otherwise falls through to the 

344 classic :meth:`DocumentStore.search_fulltext` (``ts_rank_cd``). 

345 

346 Stores that don't expose ``search_fulltext_bm25`` (in_memory, 

347 elasticsearch adapter, etc.) silently use the classic path even when 

348 the flag is on — the flag is "use BM25-IDF if available," not "fail 

349 if unavailable." 

350 """ 

351 bm25_method = getattr(document_store, "search_fulltext_bm25", None) 

352 if use_bm25_idf and callable(bm25_method): 

353 hits = await bm25_method(query_text, bank_id, limit=limit) 

354 else: 

355 hits = await document_store.search_fulltext(query_text, bank_id, limit=limit) 

356 return [ 

357 ScoredItem( 

358 id=h.document_id, 

359 text=h.text, 

360 score=h.score, 

361 metadata=h.metadata, 

362 ) 

363 for h in hits 

364 ] 

365 

366 

367async def _temporal_search( 

368 vector_store: VectorStore, 

369 bank_id: str, 

370 limit: int, 

371 *, 

372 scan_cap: int, 

373 half_life_days: float, 

374 as_of: datetime | None = None, 

375 filters: VectorFilters | None = None, 

376) -> list[ScoredItem]: 

377 """Recency-ranked strategy. 

378 

379 Enumerates up to ``scan_cap`` vectors from ``bank_id`` via 

380 ``list_vectors``, ranks them by recency decay over 

381 ``metadata["_created_at"]`` (falling back to ``occurred_at`` when 

382 present), and returns the top ``limit`` as :class:`ScoredItem`. 

383 

384 The decay is exponential with half-life ``half_life_days``: 

385 ``score = 2 ** (-age_days / half_life_days)``. A memory exactly 

386 ``half_life_days`` old scores 0.5; a fresh memory approaches 1.0. 

387 

388 Vectors without a usable timestamp are skipped (not ranked at the 

389 bottom — they simply don't contribute, because "we don't know when" 

390 is a different signal than "we know it's old"). When every candidate 

391 lacks timestamps, the result is an empty list and RRF ignores the 

392 strategy entirely. 

393 """ 

394 recent_vectors = getattr(vector_store, "list_recent_vectors", None) 

395 if callable(recent_vectors): 

396 scanned = await recent_vectors(bank_id, limit=scan_cap, filters=filters) 

397 else: 

398 # Accumulate via paginated list_vectors so large banks don't blow memory. 

399 scanned = [] 

400 offset = 0 

401 batch = min(200, scan_cap) 

402 while len(scanned) < scan_cap: 

403 page = await vector_store.list_vectors(bank_id, offset=offset, limit=batch) 

404 if not page: 

405 break 

406 scanned.extend(page) 

407 if len(page) < batch: 

408 break # last page 

409 offset += batch 

410 

411 if not scanned: 

412 return [] 

413 

414 now = datetime.now(timezone.utc) 

415 scored: list[tuple[float, VectorItem]] = [] 

416 for item in scanned: 

417 # M9: time-travel filter — skip items retained after as_of 

418 if as_of is not None and item.retained_at is not None and item.retained_at > as_of: 

419 continue 

420 timestamp = _extract_timestamp(item) 

421 if timestamp is None: 

422 continue 

423 age_days = max((now - timestamp).total_seconds() / 86400.0, 0.0) 

424 # Exponential decay with configured half-life. 

425 score = math.pow(2.0, -age_days / half_life_days) if half_life_days > 0 else 1.0 

426 scored.append((score, item)) 

427 

428 scored.sort(key=lambda pair: pair[0], reverse=True) 

429 top = scored[:limit] 

430 

431 return [ 

432 ScoredItem( 

433 id=item.id, 

434 text=item.text, 

435 score=score, 

436 fact_type=item.fact_type, 

437 metadata=item.metadata, 

438 tags=item.tags, 

439 memory_layer=item.memory_layer, 

440 occurred_at=item.occurred_at, 

441 retained_at=getattr(item, "retained_at", None), 

442 chunk_id=getattr(item, "chunk_id", None), 

443 ) 

444 for score, item in top 

445 ] 

446 

447 

448def _extract_timestamp(item: VectorItem) -> datetime | None: 

449 """Best-effort timestamp extraction for temporal ranking. 

450 

451 Precedence: ``occurred_at`` (event/session time when the caller set it) 

452 → ``metadata["_created_at"]`` (retain time used for lifecycle/TTL). 

453 ISO strings and datetime instances are both accepted; naive datetimes are 

454 interpreted as UTC. 

455 """ 

456 metadata = item.metadata or {} 

457 dt = item.occurred_at 

458 if dt is None: 

459 raw = metadata.get("_created_at") 

460 if isinstance(raw, str): 

461 try: 

462 dt = datetime.fromisoformat(raw) 

463 except ValueError: 

464 dt = None 

465 elif isinstance(raw, datetime): 

466 dt = raw 

467 

468 if dt is None: 

469 return None 

470 if dt.tzinfo is None: 

471 dt = dt.replace(tzinfo=timezone.utc) 

472 return dt