Coverage for astrocyte/pipeline/tiered_retrieval.py: 97%

91 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Adaptive tiered retrieval — progressive escalation from cache to agentic recall. 

2 

3Inspired by ByteRover's tiered approach: most queries resolve cheaply. 

4Only novel/ambiguous queries escalate to full multi-strategy retrieval. 

5 

6Sync decision logic, async execution — Rust migration candidate for the sync parts. 

7""" 

8 

9from __future__ import annotations 

10 

11from collections.abc import Awaitable, Callable 

12from typing import TYPE_CHECKING 

13 

14from astrocyte.pipeline.recall_cache import RecallCache 

15from astrocyte.pipeline.recent_buffer import RecentMemoryBuffer 

16from astrocyte.recall.merge_result import merge_external_into_recall_result 

17from astrocyte.types import MemoryHit, RecallRequest, RecallResult, RecallTrace 

18 

19if TYPE_CHECKING: 

20 from astrocyte.pipeline.orchestrator import PipelineOrchestrator 

21 

22# Escalation target for tier 3+ (and tier 4 post-reformulation). Default: built-in pipeline recall. 

23FullRecallFn = Callable[[RecallRequest], Awaitable[RecallResult]] 

24 

25 

26class TieredRetriever: 

27 """5-tier progressive retrieval — escalates only when needed. 

28 

29 Tier 0: Recall cache hit (~0ms, free) 

30 Tier 1: Fuzzy text match on recent memories (~5ms, free) 

31 Tier 2: BM25 keyword search only (~50ms, free) 

32 Tier 3: Full multi-strategy parallel retrieval (~200ms, embedding cost) 

33 Tier 4: Agentic recall — LLM reformulates query + retry (~3-10s, LLM cost) 

34 """ 

35 

36 def __init__( 

37 self, 

38 pipeline: PipelineOrchestrator, 

39 recall_cache: RecallCache | None = None, 

40 recent_buffer: RecentMemoryBuffer | None = None, 

41 min_results: int = 3, 

42 min_score: float = 0.3, 

43 max_tier: int = 3, 

44 *, 

45 full_recall: FullRecallFn | None = None, 

46 ) -> None: 

47 self.pipeline = pipeline 

48 self.recall_cache = recall_cache 

49 self.recent_buffer = recent_buffer 

50 self.min_results = min_results 

51 self.min_score = min_score 

52 self.max_tier = min(max_tier, 4) 

53 # When None, tier 3+ uses ``pipeline.recall`` (pipeline-only tiering). 

54 self._full_recall: FullRecallFn | None = full_recall 

55 

56 async def _invoke_full_recall(self, request: RecallRequest) -> RecallResult: 

57 if self._full_recall is not None: 

58 return await self._full_recall(request) 

59 return await self.pipeline.recall(request) 

60 

61 def notify_retain(self, bank_id: str, memory_id: str, text: str, metadata: dict | None = None) -> None: 

62 """Called after a successful retain to populate the recent buffer. 

63 

64 Should be called by the dispatcher or orchestrator for each stored chunk. 

65 Invalidates the recall cache for the bank (contents changed). 

66 """ 

67 if self.recent_buffer is not None: 

68 self.recent_buffer.add(bank_id, memory_id, text, metadata) 

69 if self.recall_cache is not None: 

70 self.recall_cache.invalidate_bank(bank_id) 

71 

72 async def retrieve(self, request: RecallRequest) -> RecallResult: 

73 """Run tiered retrieval, escalating until sufficient results found.""" 

74 from astrocyte.pipeline.embedding import generate_embeddings 

75 

76 # Federated / proxy hits must not use stale cache entries and are not cached themselves. 

77 allow_cache = not request.external_context 

78 

79 # ── Tier 0: Cache ── 

80 if self.recall_cache and self.max_tier >= 0 and allow_cache: 

81 query_embeddings = await generate_embeddings([request.query], self.pipeline.llm_provider) 

82 query_vector = query_embeddings[0] 

83 

84 cached = self.recall_cache.get(request.bank_id, query_vector) 

85 if cached is not None: 

86 return RecallResult( 

87 hits=cached.hits[: request.max_results], 

88 total_available=cached.total_available, 

89 truncated=cached.truncated, 

90 trace=RecallTrace( 

91 strategies_used=["cache"], 

92 total_candidates=len(cached.hits), 

93 fusion_method="cache", 

94 tier_used=0, 

95 cache_hit=True, 

96 ), 

97 ) 

98 else: 

99 query_vector = None 

100 

101 # ── Tier 1: Fuzzy text match on recent memories ── 

102 if self.recent_buffer and self.max_tier >= 1: 

103 fuzzy_hits = self.recent_buffer.search( 

104 request.query, 

105 request.bank_id, 

106 limit=request.max_results * 2, 

107 min_score=self.min_score, 

108 ) 

109 if len(fuzzy_hits) >= self.min_results: 

110 hits = fuzzy_hits[: request.max_results] 

111 result = RecallResult( 

112 hits=hits, 

113 total_available=len(fuzzy_hits), 

114 truncated=len(fuzzy_hits) > request.max_results, 

115 trace=RecallTrace( 

116 strategies_used=["fuzzy_recent"], 

117 total_candidates=len(fuzzy_hits), 

118 fusion_method="fuzzy_recent", 

119 tier_used=1, 

120 cache_hit=False, 

121 ), 

122 ) 

123 if request.external_context: 

124 result = merge_external_into_recall_result(result, request.external_context, request.max_results) 

125 if self.recall_cache and query_vector and allow_cache: 

126 self.recall_cache.put(request.bank_id, query_vector, result) 

127 return result 

128 

129 # ── Tier 2: BM25 keyword only ── 

130 if self.pipeline.document_store and self.max_tier >= 2: 

131 keyword_hits = await self.pipeline.document_store.search_fulltext( 

132 request.query, request.bank_id, limit=request.max_results * 2 

133 ) 

134 if len(keyword_hits) >= self.min_results: 

135 avg_score = sum(h.score for h in keyword_hits) / max(len(keyword_hits), 1) 

136 if avg_score >= self.min_score: 

137 hits = [ 

138 MemoryHit( 

139 text=h.text, 

140 score=h.score, 

141 metadata=h.metadata, 

142 memory_id=h.document_id, 

143 bank_id=request.bank_id, 

144 ) 

145 for h in keyword_hits[: request.max_results] 

146 ] 

147 result = RecallResult( 

148 hits=hits, 

149 total_available=len(keyword_hits), 

150 truncated=len(keyword_hits) > request.max_results, 

151 trace=RecallTrace( 

152 strategies_used=["keyword"], 

153 total_candidates=len(keyword_hits), 

154 fusion_method="bm25_only", 

155 tier_used=2, 

156 cache_hit=False, 

157 ), 

158 ) 

159 if request.external_context: 

160 result = merge_external_into_recall_result( 

161 result, request.external_context, request.max_results 

162 ) 

163 # Cache the result for future queries 

164 if self.recall_cache and query_vector and allow_cache: 

165 self.recall_cache.put(request.bank_id, query_vector, result) 

166 return result 

167 

168 # ── Tier 3: Full multi-strategy (pipeline recall or injected full recall, e.g. hybrid) ── 

169 if self.max_tier >= 3: 

170 result = await self._invoke_full_recall(request) 

171 # Tag the trace 

172 if result.trace: 

173 result.trace.tier_used = 3 

174 result.trace.cache_hit = False 

175 

176 # Check if sufficient 

177 if len(result.hits) >= self.min_results or self.max_tier <= 3: 

178 # Cache the result 

179 if self.recall_cache and query_vector and allow_cache: 

180 self.recall_cache.put(request.bank_id, query_vector, result) 

181 return result 

182 

183 # ── Tier 4: Agentic recall — LLM reformulates query ── 

184 if self.max_tier >= 4: 

185 reformulated = await self._reformulate_query(request.query) 

186 reformulated_request = RecallRequest( 

187 query=reformulated, 

188 bank_id=request.bank_id, 

189 max_results=request.max_results, 

190 max_tokens=request.max_tokens, 

191 tags=request.tags, 

192 fact_types=request.fact_types, 

193 time_range=request.time_range, 

194 include_sources=request.include_sources, 

195 layer_weights=request.layer_weights, 

196 detail_level=request.detail_level, 

197 external_context=request.external_context, 

198 ) 

199 result = await self._invoke_full_recall(reformulated_request) 

200 if result.trace: 

201 result.trace.tier_used = 4 

202 result.trace.cache_hit = False 

203 result.trace.strategies_used = (result.trace.strategies_used or []) + ["agentic_reformulation"] 

204 

205 if self.recall_cache and query_vector and allow_cache: 

206 self.recall_cache.put(request.bank_id, query_vector, result) 

207 return result 

208 

209 # Fallback: empty result (still surface federated hits if present) 

210 empty = RecallResult( 

211 hits=[], 

212 total_available=0, 

213 truncated=False, 

214 trace=RecallTrace(tier_used=0, cache_hit=False), 

215 ) 

216 if request.external_context: 

217 return merge_external_into_recall_result(empty, request.external_context, request.max_results) 

218 return empty 

219 

220 async def _reformulate_query(self, query: str) -> str: 

221 """Use LLM to reformulate an ambiguous query for better retrieval.""" 

222 from astrocyte.types import Message 

223 

224 system_msg = ( 

225 "Reformulate the user's query to improve memory search results. " 

226 "Add synonyms, expand abbreviations, and rephrase for clarity. " 

227 "Return only the reformulated query, nothing else." 

228 ) 

229 user_msg = f"<query>\n{query[:2000]}\n</query>" 

230 try: 

231 completion = await self.pipeline.llm_provider.complete( 

232 messages=[ 

233 Message(role="system", content=system_msg), 

234 Message(role="user", content=user_msg), 

235 ], 

236 max_tokens=200, 

237 temperature=0.3, 

238 ) 

239 return completion.text.strip() or query 

240 except Exception: 

241 return query # Fallback to original