Coverage for astrocyte/_provider_dispatch.py: 97%
87 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Provider dispatcher — routes operations to engine provider or pipeline."""
3from __future__ import annotations
5import logging
6from typing import TYPE_CHECKING, Any
8from astrocyte.errors import CapabilityNotSupported, ConfigError
9from astrocyte.recall.merge_result import merge_external_into_recall_result
10from astrocyte.types import (
11 ForgetRequest,
12 ForgetResult,
13 MemoryHit,
14 RecallRequest,
15 RecallResult,
16 ReflectRequest,
17 ReflectResult,
18 RetainRequest,
19 RetainResult,
20)
22if TYPE_CHECKING:
23 from astrocyte.config import AstrocyteConfig
24 from astrocyte.pipeline.orchestrator import PipelineOrchestrator
25 from astrocyte.pipeline.pageindex_pipeline import PageIndexPipeline
26 from astrocyte.pipeline.tiered_retrieval import TieredRetriever
27 from astrocyte.provider import EngineProvider
28 from astrocyte.types import EngineCapabilities
30logger = logging.getLogger("astrocyte")
33class ProviderDispatcher:
34 """Routes retain/recall/reflect/forget to the configured engine or pipeline."""
36 def __init__(
37 self,
38 config: AstrocyteConfig,
39 engine_provider: EngineProvider | None = None,
40 pipeline: PipelineOrchestrator | None = None,
41 capabilities: EngineCapabilities | None = None,
42 tiered_retriever: TieredRetriever | None = None,
43 pageindex_pipeline: PageIndexPipeline | None = None,
44 ) -> None:
45 self._config = config
46 self.engine_provider = engine_provider
47 self.pipeline = pipeline
48 self.capabilities = capabilities
49 self.tiered_retriever = tiered_retriever
50 # M32 — when set, ``recall()`` routes here in preference to the
51 # legacy ``engine_provider`` / ``pipeline``. This is the PageIndex
52 # stack — the same one the bench has validated since M14. See
53 # docs/_design/m32-stack-unification.md for the rationale.
54 self.pageindex_pipeline = pageindex_pipeline
56 @property
57 def provider_name(self) -> str:
58 return self._config.provider or "pipeline"
60 async def retain(self, request: RetainRequest) -> RetainResult:
61 if self.engine_provider:
62 result = await self.engine_provider.retain(request)
63 elif self.pipeline:
64 result = await self.pipeline.retain(request)
65 else:
66 raise ConfigError("No provider or pipeline configured")
68 # Notify tiered retriever of new content (populates recent buffer, invalidates cache)
69 if result.stored and self.tiered_retriever is not None and result.memory_id:
70 self.tiered_retriever.notify_retain(
71 request.bank_id,
72 result.memory_id,
73 request.content,
74 request.metadata,
75 )
77 return result
79 async def recall(self, request: RecallRequest) -> RecallResult:
80 # M32 — PageIndex stack takes precedence when configured. This
81 # is the same pipeline the bench harness has validated since
82 # M14; routing public API calls through it closes the long-
83 # standing bench-vs-API parity gap. ``external_context`` RRF
84 # merge still applies on top of the result.
85 if self.pageindex_pipeline is not None:
86 result = await self.pageindex_pipeline.recall(request)
87 if request.external_context:
88 return merge_external_into_recall_result(
89 result, request.external_context, request.max_results,
90 )
91 return result
92 if self.engine_provider:
93 if self.tiered_retriever is not None and self._config.tiered_retrieval.full_recall == "hybrid":
94 return await self.tiered_retriever.retrieve(request)
95 result = await self.engine_provider.recall(request)
96 # Hybrid merges pipeline (which already fuses external_context in RRF); do not merge twice.
97 from astrocyte.hybrid import HybridEngineProvider
99 if request.external_context and not isinstance(self.engine_provider, HybridEngineProvider):
100 return merge_external_into_recall_result(result, request.external_context, request.max_results)
101 return result
102 if self.pipeline:
103 if self.tiered_retriever is not None:
104 return await self.tiered_retriever.retrieve(request)
105 return await self.pipeline.recall(request)
106 raise ConfigError("No provider or pipeline configured")
108 async def reflect(self, request: ReflectRequest) -> ReflectResult:
109 # Check if provider supports reflect
110 if self.engine_provider:
111 if self.capabilities and self.capabilities.supports_reflect:
112 return await self.engine_provider.reflect(request)
113 # Fallback
114 if self._config.fallback_strategy == "error":
115 raise CapabilityNotSupported(self.provider_name, "reflect")
116 if self._config.fallback_strategy == "degrade":
117 # Return recall results as-is
118 recall_result = await self.recall(
119 RecallRequest(query=request.query, bank_id=request.bank_id, max_results=10)
120 )
121 return ReflectResult(
122 answer="\n".join(h.text for h in recall_result.hits),
123 sources=recall_result.hits,
124 )
125 # local_llm fallback needs pipeline's reflect
126 if self.pipeline:
127 return await self.pipeline.reflect(request)
128 raise CapabilityNotSupported(self.provider_name, "reflect")
130 if self.pipeline:
131 return await self.pipeline.reflect(request)
133 raise ConfigError("No provider or pipeline configured")
135 async def forget(self, request: ForgetRequest) -> ForgetResult:
136 if self.engine_provider:
137 if self.capabilities and self.capabilities.supports_forget:
138 return await self.engine_provider.forget(request)
139 raise CapabilityNotSupported(self.provider_name, "forget")
140 # Pipeline: delete from vector store
141 if self.pipeline:
142 if request.scope == "all" and hasattr(self.pipeline.vector_store, "list_vectors"):
143 # Delete all vectors in bank by paginating through them
144 total_deleted = 0
145 while True:
146 batch = await self.pipeline.vector_store.list_vectors(request.bank_id, offset=0, limit=100)
147 if not batch:
148 break
149 ids = [v.id for v in batch]
150 total_deleted += await self.pipeline.vector_store.delete(ids, request.bank_id)
151 return ForgetResult(deleted_count=total_deleted)
152 if request.memory_ids:
153 count = await self.pipeline.vector_store.delete(request.memory_ids, request.bank_id)
154 return ForgetResult(deleted_count=count)
155 raise CapabilityNotSupported(self.provider_name, "forget")
157 async def reflect_from_hits(
158 self,
159 query: str,
160 hits: list[MemoryHit],
161 bank_id: str,
162 max_tokens: int | None = None,
163 dispositions: Any = None,
164 authority_context: str | None = None,
165 ) -> ReflectResult:
166 """Synthesize over pre-fetched hits (used by multi-bank reflect).
168 Tries in order:
169 1. Pipeline reflect (if available) — calls synthesize() directly.
170 2. Degrade fallback — concatenate hit texts.
171 3. Empty answer if no hits.
172 """
173 # If we have a pipeline with an LLM, use its synthesis directly
174 if self.pipeline:
175 from astrocyte.pipeline.reflect import synthesize
177 return await synthesize(
178 query=query,
179 hits=hits,
180 llm_provider=self.pipeline.llm_provider,
181 dispositions=dispositions,
182 max_tokens=max_tokens or 2048,
183 authority_context=authority_context,
184 )
186 # Fall back to degrade mode: concatenate hit texts as the answer.
187 if hits:
188 return ReflectResult(
189 answer="\n".join(h.text for h in hits),
190 sources=hits,
191 )
193 return ReflectResult(answer="No relevant memories found across banks.", sources=[])