Coverage for astrocyte/hybrid.py: 73%

139 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Hybrid engine + storage pipeline as a single :class:`~astrocyte.provider.EngineProvider`. 

2 

3**Selection rules (summary)** 

4 

51. **Retain path** — Exactly one backend receives writes, chosen by ``retain_target``: 

6 ``"engine"`` → ``EngineProvider.retain`` (requires ``engine=``); 

7 ``"pipeline"`` → ``PipelineOrchestrator.retain`` (requires ``pipeline=``). 

8 

92. **Recall path** — If both sides are configured, ``recall`` runs them **concurrently** 

10 (``asyncio.gather``), tags hits with ``source`` ``tier2_engine`` vs ``tier1_pipeline`` 

11 (legacy tag values, kept for back-compat with downstream filters), 

12 optionally **dedupes** by text (keeping best score), then applies per-source **weights** 

13 (``engine_recall_weight`` / ``pipeline_recall_weight``). Results are sorted by weighted 

14 score and trimmed to ``max_results`` / token budget. 

15 

163. **Adaptive routing** (optional, ``adaptive_routing=True``) — :class:`AdaptiveRouter` 

17 adjusts engine vs pipeline weights per query using heuristics: temporal cues (regex), 

18 entity density (capitalized words), question shape (how/why vs what/who/list), and 

19 short queries; engine boosts require matching ``EngineCapabilities`` (e.g. temporal/graph). 

20 

214. **Reflect / forget** — Prefer engine when it advertises capability; otherwise pipeline 

22 or raise ``NotImplementedError`` (see implementation). 

23 

24When changing merge semantics or router heuristics, update **tests**: 

25``tests/test_astrocyte_tier2.py`` (engine provider), ``tests/test_hybrid_engine.py`` (merge + 

26``retain_target``), and ``tests/test_phase2_innovations.py`` (adaptive router). 

27""" 

28 

29from __future__ import annotations 

30 

31import asyncio 

32import re 

33from dataclasses import replace 

34from typing import TYPE_CHECKING, Any, ClassVar, Literal 

35 

36from astrocyte.policy.homeostasis import enforce_token_budget 

37 

38if TYPE_CHECKING: 

39 from astrocyte.pipeline.orchestrator import PipelineOrchestrator 

40 from astrocyte.provider import EngineProvider 

41 

42from astrocyte.types import ( 

43 EngineCapabilities, 

44 ForgetRequest, 

45 ForgetResult, 

46 HealthStatus, 

47 MemoryHit, 

48 RecallRequest, 

49 RecallResult, 

50 ReflectRequest, 

51 ReflectResult, 

52 RetainRequest, 

53 RetainResult, 

54) 

55 

56 

57def _dedupe_hits_prefer_score(hits: list[MemoryHit]) -> list[MemoryHit]: 

58 best: dict[str, MemoryHit] = {} 

59 for h in hits: 

60 prev = best.get(h.text) 

61 if prev is None or h.score > prev.score: 

62 best[h.text] = h 

63 return sorted(best.values(), key=lambda x: x.score, reverse=True) 

64 

65 

66class HybridEngineProvider: 

67 """Merges recall from an ``EngineProvider`` and a ``PipelineOrchestrator``. 

68 

69 Typical use: long-lived agent memory in the hosted engine plus local embeddings / RAG 

70 in the pipeline, both indexed for the same logical ``bank_id``. 

71 """ 

72 

73 SPI_VERSION: ClassVar[int] = 1 

74 

75 def __init__( 

76 self, 

77 *, 

78 engine: EngineProvider | Any | None = None, 

79 pipeline: PipelineOrchestrator | Any | None = None, 

80 retain_target: Literal["engine", "pipeline"] = "engine", 

81 engine_recall_weight: float = 1.0, 

82 pipeline_recall_weight: float = 1.0, 

83 dedup_across_sources: bool = True, 

84 adaptive_routing: bool = False, 

85 ) -> None: 

86 if engine is None and pipeline is None: 

87 raise ValueError("HybridEngineProvider requires at least one of engine= or pipeline=") 

88 if retain_target == "engine" and engine is None: 

89 raise ValueError("retain_target='engine' requires engine=") 

90 if retain_target == "pipeline" and pipeline is None: 

91 raise ValueError("retain_target='pipeline' requires pipeline=") 

92 self._engine = engine 

93 self._pipeline = pipeline 

94 self._retain_target = retain_target 

95 self._engine_w = float(engine_recall_weight) 

96 self._pipeline_w = float(pipeline_recall_weight) 

97 self._dedup_across_sources = dedup_across_sources 

98 self._adaptive_routing = adaptive_routing 

99 self._router = AdaptiveRouter() if adaptive_routing else None 

100 

101 def capabilities(self) -> EngineCapabilities: 

102 if self._engine: 

103 c = self._engine.capabilities() 

104 return EngineCapabilities( 

105 supports_reflect=c.supports_reflect or self._pipeline is not None, 

106 supports_forget=c.supports_forget or self._pipeline is not None, 

107 supports_semantic_search=c.supports_semantic_search or self._pipeline is not None, 

108 supports_keyword_search=c.supports_keyword_search, 

109 supports_graph_search=c.supports_graph_search, 

110 supports_temporal_search=c.supports_temporal_search, 

111 supports_dispositions=c.supports_dispositions, 

112 supports_consolidation=c.supports_consolidation, 

113 supports_entities=c.supports_entities, 

114 supports_tags=c.supports_tags, 

115 supports_metadata=c.supports_metadata, 

116 max_retain_bytes=c.max_retain_bytes, 

117 max_recall_results=c.max_recall_results, 

118 max_embedding_dims=c.max_embedding_dims, 

119 ) 

120 return EngineCapabilities( 

121 supports_reflect=True, 

122 supports_forget=self._pipeline is not None, 

123 supports_semantic_search=True, 

124 supports_tags=True, 

125 supports_metadata=True, 

126 ) 

127 

128 async def health(self) -> HealthStatus: 

129 parts: list[str] = [] 

130 healthy = True 

131 if self._engine: 

132 h = await self._engine.health() 

133 healthy = healthy and h.healthy 

134 parts.append(f"engine={h.message or h.healthy}") 

135 if self._pipeline: 

136 h = await self._pipeline.vector_store.health() 

137 healthy = healthy and h.healthy 

138 parts.append(f"pipeline={h.message or h.healthy}") 

139 return HealthStatus(healthy=healthy, message="; ".join(parts)) 

140 

141 async def retain(self, request: RetainRequest) -> RetainResult: 

142 if self._retain_target == "engine": 

143 return await self._engine.retain(request) 

144 return await self._pipeline.retain(request) 

145 

146 async def recall(self, request: RecallRequest) -> RecallResult: 

147 # Adaptive routing: compute per-query weights 

148 engine_w = self._engine_w 

149 pipeline_w = self._pipeline_w 

150 if self._router and self._engine: 

151 caps = self._engine.capabilities() 

152 engine_w, pipeline_w = self._router.route(request.query, caps, self._engine_w, self._pipeline_w) 

153 

154 tasks: list[Any] = [] 

155 if self._engine: 

156 tasks.append(self._engine.recall(request)) 

157 if self._pipeline: 

158 tasks.append(self._pipeline.recall(request)) 

159 raw = await asyncio.gather(*tasks, return_exceptions=True) 

160 

161 idx = 0 

162 all_hits: list[MemoryHit] = [] 

163 total_available = 0 

164 if self._engine: 

165 eng_res = raw[idx] 

166 idx += 1 

167 if isinstance(eng_res, BaseException): 

168 raise eng_res 

169 total_available += eng_res.total_available 

170 for h in eng_res.hits: 

171 tagged = replace(h, bank_id=h.bank_id or request.bank_id) 

172 tagged = replace(tagged, source=tagged.source or "tier2_engine") 

173 all_hits.append(tagged) 

174 if self._pipeline: 

175 pipe_res = raw[idx] 

176 idx += 1 

177 if isinstance(pipe_res, BaseException): 

178 raise pipe_res 

179 total_available += pipe_res.total_available 

180 for h in pipe_res.hits: 

181 tagged = replace(h, bank_id=h.bank_id or request.bank_id) 

182 tagged = replace(tagged, source=tagged.source or "tier1_pipeline") 

183 all_hits.append(tagged) 

184 

185 # Dedup before weighting to avoid score inflation from backend weights 

186 merged = ( 

187 _dedupe_hits_prefer_score(all_hits) 

188 if self._dedup_across_sources 

189 else sorted(all_hits, key=lambda x: x.score, reverse=True) 

190 ) 

191 

192 # Apply weights after dedup 

193 if engine_w != 1.0 or pipeline_w != 1.0: 

194 weighted: list[MemoryHit] = [] 

195 for h in merged: 

196 w = engine_w if h.source == "tier2_engine" else pipeline_w 

197 weighted.append(replace(h, score=h.score * w)) 

198 weighted.sort(key=lambda x: x.score, reverse=True) 

199 merged = weighted 

200 trimmed = merged[: request.max_results] 

201 truncated = False 

202 if request.max_tokens: 

203 trimmed, truncated = enforce_token_budget(trimmed, request.max_tokens) 

204 return RecallResult(hits=trimmed, total_available=total_available, truncated=truncated) 

205 

206 async def reflect(self, request: ReflectRequest) -> ReflectResult: 

207 if self._engine and self._engine.capabilities().supports_reflect: 

208 return await self._engine.reflect(request) 

209 if self._pipeline: 

210 return await self._pipeline.reflect(request) 

211 raise NotImplementedError("reflect is not available on this HybridEngineProvider") 

212 

213 async def forget(self, request: ForgetRequest) -> ForgetResult: 

214 if self._engine and self._engine.capabilities().supports_forget: 

215 return await self._engine.forget(request) 

216 if self._pipeline and request.memory_ids: 

217 deleted = await self._pipeline.vector_store.delete(request.memory_ids, request.bank_id) 

218 return ForgetResult(deleted_count=deleted) 

219 raise NotImplementedError("forget is not available on this HybridEngineProvider") 

220 

221 

222# --------------------------------------------------------------------------- 

223# Adaptive query routing (Phase 2 innovation) 

224# --------------------------------------------------------------------------- 

225 

226 

227_TEMPORAL_PATTERNS = re.compile( 

228 r"\b(yesterday|today|last\s+week|last\s+month|before|after|when|recently|" 

229 r"\d{4}[-/]\d{2}|january|february|march|april|may|june|july|august|" 

230 r"september|october|november|december|monday|tuesday|wednesday|thursday|" 

231 r"friday|saturday|sunday)\b", 

232 re.IGNORECASE, 

233) 

234 

235 

236class AdaptiveRouter: 

237 """Classifies queries and adjusts engine/pipeline weights per-query. 

238 

239 Heuristic rules: 

240 - Temporal queries → boost engine (if it supports temporal search) 

241 - Entity-rich queries → boost engine (if it supports graph search) 

242 - Simple factual → boost pipeline (keyword/BM25 is sufficient) 

243 - Complex analytical → boost engine (better reflect) 

244 

245 Sync, stateless — Rust migration candidate. 

246 """ 

247 

248 def route( 

249 self, 

250 query: str, 

251 engine_caps: EngineCapabilities | None = None, 

252 base_engine_weight: float = 1.0, 

253 base_pipeline_weight: float = 1.0, 

254 ) -> tuple[float, float]: 

255 """Compute per-query weights for engine vs pipeline. 

256 

257 Returns (engine_weight, pipeline_weight). 

258 """ 

259 engine_w = base_engine_weight 

260 pipeline_w = base_pipeline_weight 

261 

262 # Temporal signals 

263 temporal_matches = len(_TEMPORAL_PATTERNS.findall(query)) 

264 if temporal_matches > 0: 

265 if engine_caps and engine_caps.supports_temporal_search: 

266 engine_w *= 1.5 + 0.2 * min(temporal_matches, 3) 

267 else: 

268 pipeline_w *= 1.2 # Pipeline can still filter by date 

269 

270 # Entity density — count capitalized words as proxy 

271 words = query.split() 

272 capitalized = sum(1 for w in words if w[0:1].isupper() and len(w) > 1) 

273 entity_ratio = capitalized / max(len(words), 1) 

274 if entity_ratio > 0.3: 

275 if engine_caps and engine_caps.supports_graph_search: 

276 engine_w *= 1.4 

277 else: 

278 engine_w *= 1.1 

279 

280 # Question complexity — how/why suggest deeper reasoning 

281 lower = query.lower().strip() 

282 if lower.startswith(("how ", "why ", "explain ", "analyze ", "compare ")): 

283 if engine_caps and engine_caps.supports_reflect: 

284 engine_w *= 1.3 

285 elif lower.startswith(("what is ", "who is ", "where is ", "list ")): 

286 pipeline_w *= 1.2 # Simple factual → BM25 is fine 

287 

288 # Short queries → pipeline (keyword sufficient) 

289 if len(words) <= 3: 

290 pipeline_w *= 1.2 

291 

292 return engine_w, pipeline_w