Coverage for astrocyte/_multi_bank.py: 94%
126 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Multi-bank recall orchestrator — parallel, cascade, and first-match strategies."""
3from __future__ import annotations
5import asyncio
6import logging
7from collections.abc import Awaitable, Callable
8from dataclasses import replace
10from astrocyte._recall_params import RecallParams
11from astrocyte.errors import ConfigError
12from astrocyte.policy.homeostasis import enforce_token_budget
13from astrocyte.policy.observability import MetricsCollector
14from astrocyte.types import (
15 MemoryHit,
16 MultiBankStrategy,
17 RecallRequest,
18 RecallResult,
19)
21logger = logging.getLogger("astrocyte")
23# Type aliases for the callbacks injected from Astrocyte
24RecallFn = Callable[[RecallRequest], Awaitable[RecallResult]]
25MakeRequestFn = Callable[
26 [str, str, int, int | None, list[str] | None, RecallParams],
27 Awaitable[RecallRequest],
28]
31def _bank_visit_order(bank_ids: list[str], cascade_order: list[str] | None) -> list[str]:
32 if not cascade_order:
33 return list(bank_ids)
34 out: list[str] = []
35 seen: set[str] = set()
36 for b in cascade_order:
37 if b in bank_ids and b not in seen:
38 out.append(b)
39 seen.add(b)
40 for b in bank_ids:
41 if b not in seen:
42 out.append(b)
43 seen.add(b)
44 return out
47def _dedupe_hits_by_text(hits: list[MemoryHit]) -> list[MemoryHit]:
48 """One hit per distinct text, keeping the highest-scoring instance."""
49 best: dict[str, MemoryHit] = {}
50 for h in hits:
51 prev = best.get(h.text)
52 if prev is None or h.score > prev.score:
53 best[h.text] = h
54 return sorted(best.values(), key=lambda x: x.score, reverse=True)
57def _apply_bank_weights(hits: list[MemoryHit], weights: dict[str, float] | None) -> list[MemoryHit]:
58 if not weights:
59 return list(hits)
60 out: list[MemoryHit] = []
61 for h in hits:
62 bid = h.bank_id or ""
63 w = float(weights.get(bid, 1.0))
64 out.append(replace(h, score=h.score * w))
65 return out
68def _tag_hits_with_bank(hits: list[MemoryHit], bank_id: str) -> list[MemoryHit]:
69 return [replace(h, bank_id=bank_id) if h.bank_id is None else h for h in hits]
72class MultiBankOrchestrator:
73 """Dispatches multi-bank recall across parallel, cascade, and first-match strategies."""
75 def __init__(
76 self,
77 *,
78 do_recall: RecallFn,
79 make_request: MakeRequestFn,
80 circuit_breaker_record_failure: Callable[[], None],
81 metrics: MetricsCollector,
82 provider_name: str,
83 ) -> None:
84 self._do_recall = do_recall
85 self._make_request = make_request
86 self._cb_record_failure = circuit_breaker_record_failure
87 self._metrics = metrics
88 self._provider_name = provider_name
90 async def recall(
91 self,
92 query: str,
93 bank_ids: list[str],
94 max_results: int,
95 max_tokens: int | None,
96 tags: list[str] | None,
97 params: RecallParams,
98 strategy: MultiBankStrategy,
99 ) -> RecallResult:
100 """Multi-bank recall — strategy dispatch."""
101 if strategy.mode == "parallel":
102 return await self._parallel(query, bank_ids, max_results, max_tokens, tags, params, strategy)
103 if strategy.mode == "cascade":
104 return await self._cascade(query, bank_ids, max_results, max_tokens, tags, params, strategy)
105 if strategy.mode == "first_match":
106 return await self._first_match(query, bank_ids, max_results, max_tokens, tags, params, strategy)
107 raise ConfigError(f"Unknown multi-bank mode: {strategy.mode!r}")
109 async def _parallel(
110 self,
111 query: str,
112 bank_ids: list[str],
113 max_results: int,
114 max_tokens: int | None,
115 tags: list[str] | None,
116 params: RecallParams,
117 strategy: MultiBankStrategy,
118 ) -> RecallResult:
119 reqs: list[RecallRequest] = []
120 for bid in bank_ids:
121 reqs.append(await self._make_request(query, bid, max_results, None, tags, params))
122 tasks = [self._do_recall(r) for r in reqs]
124 try:
125 results = await asyncio.wait_for(
126 asyncio.gather(*tasks, return_exceptions=True),
127 timeout=30.0,
128 )
129 except asyncio.TimeoutError:
130 logger.error("Multi-bank parallel recall timed out after 30s for banks: %s", bank_ids)
131 self._metrics.inc_counter(
132 "astrocyte_multi_bank_recall_timeout_total",
133 {"bank_ids": ",".join(bank_ids)},
134 )
135 return RecallResult(hits=[], total_available=0, truncated=False)
137 all_hits: list[MemoryHit] = []
138 total_available = 0
139 for bid, result in zip(bank_ids, results):
140 if isinstance(result, RecallResult):
141 all_hits.extend(_tag_hits_with_bank(result.hits, bid))
142 total_available += result.total_available
143 elif isinstance(result, BaseException):
144 logger.warning("Multi-bank recall failed for bank '%s': %s", bid, result)
145 self._cb_record_failure()
146 self._metrics.inc_counter(
147 "astrocyte_recall_total",
148 {"bank_id": bid, "provider": self._provider_name, "status": "error"},
149 )
151 weighted = _apply_bank_weights(all_hits, strategy.bank_weights)
152 weighted.sort(key=lambda h: h.score, reverse=True)
154 if strategy.dedup_across_banks:
155 deduped = _dedupe_hits_by_text(weighted)
156 else:
157 deduped = weighted
159 trimmed = deduped[:max_results]
160 truncated = False
161 if max_tokens:
162 trimmed, truncated = enforce_token_budget(trimmed, max_tokens)
164 return RecallResult(hits=trimmed, total_available=total_available, truncated=truncated)
166 async def _cascade(
167 self,
168 query: str,
169 bank_ids: list[str],
170 max_results: int,
171 max_tokens: int | None,
172 tags: list[str] | None,
173 params: RecallParams,
174 strategy: MultiBankStrategy,
175 ) -> RecallResult:
176 order = _bank_visit_order(bank_ids, strategy.cascade_order)
177 accumulated: list[MemoryHit] = []
178 total_available = 0
180 for bid in order:
181 result = await self._do_recall(
182 await self._make_request(query, bid, max_results, None, tags, params),
183 )
184 total_available += result.total_available
185 accumulated.extend(_tag_hits_with_bank(result.hits, bid))
187 merged_for_stop = _dedupe_hits_by_text(accumulated) if strategy.dedup_across_banks else list(accumulated)
188 if len(merged_for_stop) >= strategy.min_results_to_stop:
189 break
191 working = _dedupe_hits_by_text(accumulated) if strategy.dedup_across_banks else accumulated
192 weighted = _apply_bank_weights(working, strategy.bank_weights)
193 weighted.sort(key=lambda h: h.score, reverse=True)
194 trimmed = weighted[:max_results]
195 truncated = False
196 if max_tokens:
197 trimmed, truncated = enforce_token_budget(trimmed, max_tokens)
198 return RecallResult(hits=trimmed, total_available=total_available, truncated=truncated)
200 async def _first_match(
201 self,
202 query: str,
203 bank_ids: list[str],
204 max_results: int,
205 max_tokens: int | None,
206 tags: list[str] | None,
207 params: RecallParams,
208 strategy: MultiBankStrategy,
209 ) -> RecallResult:
210 order = _bank_visit_order(bank_ids, strategy.cascade_order)
211 total_available = 0
212 for bid in order:
213 result = await self._do_recall(
214 await self._make_request(query, bid, max_results, None, tags, params),
215 )
216 total_available += result.total_available
217 if result.hits:
218 hits = _tag_hits_with_bank(result.hits, bid)
219 hits = hits[:max_results]
220 hits = _apply_bank_weights(hits, strategy.bank_weights)
221 hits.sort(key=lambda h: h.score, reverse=True)
222 truncated = False
223 if max_tokens:
224 hits, truncated = enforce_token_budget(hits, max_tokens)
225 return RecallResult(hits=hits, total_available=total_available, truncated=truncated)
226 return RecallResult(hits=[], total_available=total_available, truncated=False)