Coverage for astrocyte/pipeline/retrieval.py: 96%
143 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Parallel multi-strategy retrieval — runs concurrent searches across stores.
3Async (I/O-bound). See docs/_design/built-in-pipeline.md section 3.
5Strategies fused via RRF:
7- ``semantic`` — dense vector similarity (always runs).
8- ``keyword`` — BM25 full-text (runs when ``document_store`` is configured).
9- ``graph`` — entity-graph neighbor traversal (runs when ``graph_store`` is
10 configured AND query entities are resolved).
11- ``temporal`` — recency-ranked list of bank vectors (runs when the store
12 exposes ``list_vectors`` AND the strategy is enabled via
13 ``enable_temporal=True``). Rescues recently-retained memories that
14 lose the semantic cutoff to older near-matches. Inspired by Hindsight's
15 4-way parallel retrieval (see
16 ``docs/_design/platform-positioning.md`` §Mystique).
17"""
19from __future__ import annotations
21import asyncio
22import logging
23import math
24import time
25from datetime import datetime, timezone
26from typing import TYPE_CHECKING, Any
28from astrocyte.pipeline.fusion import ScoredItem
30if TYPE_CHECKING:
31 from astrocyte.provider import DocumentStore, GraphStore, VectorStore
32 from astrocyte.types import VectorFilters, VectorItem
34logger = logging.getLogger("astrocyte.retrieval")
36#: Cap on how many vectors the temporal strategy will scan per recall.
37#: For large banks we can't enumerate everything on every query — 500 is a
38#: reasonable default that still surfaces recent writes without crushing
39#: the store. Operators can lower this via ``max_temporal_scan`` in the
40#: orchestrator for hot paths.
41DEFAULT_TEMPORAL_SCAN_CAP = 500
43#: Half-life for the temporal score's exponential decay. A memory this many
44#: days old scores 0.5; older memories fall off; fresher memories climb
45#: toward 1.0. Tuned to "a week feels recent, a month feels old" for
46#: conversational memory workloads. Configurable per orchestrator.
47DEFAULT_TEMPORAL_HALF_LIFE_DAYS = 7.0
50async def parallel_retrieve(
51 query_vector: list[float],
52 query_text: str,
53 bank_id: str,
54 vector_store: VectorStore,
55 graph_store: GraphStore | None = None,
56 document_store: DocumentStore | None = None,
57 entity_ids: list[str] | None = None,
58 limit: int = 30,
59 filters: VectorFilters | None = None,
60 *,
61 enable_temporal: bool = True,
62 temporal_scan_cap: int = DEFAULT_TEMPORAL_SCAN_CAP,
63 temporal_half_life_days: float = DEFAULT_TEMPORAL_HALF_LIFE_DAYS,
64 hyde_vector: list[float] | None = None,
65 strategy_timings_ms: dict[str, float] | None = None,
66 strategy_candidate_counts: dict[str, int] | None = None,
67 use_bm25_idf: bool = False,
68) -> dict[str, list[ScoredItem]]:
69 """Run parallel retrieval across all configured stores.
71 Returns a dict of ``{strategy_name: list[ScoredItem]}``. Strategies run
72 concurrently via ``asyncio.gather``; a failure in one strategy never
73 blocks the others.
75 Args:
76 enable_temporal: Gate the temporal strategy. Default on. Turn off
77 for workloads where recency is not a signal (static document
78 corpora). Requires ``vector_store.list_vectors``.
79 temporal_scan_cap: Cap on vectors scanned per recall for temporal
80 ranking. Guards against O(bank) cost on large stores.
81 temporal_half_life_days: Exponential decay half-life for temporal
82 score. Tune shorter (e.g. 1.0) for fast-moving chat workloads,
83 longer (e.g. 30.0) for slower knowledge bases.
84 hyde_vector: Optional pre-computed HyDE embedding (hypothetical
85 document embedding). When provided, an additional ``"hyde"``
86 strategy runs semantic search with this vector and its results
87 are fused via RRF alongside the standard ``"semantic"`` strategy.
88 Generate with :func:`astrocyte.pipeline.hyde.generate_hyde_vector`.
89 use_bm25_idf: When ``True`` AND the document store advertises
90 ``search_fulltext_bm25`` (PostgresStore via migration 013),
91 route the keyword strategy through the BM25-with-IDF
92 materialized-view path instead of the classic ``ts_rank_cd``
93 path. Skips the hybrid CTE (which uses ts_rank_cd) — keyword
94 and semantic run as separate strategies. Default ``False``.
95 """
96 tasks: dict[str, asyncio.Task[tuple[list[ScoredItem], float]]] = {}
97 hybrid_task: asyncio.Task[tuple[dict[str, list[ScoredItem]], float]] | None = None
99 hybrid_search = getattr(vector_store, "search_hybrid_semantic_bm25", None)
100 use_hybrid_search = (
101 callable(hybrid_search)
102 and document_store is vector_store
103 and query_text.strip()
104 and hyde_vector is None
105 # BM25-IDF requires its own keyword path (the materialized-view
106 # query). When enabled, skip the hybrid CTE so keyword goes
107 # through ``_keyword_search`` → ``search_fulltext_bm25``.
108 and not use_bm25_idf
109 )
111 # PostgresStore can answer semantic and keyword retrieval in one SQL round trip.
112 # Other adapters keep the portable per-strategy path.
113 if use_hybrid_search:
114 hybrid_task = asyncio.create_task(
115 _timed_hybrid_semantic_keyword_search(
116 vector_store,
117 query_vector,
118 query_text,
119 bank_id,
120 limit,
121 filters,
122 )
123 )
124 else:
125 tasks["semantic"] = asyncio.create_task(
126 _timed(_semantic_search(vector_store, query_vector, bank_id, limit, filters))
127 )
129 # HyDE (R1): second semantic pass with hypothetical-document embedding.
130 # Runs concurrently with the standard semantic strategy; RRF fusion merges
131 # both result sets. No-op when hyde_vector is None (feature disabled or
132 # generation failed upstream).
133 if hyde_vector is not None:
134 tasks["hyde"] = asyncio.create_task(
135 _timed(_semantic_search(vector_store, hyde_vector, bank_id, limit, filters))
136 )
138 # Graph search if store configured and entities found
139 if graph_store and entity_ids:
140 tasks["graph"] = asyncio.create_task(_timed(_graph_search(graph_store, entity_ids, bank_id, limit)))
142 # Full-text search if document store configured
143 if document_store and not use_hybrid_search:
144 tasks["keyword"] = asyncio.create_task(
145 _timed(_keyword_search(document_store, query_text, bank_id, limit, use_bm25_idf=use_bm25_idf))
146 )
148 # Temporal search if the vector store can enumerate. Capped scan keeps
149 # cost bounded; rank by metadata[_created_at]/occurred_at recency decay.
150 as_of = filters.as_of if filters is not None else None
151 if enable_temporal and hasattr(vector_store, "list_vectors"):
152 tasks["temporal"] = asyncio.create_task(
153 _timed(
154 _temporal_search(
155 vector_store,
156 bank_id,
157 limit,
158 scan_cap=temporal_scan_cap,
159 half_life_days=temporal_half_life_days,
160 as_of=as_of,
161 filters=filters,
162 )
163 )
164 )
166 # Wait for all strategies
167 results: dict[str, list[ScoredItem]] = {}
168 if hybrid_task is not None:
169 try:
170 hybrid_results, elapsed_ms = await hybrid_task
171 for name, items in hybrid_results.items():
172 results[name] = items
173 if strategy_timings_ms is not None:
174 strategy_timings_ms[name] = elapsed_ms
175 if strategy_candidate_counts is not None:
176 strategy_candidate_counts[name] = len(items)
177 except Exception as exc:
178 # Hybrid CTE failed (transient pool error, deadlock, lock-wait,
179 # OOM, etc.). DON'T clobber semantic + keyword to [] — that
180 # turns one transient DB hiccup into a recall failure for the
181 # entire question. Fall back to running the same two strategies
182 # as separate per-store calls (the portable path other adapters
183 # use). Each fallback is isolated, so if e.g. semantic succeeds
184 # but keyword fails, semantic still flows through.
185 logger.warning(
186 "retrieval strategy hybrid_semantic_bm25 failed (%s); falling back to per-strategy semantic + keyword",
187 exc,
188 )
189 sem_task = asyncio.create_task(
190 _timed(_semantic_search(vector_store, query_vector, bank_id, limit, filters))
191 )
192 kw_task = asyncio.create_task(_timed(_keyword_search(document_store, query_text, bank_id, limit)))
193 for name, fallback_task in (("semantic", sem_task), ("keyword", kw_task)):
194 try:
195 items, fallback_ms = await fallback_task
196 results[name] = items
197 if strategy_timings_ms is not None:
198 strategy_timings_ms[name] = fallback_ms
199 if strategy_candidate_counts is not None:
200 strategy_candidate_counts[name] = len(items)
201 except Exception as inner_exc:
202 logger.warning(
203 "fallback strategy %s also failed: %s",
204 name,
205 inner_exc,
206 )
207 results[name] = []
208 if strategy_timings_ms is not None:
209 strategy_timings_ms[name] = 0.0
210 if strategy_candidate_counts is not None:
211 strategy_candidate_counts[name] = 0
213 for name, task in tasks.items():
214 try:
215 items, elapsed_ms = await task
216 results[name] = items
217 if strategy_timings_ms is not None:
218 strategy_timings_ms[name] = elapsed_ms
219 if strategy_candidate_counts is not None:
220 strategy_candidate_counts[name] = len(items)
221 except Exception as exc: # pragma: no cover — per-strategy isolation
222 logger.warning("retrieval strategy %s failed: %s", name, exc)
223 results[name] = [] # Strategy failure should not block others
224 if strategy_timings_ms is not None:
225 strategy_timings_ms[name] = 0.0
226 if strategy_candidate_counts is not None:
227 strategy_candidate_counts[name] = 0
229 return results
232async def _timed(coro) -> tuple[list[ScoredItem], float]:
233 start = time.perf_counter()
234 result = await coro
235 elapsed_ms = (time.perf_counter() - start) * 1000
236 return result, elapsed_ms
239async def _semantic_search(
240 vector_store: VectorStore,
241 query_vector: list[float],
242 bank_id: str,
243 limit: int,
244 filters: VectorFilters | None,
245) -> list[ScoredItem]:
246 """Vector similarity search."""
247 hits = await vector_store.search_similar(query_vector, bank_id, limit=limit, filters=filters)
248 return [
249 ScoredItem(
250 id=h.id,
251 text=h.text,
252 score=h.score,
253 fact_type=h.fact_type,
254 metadata=h.metadata,
255 tags=h.tags,
256 occurred_at=h.occurred_at,
257 retained_at=getattr(h, "retained_at", None),
258 chunk_id=getattr(h, "chunk_id", None),
259 )
260 for h in hits
261 ]
264async def _timed_hybrid_semantic_keyword_search(
265 vector_store: VectorStore,
266 query_vector: list[float],
267 query_text: str,
268 bank_id: str,
269 limit: int,
270 filters: VectorFilters | None,
271) -> tuple[dict[str, list[ScoredItem]], float]:
272 start = time.perf_counter()
273 raw = await vector_store.search_hybrid_semantic_bm25(
274 query_vector,
275 query_text,
276 bank_id,
277 limit=limit,
278 filters=filters,
279 )
280 elapsed_ms = (time.perf_counter() - start) * 1000
281 return _hybrid_hits_to_scored_items(raw), elapsed_ms
284def _hybrid_hits_to_scored_items(raw: dict[str, list[Any]]) -> dict[str, list[ScoredItem]]:
285 results: dict[str, list[ScoredItem]] = {"semantic": [], "keyword": []}
286 for hit in raw.get("semantic", []):
287 results["semantic"].append(
288 ScoredItem(
289 id=hit.id,
290 text=hit.text,
291 score=hit.score,
292 fact_type=hit.fact_type,
293 metadata=hit.metadata,
294 tags=hit.tags,
295 occurred_at=hit.occurred_at,
296 retained_at=getattr(hit, "retained_at", None),
297 chunk_id=getattr(hit, "chunk_id", None),
298 )
299 )
300 for hit in raw.get("keyword", []):
301 results["keyword"].append(
302 ScoredItem(
303 id=hit.document_id,
304 text=hit.text,
305 score=hit.score,
306 metadata=hit.metadata,
307 )
308 )
309 return results
312async def _graph_search(
313 graph_store: GraphStore,
314 entity_ids: list[str],
315 bank_id: str,
316 limit: int,
317) -> list[ScoredItem]:
318 """Graph neighbor traversal."""
319 hits = await graph_store.query_neighbors(entity_ids, bank_id, max_depth=2, limit=limit)
320 return [
321 ScoredItem(
322 id=h.memory_id,
323 text=h.text,
324 score=h.score,
325 fact_type=None,
326 )
327 for h in hits
328 ]
331async def _keyword_search(
332 document_store: DocumentStore,
333 query_text: str,
334 bank_id: str,
335 limit: int,
336 *,
337 use_bm25_idf: bool = False,
338) -> list[ScoredItem]:
339 """Full-text search.
341 Routes through :meth:`PostgresStore.search_fulltext_bm25` (proper BM25
342 with corpus IDF + length normalisation) when ``use_bm25_idf=True`` AND
343 the store advertises that method; otherwise falls through to the
344 classic :meth:`DocumentStore.search_fulltext` (``ts_rank_cd``).
346 Stores that don't expose ``search_fulltext_bm25`` (in_memory,
347 elasticsearch adapter, etc.) silently use the classic path even when
348 the flag is on — the flag is "use BM25-IDF if available," not "fail
349 if unavailable."
350 """
351 bm25_method = getattr(document_store, "search_fulltext_bm25", None)
352 if use_bm25_idf and callable(bm25_method):
353 hits = await bm25_method(query_text, bank_id, limit=limit)
354 else:
355 hits = await document_store.search_fulltext(query_text, bank_id, limit=limit)
356 return [
357 ScoredItem(
358 id=h.document_id,
359 text=h.text,
360 score=h.score,
361 metadata=h.metadata,
362 )
363 for h in hits
364 ]
367async def _temporal_search(
368 vector_store: VectorStore,
369 bank_id: str,
370 limit: int,
371 *,
372 scan_cap: int,
373 half_life_days: float,
374 as_of: datetime | None = None,
375 filters: VectorFilters | None = None,
376) -> list[ScoredItem]:
377 """Recency-ranked strategy.
379 Enumerates up to ``scan_cap`` vectors from ``bank_id`` via
380 ``list_vectors``, ranks them by recency decay over
381 ``metadata["_created_at"]`` (falling back to ``occurred_at`` when
382 present), and returns the top ``limit`` as :class:`ScoredItem`.
384 The decay is exponential with half-life ``half_life_days``:
385 ``score = 2 ** (-age_days / half_life_days)``. A memory exactly
386 ``half_life_days`` old scores 0.5; a fresh memory approaches 1.0.
388 Vectors without a usable timestamp are skipped (not ranked at the
389 bottom — they simply don't contribute, because "we don't know when"
390 is a different signal than "we know it's old"). When every candidate
391 lacks timestamps, the result is an empty list and RRF ignores the
392 strategy entirely.
393 """
394 recent_vectors = getattr(vector_store, "list_recent_vectors", None)
395 if callable(recent_vectors):
396 scanned = await recent_vectors(bank_id, limit=scan_cap, filters=filters)
397 else:
398 # Accumulate via paginated list_vectors so large banks don't blow memory.
399 scanned = []
400 offset = 0
401 batch = min(200, scan_cap)
402 while len(scanned) < scan_cap:
403 page = await vector_store.list_vectors(bank_id, offset=offset, limit=batch)
404 if not page:
405 break
406 scanned.extend(page)
407 if len(page) < batch:
408 break # last page
409 offset += batch
411 if not scanned:
412 return []
414 now = datetime.now(timezone.utc)
415 scored: list[tuple[float, VectorItem]] = []
416 for item in scanned:
417 # M9: time-travel filter — skip items retained after as_of
418 if as_of is not None and item.retained_at is not None and item.retained_at > as_of:
419 continue
420 timestamp = _extract_timestamp(item)
421 if timestamp is None:
422 continue
423 age_days = max((now - timestamp).total_seconds() / 86400.0, 0.0)
424 # Exponential decay with configured half-life.
425 score = math.pow(2.0, -age_days / half_life_days) if half_life_days > 0 else 1.0
426 scored.append((score, item))
428 scored.sort(key=lambda pair: pair[0], reverse=True)
429 top = scored[:limit]
431 return [
432 ScoredItem(
433 id=item.id,
434 text=item.text,
435 score=score,
436 fact_type=item.fact_type,
437 metadata=item.metadata,
438 tags=item.tags,
439 memory_layer=item.memory_layer,
440 occurred_at=item.occurred_at,
441 retained_at=getattr(item, "retained_at", None),
442 chunk_id=getattr(item, "chunk_id", None),
443 )
444 for score, item in top
445 ]
448def _extract_timestamp(item: VectorItem) -> datetime | None:
449 """Best-effort timestamp extraction for temporal ranking.
451 Precedence: ``occurred_at`` (event/session time when the caller set it)
452 → ``metadata["_created_at"]`` (retain time used for lifecycle/TTL).
453 ISO strings and datetime instances are both accepted; naive datetimes are
454 interpreted as UTC.
455 """
456 metadata = item.metadata or {}
457 dt = item.occurred_at
458 if dt is None:
459 raw = metadata.get("_created_at")
460 if isinstance(raw, str):
461 try:
462 dt = datetime.fromisoformat(raw)
463 except ValueError:
464 dt = None
465 elif isinstance(raw, datetime):
466 dt = raw
468 if dt is None:
469 return None
470 if dt.tzinfo is None:
471 dt = dt.replace(tzinfo=timezone.utc)
472 return dt