Coverage for astrocyte/pipeline/tiered_retrieval.py: 97%
91 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Adaptive tiered retrieval — progressive escalation from cache to agentic recall.
3Inspired by ByteRover's tiered approach: most queries resolve cheaply.
4Only novel/ambiguous queries escalate to full multi-strategy retrieval.
6Sync decision logic, async execution — Rust migration candidate for the sync parts.
7"""
9from __future__ import annotations
11from collections.abc import Awaitable, Callable
12from typing import TYPE_CHECKING
14from astrocyte.pipeline.recall_cache import RecallCache
15from astrocyte.pipeline.recent_buffer import RecentMemoryBuffer
16from astrocyte.recall.merge_result import merge_external_into_recall_result
17from astrocyte.types import MemoryHit, RecallRequest, RecallResult, RecallTrace
19if TYPE_CHECKING:
20 from astrocyte.pipeline.orchestrator import PipelineOrchestrator
22# Escalation target for tier 3+ (and tier 4 post-reformulation). Default: built-in pipeline recall.
23FullRecallFn = Callable[[RecallRequest], Awaitable[RecallResult]]
26class TieredRetriever:
27 """5-tier progressive retrieval — escalates only when needed.
29 Tier 0: Recall cache hit (~0ms, free)
30 Tier 1: Fuzzy text match on recent memories (~5ms, free)
31 Tier 2: BM25 keyword search only (~50ms, free)
32 Tier 3: Full multi-strategy parallel retrieval (~200ms, embedding cost)
33 Tier 4: Agentic recall — LLM reformulates query + retry (~3-10s, LLM cost)
34 """
36 def __init__(
37 self,
38 pipeline: PipelineOrchestrator,
39 recall_cache: RecallCache | None = None,
40 recent_buffer: RecentMemoryBuffer | None = None,
41 min_results: int = 3,
42 min_score: float = 0.3,
43 max_tier: int = 3,
44 *,
45 full_recall: FullRecallFn | None = None,
46 ) -> None:
47 self.pipeline = pipeline
48 self.recall_cache = recall_cache
49 self.recent_buffer = recent_buffer
50 self.min_results = min_results
51 self.min_score = min_score
52 self.max_tier = min(max_tier, 4)
53 # When None, tier 3+ uses ``pipeline.recall`` (pipeline-only tiering).
54 self._full_recall: FullRecallFn | None = full_recall
56 async def _invoke_full_recall(self, request: RecallRequest) -> RecallResult:
57 if self._full_recall is not None:
58 return await self._full_recall(request)
59 return await self.pipeline.recall(request)
61 def notify_retain(self, bank_id: str, memory_id: str, text: str, metadata: dict | None = None) -> None:
62 """Called after a successful retain to populate the recent buffer.
64 Should be called by the dispatcher or orchestrator for each stored chunk.
65 Invalidates the recall cache for the bank (contents changed).
66 """
67 if self.recent_buffer is not None:
68 self.recent_buffer.add(bank_id, memory_id, text, metadata)
69 if self.recall_cache is not None:
70 self.recall_cache.invalidate_bank(bank_id)
72 async def retrieve(self, request: RecallRequest) -> RecallResult:
73 """Run tiered retrieval, escalating until sufficient results found."""
74 from astrocyte.pipeline.embedding import generate_embeddings
76 # Federated / proxy hits must not use stale cache entries and are not cached themselves.
77 allow_cache = not request.external_context
79 # ── Tier 0: Cache ──
80 if self.recall_cache and self.max_tier >= 0 and allow_cache:
81 query_embeddings = await generate_embeddings([request.query], self.pipeline.llm_provider)
82 query_vector = query_embeddings[0]
84 cached = self.recall_cache.get(request.bank_id, query_vector)
85 if cached is not None:
86 return RecallResult(
87 hits=cached.hits[: request.max_results],
88 total_available=cached.total_available,
89 truncated=cached.truncated,
90 trace=RecallTrace(
91 strategies_used=["cache"],
92 total_candidates=len(cached.hits),
93 fusion_method="cache",
94 tier_used=0,
95 cache_hit=True,
96 ),
97 )
98 else:
99 query_vector = None
101 # ── Tier 1: Fuzzy text match on recent memories ──
102 if self.recent_buffer and self.max_tier >= 1:
103 fuzzy_hits = self.recent_buffer.search(
104 request.query,
105 request.bank_id,
106 limit=request.max_results * 2,
107 min_score=self.min_score,
108 )
109 if len(fuzzy_hits) >= self.min_results:
110 hits = fuzzy_hits[: request.max_results]
111 result = RecallResult(
112 hits=hits,
113 total_available=len(fuzzy_hits),
114 truncated=len(fuzzy_hits) > request.max_results,
115 trace=RecallTrace(
116 strategies_used=["fuzzy_recent"],
117 total_candidates=len(fuzzy_hits),
118 fusion_method="fuzzy_recent",
119 tier_used=1,
120 cache_hit=False,
121 ),
122 )
123 if request.external_context:
124 result = merge_external_into_recall_result(result, request.external_context, request.max_results)
125 if self.recall_cache and query_vector and allow_cache:
126 self.recall_cache.put(request.bank_id, query_vector, result)
127 return result
129 # ── Tier 2: BM25 keyword only ──
130 if self.pipeline.document_store and self.max_tier >= 2:
131 keyword_hits = await self.pipeline.document_store.search_fulltext(
132 request.query, request.bank_id, limit=request.max_results * 2
133 )
134 if len(keyword_hits) >= self.min_results:
135 avg_score = sum(h.score for h in keyword_hits) / max(len(keyword_hits), 1)
136 if avg_score >= self.min_score:
137 hits = [
138 MemoryHit(
139 text=h.text,
140 score=h.score,
141 metadata=h.metadata,
142 memory_id=h.document_id,
143 bank_id=request.bank_id,
144 )
145 for h in keyword_hits[: request.max_results]
146 ]
147 result = RecallResult(
148 hits=hits,
149 total_available=len(keyword_hits),
150 truncated=len(keyword_hits) > request.max_results,
151 trace=RecallTrace(
152 strategies_used=["keyword"],
153 total_candidates=len(keyword_hits),
154 fusion_method="bm25_only",
155 tier_used=2,
156 cache_hit=False,
157 ),
158 )
159 if request.external_context:
160 result = merge_external_into_recall_result(
161 result, request.external_context, request.max_results
162 )
163 # Cache the result for future queries
164 if self.recall_cache and query_vector and allow_cache:
165 self.recall_cache.put(request.bank_id, query_vector, result)
166 return result
168 # ── Tier 3: Full multi-strategy (pipeline recall or injected full recall, e.g. hybrid) ──
169 if self.max_tier >= 3:
170 result = await self._invoke_full_recall(request)
171 # Tag the trace
172 if result.trace:
173 result.trace.tier_used = 3
174 result.trace.cache_hit = False
176 # Check if sufficient
177 if len(result.hits) >= self.min_results or self.max_tier <= 3:
178 # Cache the result
179 if self.recall_cache and query_vector and allow_cache:
180 self.recall_cache.put(request.bank_id, query_vector, result)
181 return result
183 # ── Tier 4: Agentic recall — LLM reformulates query ──
184 if self.max_tier >= 4:
185 reformulated = await self._reformulate_query(request.query)
186 reformulated_request = RecallRequest(
187 query=reformulated,
188 bank_id=request.bank_id,
189 max_results=request.max_results,
190 max_tokens=request.max_tokens,
191 tags=request.tags,
192 fact_types=request.fact_types,
193 time_range=request.time_range,
194 include_sources=request.include_sources,
195 layer_weights=request.layer_weights,
196 detail_level=request.detail_level,
197 external_context=request.external_context,
198 )
199 result = await self._invoke_full_recall(reformulated_request)
200 if result.trace:
201 result.trace.tier_used = 4
202 result.trace.cache_hit = False
203 result.trace.strategies_used = (result.trace.strategies_used or []) + ["agentic_reformulation"]
205 if self.recall_cache and query_vector and allow_cache:
206 self.recall_cache.put(request.bank_id, query_vector, result)
207 return result
209 # Fallback: empty result (still surface federated hits if present)
210 empty = RecallResult(
211 hits=[],
212 total_available=0,
213 truncated=False,
214 trace=RecallTrace(tier_used=0, cache_hit=False),
215 )
216 if request.external_context:
217 return merge_external_into_recall_result(empty, request.external_context, request.max_results)
218 return empty
220 async def _reformulate_query(self, query: str) -> str:
221 """Use LLM to reformulate an ambiguous query for better retrieval."""
222 from astrocyte.types import Message
224 system_msg = (
225 "Reformulate the user's query to improve memory search results. "
226 "Add synonyms, expand abbreviations, and rephrase for clarity. "
227 "Return only the reformulated query, nothing else."
228 )
229 user_msg = f"<query>\n{query[:2000]}\n</query>"
230 try:
231 completion = await self.pipeline.llm_provider.complete(
232 messages=[
233 Message(role="system", content=system_msg),
234 Message(role="user", content=user_msg),
235 ],
236 max_tokens=200,
237 temperature=0.3,
238 )
239 return completion.text.strip() or query
240 except Exception:
241 return query # Fallback to original