Coverage for astrocyte/_provider_dispatch.py: 97%

87 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Provider dispatcher — routes operations to engine provider or pipeline.""" 

2 

3from __future__ import annotations 

4 

5import logging 

6from typing import TYPE_CHECKING, Any 

7 

8from astrocyte.errors import CapabilityNotSupported, ConfigError 

9from astrocyte.recall.merge_result import merge_external_into_recall_result 

10from astrocyte.types import ( 

11 ForgetRequest, 

12 ForgetResult, 

13 MemoryHit, 

14 RecallRequest, 

15 RecallResult, 

16 ReflectRequest, 

17 ReflectResult, 

18 RetainRequest, 

19 RetainResult, 

20) 

21 

22if TYPE_CHECKING: 

23 from astrocyte.config import AstrocyteConfig 

24 from astrocyte.pipeline.orchestrator import PipelineOrchestrator 

25 from astrocyte.pipeline.pageindex_pipeline import PageIndexPipeline 

26 from astrocyte.pipeline.tiered_retrieval import TieredRetriever 

27 from astrocyte.provider import EngineProvider 

28 from astrocyte.types import EngineCapabilities 

29 

30logger = logging.getLogger("astrocyte") 

31 

32 

33class ProviderDispatcher: 

34 """Routes retain/recall/reflect/forget to the configured engine or pipeline.""" 

35 

36 def __init__( 

37 self, 

38 config: AstrocyteConfig, 

39 engine_provider: EngineProvider | None = None, 

40 pipeline: PipelineOrchestrator | None = None, 

41 capabilities: EngineCapabilities | None = None, 

42 tiered_retriever: TieredRetriever | None = None, 

43 pageindex_pipeline: PageIndexPipeline | None = None, 

44 ) -> None: 

45 self._config = config 

46 self.engine_provider = engine_provider 

47 self.pipeline = pipeline 

48 self.capabilities = capabilities 

49 self.tiered_retriever = tiered_retriever 

50 # M32 — when set, ``recall()`` routes here in preference to the 

51 # legacy ``engine_provider`` / ``pipeline``. This is the PageIndex 

52 # stack — the same one the bench has validated since M14. See 

53 # docs/_design/m32-stack-unification.md for the rationale. 

54 self.pageindex_pipeline = pageindex_pipeline 

55 

56 @property 

57 def provider_name(self) -> str: 

58 return self._config.provider or "pipeline" 

59 

60 async def retain(self, request: RetainRequest) -> RetainResult: 

61 if self.engine_provider: 

62 result = await self.engine_provider.retain(request) 

63 elif self.pipeline: 

64 result = await self.pipeline.retain(request) 

65 else: 

66 raise ConfigError("No provider or pipeline configured") 

67 

68 # Notify tiered retriever of new content (populates recent buffer, invalidates cache) 

69 if result.stored and self.tiered_retriever is not None and result.memory_id: 

70 self.tiered_retriever.notify_retain( 

71 request.bank_id, 

72 result.memory_id, 

73 request.content, 

74 request.metadata, 

75 ) 

76 

77 return result 

78 

79 async def recall(self, request: RecallRequest) -> RecallResult: 

80 # M32 — PageIndex stack takes precedence when configured. This 

81 # is the same pipeline the bench harness has validated since 

82 # M14; routing public API calls through it closes the long- 

83 # standing bench-vs-API parity gap. ``external_context`` RRF 

84 # merge still applies on top of the result. 

85 if self.pageindex_pipeline is not None: 

86 result = await self.pageindex_pipeline.recall(request) 

87 if request.external_context: 

88 return merge_external_into_recall_result( 

89 result, request.external_context, request.max_results, 

90 ) 

91 return result 

92 if self.engine_provider: 

93 if self.tiered_retriever is not None and self._config.tiered_retrieval.full_recall == "hybrid": 

94 return await self.tiered_retriever.retrieve(request) 

95 result = await self.engine_provider.recall(request) 

96 # Hybrid merges pipeline (which already fuses external_context in RRF); do not merge twice. 

97 from astrocyte.hybrid import HybridEngineProvider 

98 

99 if request.external_context and not isinstance(self.engine_provider, HybridEngineProvider): 

100 return merge_external_into_recall_result(result, request.external_context, request.max_results) 

101 return result 

102 if self.pipeline: 

103 if self.tiered_retriever is not None: 

104 return await self.tiered_retriever.retrieve(request) 

105 return await self.pipeline.recall(request) 

106 raise ConfigError("No provider or pipeline configured") 

107 

108 async def reflect(self, request: ReflectRequest) -> ReflectResult: 

109 # Check if provider supports reflect 

110 if self.engine_provider: 

111 if self.capabilities and self.capabilities.supports_reflect: 

112 return await self.engine_provider.reflect(request) 

113 # Fallback 

114 if self._config.fallback_strategy == "error": 

115 raise CapabilityNotSupported(self.provider_name, "reflect") 

116 if self._config.fallback_strategy == "degrade": 

117 # Return recall results as-is 

118 recall_result = await self.recall( 

119 RecallRequest(query=request.query, bank_id=request.bank_id, max_results=10) 

120 ) 

121 return ReflectResult( 

122 answer="\n".join(h.text for h in recall_result.hits), 

123 sources=recall_result.hits, 

124 ) 

125 # local_llm fallback needs pipeline's reflect 

126 if self.pipeline: 

127 return await self.pipeline.reflect(request) 

128 raise CapabilityNotSupported(self.provider_name, "reflect") 

129 

130 if self.pipeline: 

131 return await self.pipeline.reflect(request) 

132 

133 raise ConfigError("No provider or pipeline configured") 

134 

135 async def forget(self, request: ForgetRequest) -> ForgetResult: 

136 if self.engine_provider: 

137 if self.capabilities and self.capabilities.supports_forget: 

138 return await self.engine_provider.forget(request) 

139 raise CapabilityNotSupported(self.provider_name, "forget") 

140 # Pipeline: delete from vector store 

141 if self.pipeline: 

142 if request.scope == "all" and hasattr(self.pipeline.vector_store, "list_vectors"): 

143 # Delete all vectors in bank by paginating through them 

144 total_deleted = 0 

145 while True: 

146 batch = await self.pipeline.vector_store.list_vectors(request.bank_id, offset=0, limit=100) 

147 if not batch: 

148 break 

149 ids = [v.id for v in batch] 

150 total_deleted += await self.pipeline.vector_store.delete(ids, request.bank_id) 

151 return ForgetResult(deleted_count=total_deleted) 

152 if request.memory_ids: 

153 count = await self.pipeline.vector_store.delete(request.memory_ids, request.bank_id) 

154 return ForgetResult(deleted_count=count) 

155 raise CapabilityNotSupported(self.provider_name, "forget") 

156 

157 async def reflect_from_hits( 

158 self, 

159 query: str, 

160 hits: list[MemoryHit], 

161 bank_id: str, 

162 max_tokens: int | None = None, 

163 dispositions: Any = None, 

164 authority_context: str | None = None, 

165 ) -> ReflectResult: 

166 """Synthesize over pre-fetched hits (used by multi-bank reflect). 

167 

168 Tries in order: 

169 1. Pipeline reflect (if available) — calls synthesize() directly. 

170 2. Degrade fallback — concatenate hit texts. 

171 3. Empty answer if no hits. 

172 """ 

173 # If we have a pipeline with an LLM, use its synthesis directly 

174 if self.pipeline: 

175 from astrocyte.pipeline.reflect import synthesize 

176 

177 return await synthesize( 

178 query=query, 

179 hits=hits, 

180 llm_provider=self.pipeline.llm_provider, 

181 dispositions=dispositions, 

182 max_tokens=max_tokens or 2048, 

183 authority_context=authority_context, 

184 ) 

185 

186 # Fall back to degrade mode: concatenate hit texts as the answer. 

187 if hits: 

188 return ReflectResult( 

189 answer="\n".join(h.text for h in hits), 

190 sources=hits, 

191 ) 

192 

193 return ReflectResult(answer="No relevant memories found across banks.", sources=[])