Coverage for astrocyte/hybrid.py: 73%
139 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Hybrid engine + storage pipeline as a single :class:`~astrocyte.provider.EngineProvider`.
3**Selection rules (summary)**
51. **Retain path** — Exactly one backend receives writes, chosen by ``retain_target``:
6 ``"engine"`` → ``EngineProvider.retain`` (requires ``engine=``);
7 ``"pipeline"`` → ``PipelineOrchestrator.retain`` (requires ``pipeline=``).
92. **Recall path** — If both sides are configured, ``recall`` runs them **concurrently**
10 (``asyncio.gather``), tags hits with ``source`` ``tier2_engine`` vs ``tier1_pipeline``
11 (legacy tag values, kept for back-compat with downstream filters),
12 optionally **dedupes** by text (keeping best score), then applies per-source **weights**
13 (``engine_recall_weight`` / ``pipeline_recall_weight``). Results are sorted by weighted
14 score and trimmed to ``max_results`` / token budget.
163. **Adaptive routing** (optional, ``adaptive_routing=True``) — :class:`AdaptiveRouter`
17 adjusts engine vs pipeline weights per query using heuristics: temporal cues (regex),
18 entity density (capitalized words), question shape (how/why vs what/who/list), and
19 short queries; engine boosts require matching ``EngineCapabilities`` (e.g. temporal/graph).
214. **Reflect / forget** — Prefer engine when it advertises capability; otherwise pipeline
22 or raise ``NotImplementedError`` (see implementation).
24When changing merge semantics or router heuristics, update **tests**:
25``tests/test_astrocyte_tier2.py`` (engine provider), ``tests/test_hybrid_engine.py`` (merge +
26``retain_target``), and ``tests/test_phase2_innovations.py`` (adaptive router).
27"""
29from __future__ import annotations
31import asyncio
32import re
33from dataclasses import replace
34from typing import TYPE_CHECKING, Any, ClassVar, Literal
36from astrocyte.policy.homeostasis import enforce_token_budget
38if TYPE_CHECKING:
39 from astrocyte.pipeline.orchestrator import PipelineOrchestrator
40 from astrocyte.provider import EngineProvider
42from astrocyte.types import (
43 EngineCapabilities,
44 ForgetRequest,
45 ForgetResult,
46 HealthStatus,
47 MemoryHit,
48 RecallRequest,
49 RecallResult,
50 ReflectRequest,
51 ReflectResult,
52 RetainRequest,
53 RetainResult,
54)
57def _dedupe_hits_prefer_score(hits: list[MemoryHit]) -> list[MemoryHit]:
58 best: dict[str, MemoryHit] = {}
59 for h in hits:
60 prev = best.get(h.text)
61 if prev is None or h.score > prev.score:
62 best[h.text] = h
63 return sorted(best.values(), key=lambda x: x.score, reverse=True)
66class HybridEngineProvider:
67 """Merges recall from an ``EngineProvider`` and a ``PipelineOrchestrator``.
69 Typical use: long-lived agent memory in the hosted engine plus local embeddings / RAG
70 in the pipeline, both indexed for the same logical ``bank_id``.
71 """
73 SPI_VERSION: ClassVar[int] = 1
75 def __init__(
76 self,
77 *,
78 engine: EngineProvider | Any | None = None,
79 pipeline: PipelineOrchestrator | Any | None = None,
80 retain_target: Literal["engine", "pipeline"] = "engine",
81 engine_recall_weight: float = 1.0,
82 pipeline_recall_weight: float = 1.0,
83 dedup_across_sources: bool = True,
84 adaptive_routing: bool = False,
85 ) -> None:
86 if engine is None and pipeline is None:
87 raise ValueError("HybridEngineProvider requires at least one of engine= or pipeline=")
88 if retain_target == "engine" and engine is None:
89 raise ValueError("retain_target='engine' requires engine=")
90 if retain_target == "pipeline" and pipeline is None:
91 raise ValueError("retain_target='pipeline' requires pipeline=")
92 self._engine = engine
93 self._pipeline = pipeline
94 self._retain_target = retain_target
95 self._engine_w = float(engine_recall_weight)
96 self._pipeline_w = float(pipeline_recall_weight)
97 self._dedup_across_sources = dedup_across_sources
98 self._adaptive_routing = adaptive_routing
99 self._router = AdaptiveRouter() if adaptive_routing else None
101 def capabilities(self) -> EngineCapabilities:
102 if self._engine:
103 c = self._engine.capabilities()
104 return EngineCapabilities(
105 supports_reflect=c.supports_reflect or self._pipeline is not None,
106 supports_forget=c.supports_forget or self._pipeline is not None,
107 supports_semantic_search=c.supports_semantic_search or self._pipeline is not None,
108 supports_keyword_search=c.supports_keyword_search,
109 supports_graph_search=c.supports_graph_search,
110 supports_temporal_search=c.supports_temporal_search,
111 supports_dispositions=c.supports_dispositions,
112 supports_consolidation=c.supports_consolidation,
113 supports_entities=c.supports_entities,
114 supports_tags=c.supports_tags,
115 supports_metadata=c.supports_metadata,
116 max_retain_bytes=c.max_retain_bytes,
117 max_recall_results=c.max_recall_results,
118 max_embedding_dims=c.max_embedding_dims,
119 )
120 return EngineCapabilities(
121 supports_reflect=True,
122 supports_forget=self._pipeline is not None,
123 supports_semantic_search=True,
124 supports_tags=True,
125 supports_metadata=True,
126 )
128 async def health(self) -> HealthStatus:
129 parts: list[str] = []
130 healthy = True
131 if self._engine:
132 h = await self._engine.health()
133 healthy = healthy and h.healthy
134 parts.append(f"engine={h.message or h.healthy}")
135 if self._pipeline:
136 h = await self._pipeline.vector_store.health()
137 healthy = healthy and h.healthy
138 parts.append(f"pipeline={h.message or h.healthy}")
139 return HealthStatus(healthy=healthy, message="; ".join(parts))
141 async def retain(self, request: RetainRequest) -> RetainResult:
142 if self._retain_target == "engine":
143 return await self._engine.retain(request)
144 return await self._pipeline.retain(request)
146 async def recall(self, request: RecallRequest) -> RecallResult:
147 # Adaptive routing: compute per-query weights
148 engine_w = self._engine_w
149 pipeline_w = self._pipeline_w
150 if self._router and self._engine:
151 caps = self._engine.capabilities()
152 engine_w, pipeline_w = self._router.route(request.query, caps, self._engine_w, self._pipeline_w)
154 tasks: list[Any] = []
155 if self._engine:
156 tasks.append(self._engine.recall(request))
157 if self._pipeline:
158 tasks.append(self._pipeline.recall(request))
159 raw = await asyncio.gather(*tasks, return_exceptions=True)
161 idx = 0
162 all_hits: list[MemoryHit] = []
163 total_available = 0
164 if self._engine:
165 eng_res = raw[idx]
166 idx += 1
167 if isinstance(eng_res, BaseException):
168 raise eng_res
169 total_available += eng_res.total_available
170 for h in eng_res.hits:
171 tagged = replace(h, bank_id=h.bank_id or request.bank_id)
172 tagged = replace(tagged, source=tagged.source or "tier2_engine")
173 all_hits.append(tagged)
174 if self._pipeline:
175 pipe_res = raw[idx]
176 idx += 1
177 if isinstance(pipe_res, BaseException):
178 raise pipe_res
179 total_available += pipe_res.total_available
180 for h in pipe_res.hits:
181 tagged = replace(h, bank_id=h.bank_id or request.bank_id)
182 tagged = replace(tagged, source=tagged.source or "tier1_pipeline")
183 all_hits.append(tagged)
185 # Dedup before weighting to avoid score inflation from backend weights
186 merged = (
187 _dedupe_hits_prefer_score(all_hits)
188 if self._dedup_across_sources
189 else sorted(all_hits, key=lambda x: x.score, reverse=True)
190 )
192 # Apply weights after dedup
193 if engine_w != 1.0 or pipeline_w != 1.0:
194 weighted: list[MemoryHit] = []
195 for h in merged:
196 w = engine_w if h.source == "tier2_engine" else pipeline_w
197 weighted.append(replace(h, score=h.score * w))
198 weighted.sort(key=lambda x: x.score, reverse=True)
199 merged = weighted
200 trimmed = merged[: request.max_results]
201 truncated = False
202 if request.max_tokens:
203 trimmed, truncated = enforce_token_budget(trimmed, request.max_tokens)
204 return RecallResult(hits=trimmed, total_available=total_available, truncated=truncated)
206 async def reflect(self, request: ReflectRequest) -> ReflectResult:
207 if self._engine and self._engine.capabilities().supports_reflect:
208 return await self._engine.reflect(request)
209 if self._pipeline:
210 return await self._pipeline.reflect(request)
211 raise NotImplementedError("reflect is not available on this HybridEngineProvider")
213 async def forget(self, request: ForgetRequest) -> ForgetResult:
214 if self._engine and self._engine.capabilities().supports_forget:
215 return await self._engine.forget(request)
216 if self._pipeline and request.memory_ids:
217 deleted = await self._pipeline.vector_store.delete(request.memory_ids, request.bank_id)
218 return ForgetResult(deleted_count=deleted)
219 raise NotImplementedError("forget is not available on this HybridEngineProvider")
222# ---------------------------------------------------------------------------
223# Adaptive query routing (Phase 2 innovation)
224# ---------------------------------------------------------------------------
227_TEMPORAL_PATTERNS = re.compile(
228 r"\b(yesterday|today|last\s+week|last\s+month|before|after|when|recently|"
229 r"\d{4}[-/]\d{2}|january|february|march|april|may|june|july|august|"
230 r"september|october|november|december|monday|tuesday|wednesday|thursday|"
231 r"friday|saturday|sunday)\b",
232 re.IGNORECASE,
233)
236class AdaptiveRouter:
237 """Classifies queries and adjusts engine/pipeline weights per-query.
239 Heuristic rules:
240 - Temporal queries → boost engine (if it supports temporal search)
241 - Entity-rich queries → boost engine (if it supports graph search)
242 - Simple factual → boost pipeline (keyword/BM25 is sufficient)
243 - Complex analytical → boost engine (better reflect)
245 Sync, stateless — Rust migration candidate.
246 """
248 def route(
249 self,
250 query: str,
251 engine_caps: EngineCapabilities | None = None,
252 base_engine_weight: float = 1.0,
253 base_pipeline_weight: float = 1.0,
254 ) -> tuple[float, float]:
255 """Compute per-query weights for engine vs pipeline.
257 Returns (engine_weight, pipeline_weight).
258 """
259 engine_w = base_engine_weight
260 pipeline_w = base_pipeline_weight
262 # Temporal signals
263 temporal_matches = len(_TEMPORAL_PATTERNS.findall(query))
264 if temporal_matches > 0:
265 if engine_caps and engine_caps.supports_temporal_search:
266 engine_w *= 1.5 + 0.2 * min(temporal_matches, 3)
267 else:
268 pipeline_w *= 1.2 # Pipeline can still filter by date
270 # Entity density — count capitalized words as proxy
271 words = query.split()
272 capitalized = sum(1 for w in words if w[0:1].isupper() and len(w) > 1)
273 entity_ratio = capitalized / max(len(words), 1)
274 if entity_ratio > 0.3:
275 if engine_caps and engine_caps.supports_graph_search:
276 engine_w *= 1.4
277 else:
278 engine_w *= 1.1
280 # Question complexity — how/why suggest deeper reasoning
281 lower = query.lower().strip()
282 if lower.startswith(("how ", "why ", "explain ", "analyze ", "compare ")):
283 if engine_caps and engine_caps.supports_reflect:
284 engine_w *= 1.3
285 elif lower.startswith(("what is ", "who is ", "where is ", "list ")):
286 pipeline_w *= 1.2 # Simple factual → BM25 is fine
288 # Short queries → pipeline (keyword sufficient)
289 if len(words) <= 3:
290 pipeline_w *= 1.2
292 return engine_w, pipeline_w