Coverage for astrocyte/_multi_bank.py: 94%

126 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Multi-bank recall orchestrator — parallel, cascade, and first-match strategies.""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import logging 

7from collections.abc import Awaitable, Callable 

8from dataclasses import replace 

9 

10from astrocyte._recall_params import RecallParams 

11from astrocyte.errors import ConfigError 

12from astrocyte.policy.homeostasis import enforce_token_budget 

13from astrocyte.policy.observability import MetricsCollector 

14from astrocyte.types import ( 

15 MemoryHit, 

16 MultiBankStrategy, 

17 RecallRequest, 

18 RecallResult, 

19) 

20 

21logger = logging.getLogger("astrocyte") 

22 

23# Type aliases for the callbacks injected from Astrocyte 

24RecallFn = Callable[[RecallRequest], Awaitable[RecallResult]] 

25MakeRequestFn = Callable[ 

26 [str, str, int, int | None, list[str] | None, RecallParams], 

27 Awaitable[RecallRequest], 

28] 

29 

30 

31def _bank_visit_order(bank_ids: list[str], cascade_order: list[str] | None) -> list[str]: 

32 if not cascade_order: 

33 return list(bank_ids) 

34 out: list[str] = [] 

35 seen: set[str] = set() 

36 for b in cascade_order: 

37 if b in bank_ids and b not in seen: 

38 out.append(b) 

39 seen.add(b) 

40 for b in bank_ids: 

41 if b not in seen: 

42 out.append(b) 

43 seen.add(b) 

44 return out 

45 

46 

47def _dedupe_hits_by_text(hits: list[MemoryHit]) -> list[MemoryHit]: 

48 """One hit per distinct text, keeping the highest-scoring instance.""" 

49 best: dict[str, MemoryHit] = {} 

50 for h in hits: 

51 prev = best.get(h.text) 

52 if prev is None or h.score > prev.score: 

53 best[h.text] = h 

54 return sorted(best.values(), key=lambda x: x.score, reverse=True) 

55 

56 

57def _apply_bank_weights(hits: list[MemoryHit], weights: dict[str, float] | None) -> list[MemoryHit]: 

58 if not weights: 

59 return list(hits) 

60 out: list[MemoryHit] = [] 

61 for h in hits: 

62 bid = h.bank_id or "" 

63 w = float(weights.get(bid, 1.0)) 

64 out.append(replace(h, score=h.score * w)) 

65 return out 

66 

67 

68def _tag_hits_with_bank(hits: list[MemoryHit], bank_id: str) -> list[MemoryHit]: 

69 return [replace(h, bank_id=bank_id) if h.bank_id is None else h for h in hits] 

70 

71 

72class MultiBankOrchestrator: 

73 """Dispatches multi-bank recall across parallel, cascade, and first-match strategies.""" 

74 

75 def __init__( 

76 self, 

77 *, 

78 do_recall: RecallFn, 

79 make_request: MakeRequestFn, 

80 circuit_breaker_record_failure: Callable[[], None], 

81 metrics: MetricsCollector, 

82 provider_name: str, 

83 ) -> None: 

84 self._do_recall = do_recall 

85 self._make_request = make_request 

86 self._cb_record_failure = circuit_breaker_record_failure 

87 self._metrics = metrics 

88 self._provider_name = provider_name 

89 

90 async def recall( 

91 self, 

92 query: str, 

93 bank_ids: list[str], 

94 max_results: int, 

95 max_tokens: int | None, 

96 tags: list[str] | None, 

97 params: RecallParams, 

98 strategy: MultiBankStrategy, 

99 ) -> RecallResult: 

100 """Multi-bank recall — strategy dispatch.""" 

101 if strategy.mode == "parallel": 

102 return await self._parallel(query, bank_ids, max_results, max_tokens, tags, params, strategy) 

103 if strategy.mode == "cascade": 

104 return await self._cascade(query, bank_ids, max_results, max_tokens, tags, params, strategy) 

105 if strategy.mode == "first_match": 

106 return await self._first_match(query, bank_ids, max_results, max_tokens, tags, params, strategy) 

107 raise ConfigError(f"Unknown multi-bank mode: {strategy.mode!r}") 

108 

109 async def _parallel( 

110 self, 

111 query: str, 

112 bank_ids: list[str], 

113 max_results: int, 

114 max_tokens: int | None, 

115 tags: list[str] | None, 

116 params: RecallParams, 

117 strategy: MultiBankStrategy, 

118 ) -> RecallResult: 

119 reqs: list[RecallRequest] = [] 

120 for bid in bank_ids: 

121 reqs.append(await self._make_request(query, bid, max_results, None, tags, params)) 

122 tasks = [self._do_recall(r) for r in reqs] 

123 

124 try: 

125 results = await asyncio.wait_for( 

126 asyncio.gather(*tasks, return_exceptions=True), 

127 timeout=30.0, 

128 ) 

129 except asyncio.TimeoutError: 

130 logger.error("Multi-bank parallel recall timed out after 30s for banks: %s", bank_ids) 

131 self._metrics.inc_counter( 

132 "astrocyte_multi_bank_recall_timeout_total", 

133 {"bank_ids": ",".join(bank_ids)}, 

134 ) 

135 return RecallResult(hits=[], total_available=0, truncated=False) 

136 

137 all_hits: list[MemoryHit] = [] 

138 total_available = 0 

139 for bid, result in zip(bank_ids, results): 

140 if isinstance(result, RecallResult): 

141 all_hits.extend(_tag_hits_with_bank(result.hits, bid)) 

142 total_available += result.total_available 

143 elif isinstance(result, BaseException): 

144 logger.warning("Multi-bank recall failed for bank '%s': %s", bid, result) 

145 self._cb_record_failure() 

146 self._metrics.inc_counter( 

147 "astrocyte_recall_total", 

148 {"bank_id": bid, "provider": self._provider_name, "status": "error"}, 

149 ) 

150 

151 weighted = _apply_bank_weights(all_hits, strategy.bank_weights) 

152 weighted.sort(key=lambda h: h.score, reverse=True) 

153 

154 if strategy.dedup_across_banks: 

155 deduped = _dedupe_hits_by_text(weighted) 

156 else: 

157 deduped = weighted 

158 

159 trimmed = deduped[:max_results] 

160 truncated = False 

161 if max_tokens: 

162 trimmed, truncated = enforce_token_budget(trimmed, max_tokens) 

163 

164 return RecallResult(hits=trimmed, total_available=total_available, truncated=truncated) 

165 

166 async def _cascade( 

167 self, 

168 query: str, 

169 bank_ids: list[str], 

170 max_results: int, 

171 max_tokens: int | None, 

172 tags: list[str] | None, 

173 params: RecallParams, 

174 strategy: MultiBankStrategy, 

175 ) -> RecallResult: 

176 order = _bank_visit_order(bank_ids, strategy.cascade_order) 

177 accumulated: list[MemoryHit] = [] 

178 total_available = 0 

179 

180 for bid in order: 

181 result = await self._do_recall( 

182 await self._make_request(query, bid, max_results, None, tags, params), 

183 ) 

184 total_available += result.total_available 

185 accumulated.extend(_tag_hits_with_bank(result.hits, bid)) 

186 

187 merged_for_stop = _dedupe_hits_by_text(accumulated) if strategy.dedup_across_banks else list(accumulated) 

188 if len(merged_for_stop) >= strategy.min_results_to_stop: 

189 break 

190 

191 working = _dedupe_hits_by_text(accumulated) if strategy.dedup_across_banks else accumulated 

192 weighted = _apply_bank_weights(working, strategy.bank_weights) 

193 weighted.sort(key=lambda h: h.score, reverse=True) 

194 trimmed = weighted[:max_results] 

195 truncated = False 

196 if max_tokens: 

197 trimmed, truncated = enforce_token_budget(trimmed, max_tokens) 

198 return RecallResult(hits=trimmed, total_available=total_available, truncated=truncated) 

199 

200 async def _first_match( 

201 self, 

202 query: str, 

203 bank_ids: list[str], 

204 max_results: int, 

205 max_tokens: int | None, 

206 tags: list[str] | None, 

207 params: RecallParams, 

208 strategy: MultiBankStrategy, 

209 ) -> RecallResult: 

210 order = _bank_visit_order(bank_ids, strategy.cascade_order) 

211 total_available = 0 

212 for bid in order: 

213 result = await self._do_recall( 

214 await self._make_request(query, bid, max_results, None, tags, params), 

215 ) 

216 total_available += result.total_available 

217 if result.hits: 

218 hits = _tag_hits_with_bank(result.hits, bid) 

219 hits = hits[:max_results] 

220 hits = _apply_bank_weights(hits, strategy.bank_weights) 

221 hits.sort(key=lambda h: h.score, reverse=True) 

222 truncated = False 

223 if max_tokens: 

224 hits, truncated = enforce_token_budget(hits, max_tokens) 

225 return RecallResult(hits=hits, total_available=total_available, truncated=truncated) 

226 return RecallResult(hits=[], total_available=total_available, truncated=False)