Coverage for astrocyte/pipeline/cross_encoder_rerank.py: 71%

92 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Cross-encoder final-stage reranker (Hindsight parity). 

2 

3The retrieval stack uses a bi-encoder (embedding cosine) to fetch a 

4broad candidate set, then this module's cross-encoder to rerank the 

5top-K with full query/document attention. Cross-encoders score every 

6(query, candidate) pair jointly — slower per-pair than bi-encoders, but 

7substantially more accurate. The combination is the standard IR pattern 

8Hindsight uses (see ``hindsight-docs/docs/developer/configuration.md``): 

9default model ``cross-encoder/ms-marco-MiniLM-L-6-v2``, with pluggable 

10local / FlashRank / jina-mlx backends. 

11 

12Design: 

13 

14- :class:`CrossEncoderProtocol` defines the minimal scoring surface. 

15 Production uses :class:`SentenceTransformersCrossEncoder` (pulls 

16 ``sentence-transformers`` and ``torch`` from the optional 

17 ``[rerank]`` extras). Tests can pass a fake. 

18- :func:`cross_encoder_rerank` reuses :class:`ScoredItem` from 

19 :mod:`astrocyte.pipeline.reranking` so it slots into the existing 

20 pipeline at the same boundary as ``cross_encoder_like_rerank``. 

21- A module-level cache keys models by ``(model_name, force_cpu)`` so 

22 repeated calls within a process amortize the load. 

23 

24Failure mode: when ``sentence-transformers`` isn't installed and no 

25explicit model is supplied, callers fall back to the heuristic 

26``cross_encoder_like_rerank``. The pipeline orchestrator threads this 

27fallback automatically based on config. 

28""" 

29 

30from __future__ import annotations 

31 

32import logging 

33import os 

34from threading import Lock 

35from typing import Protocol, runtime_checkable 

36 

37from astrocyte.pipeline.reranking import ScoredItem 

38 

39_logger = logging.getLogger("astrocyte.cross_encoder_rerank") 

40 

41# --------------------------------------------------------------------------- 

42# Protocol 

43# --------------------------------------------------------------------------- 

44 

45 

46@runtime_checkable 

47class CrossEncoderProtocol(Protocol): 

48 """Minimal scoring surface a cross-encoder backend must implement. 

49 

50 Returns a list of relevance scores in the same order as ``candidates``. 

51 Higher = more relevant. Scale is backend-specific (sentence-transformers 

52 cross-encoders return raw logits; FlashRank returns calibrated [0, 1]). 

53 Reranking is order-preserving against the score vector, so absolute 

54 scale doesn't matter — only relative ranking. 

55 """ 

56 

57 def score(self, query: str, candidates: list[str]) -> list[float]: 

58 pass 

59 

60 

61# --------------------------------------------------------------------------- 

62# Sentence-transformers backend (default production implementation) 

63# --------------------------------------------------------------------------- 

64 

65 

66class SentenceTransformersCrossEncoder: 

67 """Default backend wrapping ``sentence_transformers.CrossEncoder``. 

68 

69 Loads on first call; raises a clear error if the dependency isn't 

70 installed. The Hindsight default model is 

71 ``cross-encoder/ms-marco-MiniLM-L-6-v2`` — small (~80MB), CPU-fast, 

72 and trained on MS MARCO passage ranking which transfers well to 

73 open-domain QA reranking. 

74 

75 **Device selection** (added 2026-05-16): 

76 By default the backend auto-detects Apple Silicon (MPS) and routes 

77 inference there, giving ~3-5× speedup over CPU on Mac. Other 

78 platforms use sentence-transformers' default (CUDA if available, 

79 else CPU). Force a specific device via ``device=...`` or 

80 ``force_cpu=True``. 

81 

82 For stronger quality, pass a preset model name from 

83 ``APACHE2_MODEL_PRESETS``: 

84 encoder = SentenceTransformersCrossEncoder( 

85 model_name=APACHE2_MODEL_PRESETS["mxbai-base"], 

86 ) 

87 """ 

88 

89 def __init__( 

90 self, 

91 model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2", 

92 *, 

93 force_cpu: bool = False, 

94 device: str | None = None, 

95 max_length: int = 512, 

96 ) -> None: 

97 self.model_name = model_name 

98 self.force_cpu = force_cpu 

99 self._explicit_device = device 

100 self.max_length = max_length 

101 self._model: object | None = None # lazily populated on first score() 

102 

103 def _resolve_device(self) -> str | None: 

104 """Pick the device to load the model on. Returns None for sentence-transformers default. 

105 

106 Priority: explicit ``device=`` arg > ``force_cpu=True`` > Apple-Silicon MPS auto-detect > default. 

107 """ 

108 if self._explicit_device is not None: 

109 return self._explicit_device 

110 if self.force_cpu: 

111 return "cpu" 

112 if is_mps_available(): 

113 return "mps" 

114 return None # let sentence-transformers pick (CUDA/CPU) 

115 

116 def _load(self) -> object: 

117 if self._model is not None: 

118 return self._model 

119 try: 

120 from sentence_transformers import CrossEncoder # type: ignore 

121 except ImportError as exc: # pragma: no cover — import-time failure 

122 raise ImportError( 

123 "Cross-encoder reranking requires the 'sentence-transformers' " 

124 "package. Install with: pip install 'astrocyte[rerank]' " 

125 "(or: pip install sentence-transformers torch)." 

126 ) from exc 

127 

128 kwargs: dict[str, object] = {"max_length": self.max_length} 

129 device = self._resolve_device() 

130 if device is not None: 

131 kwargs["device"] = device 

132 

133 _logger.info( 

134 "Loading cross-encoder model %r (device=%s, force_cpu=%s)", 

135 self.model_name, 

136 device or "<st-default>", 

137 self.force_cpu, 

138 ) 

139 self._model = CrossEncoder(self.model_name, **kwargs) 

140 return self._model 

141 

142 def score(self, query: str, candidates: list[str]) -> list[float]: 

143 if not candidates: 

144 return [] 

145 model = self._load() 

146 pairs = [(query, candidate) for candidate in candidates] 

147 # ``CrossEncoder.predict`` returns a numpy array; convert to plain 

148 # floats so callers don't need numpy in their typing. 

149 raw = model.predict(pairs) # type: ignore[attr-defined] 

150 return [float(score) for score in raw] 

151 

152 

153# --------------------------------------------------------------------------- 

154# Apple Silicon device detection (MPS routing for sentence-transformers) 

155# --------------------------------------------------------------------------- 

156 

157 

158def is_apple_silicon() -> bool: 

159 """Detect arm64 macOS — the platform where MPS (Metal) acceleration applies.""" 

160 import platform 

161 

162 return platform.system() == "Darwin" and platform.machine() == "arm64" 

163 

164 

165def is_mps_available() -> bool: 

166 """Detect whether torch's MPS backend is usable in this process. 

167 

168 Returns True only when: 

169 - we're on Apple Silicon, AND 

170 - torch is installed (transitively via [rerank] / sentence-transformers), AND 

171 - torch.backends.mps.is_available() reports True. 

172 

173 Use this to opt into the MPS device automatically without forcing a 

174 hard dependency on torch at import time. 

175 """ 

176 if not is_apple_silicon(): 

177 return False 

178 try: 

179 import torch # type: ignore # noqa: PLC0415 

180 except ImportError: 

181 return False 

182 try: 

183 return bool(torch.backends.mps.is_available()) 

184 except Exception: # noqa: BLE001 

185 return False 

186 

187 

188# Apache-2.0 model presets — production-friendly defaults that pair well 

189# with MPS on Apple Silicon and CUDA / CPU elsewhere. Pick via the 

190# ``model_name`` arg to SentenceTransformersCrossEncoder. 

191APACHE2_MODEL_PRESETS = { 

192 # Smallest, fastest. Our historical default — good baseline. 

193 "minilm": "cross-encoder/ms-marco-MiniLM-L-6-v2", 

194 # mixedbread-ai/mxbai-rerank — production-focused, strong on BEIR. 

195 # ~184M params; ~3-5× quality lift over MiniLM at ~2× the latency. 

196 # Apache 2.0. Recommended default for new deployments. 

197 "mxbai-base": "mixedbread-ai/mxbai-rerank-base-v2", 

198 "mxbai-large": "mixedbread-ai/mxbai-rerank-large-v2", 

199 # BAAI/bge-reranker — multilingual; well-benchmarked. 

200 "bge-base": "BAAI/bge-reranker-base", 

201 "bge-large": "BAAI/bge-reranker-large", 

202 "bge-v2-m3": "BAAI/bge-reranker-v2-m3", 

203} 

204 

205 

206# --------------------------------------------------------------------------- 

207# Module-level model cache 

208# --------------------------------------------------------------------------- 

209 

210_model_cache: dict[tuple[str, bool], CrossEncoderProtocol] = {} 

211_cache_lock = Lock() 

212 

213#: Built-in default cross-encoder. Matches the historical Hindsight default. 

214#: Override at runtime via the ``ASTROCYTE_CROSS_ENCODER_MODEL`` env var 

215#: (accepts a full HuggingFace path or a preset alias from 

216#: :data:`APACHE2_MODEL_PRESETS`). M33-1a (v0.15.0): scaffolding for A/B-ing 

217#: bge-reranker-large vs mxbai-rerank-large-v2. 

218_DEFAULT_MODEL: str = "cross-encoder/ms-marco-MiniLM-L-6-v2" 

219_ENV_VAR: str = "ASTROCYTE_CROSS_ENCODER_MODEL" 

220 

221 

222def _resolve_default_model() -> str: 

223 """Pick the active default cross-encoder model. 

224 

225 Resolution order: 

226 

227 1. ``ASTROCYTE_CROSS_ENCODER_MODEL`` env var, if set and non-empty. 

228 Value may be a full HuggingFace path (``BAAI/bge-reranker-large``) 

229 or a preset alias (``bge-large``); aliases resolve via 

230 :data:`APACHE2_MODEL_PRESETS`. 

231 2. :data:`_DEFAULT_MODEL` (MiniLM-L-6-v2) otherwise. 

232 

233 This is the seam used for bench A/B without code changes. 

234 """ 

235 raw = os.environ.get(_ENV_VAR, "").strip() 

236 if not raw: 

237 return _DEFAULT_MODEL 

238 return APACHE2_MODEL_PRESETS.get(raw, raw) 

239 

240 

241def get_default_cross_encoder( 

242 model_name: str | None = None, 

243 *, 

244 force_cpu: bool = False, 

245) -> CrossEncoderProtocol: 

246 """Return a cached :class:`SentenceTransformersCrossEncoder`. 

247 

248 Args: 

249 model_name: Explicit model HF path. When ``None`` (default), 

250 resolves via :func:`_resolve_default_model` so the 

251 ``ASTROCYTE_CROSS_ENCODER_MODEL`` env var can override 

252 without touching code. 

253 force_cpu: Pin to CPU even when MPS/CUDA is available. 

254 

255 Threadsafe — concurrent first-load calls block on a lock so the 

256 model is only loaded once. Subsequent calls return the cached 

257 instance immediately. 

258 """ 

259 resolved = model_name if model_name is not None else _resolve_default_model() 

260 key = (resolved, force_cpu) 

261 with _cache_lock: 

262 cached = _model_cache.get(key) 

263 if cached is None: 

264 cached = SentenceTransformersCrossEncoder( 

265 resolved, 

266 force_cpu=force_cpu, 

267 ) 

268 _model_cache[key] = cached 

269 return cached 

270 

271 

272def reset_default_cross_encoder_cache() -> None: 

273 """Drop cached cross-encoder instances. Test-only.""" 

274 with _cache_lock: 

275 _model_cache.clear() 

276 

277 

278# --------------------------------------------------------------------------- 

279# Reranking entry point 

280# --------------------------------------------------------------------------- 

281 

282 

283def cross_encoder_rerank( 

284 items: list[ScoredItem], 

285 query: str, 

286 *, 

287 model: CrossEncoderProtocol | None = None, 

288 top_k: int | None = None, 

289) -> list[ScoredItem]: 

290 """Rerank ``items`` by a cross-encoder's joint relevance score. 

291 

292 Args: 

293 items: Candidate items to rerank — typically the top-K from a 

294 cheaper retrieval stage (bi-encoder or BM25). 

295 query: The user query / synthesis prompt fragment to score 

296 candidates against. 

297 model: Cross-encoder backend. Defaults to the cached 

298 :class:`SentenceTransformersCrossEncoder` with the Hindsight 

299 default model. 

300 top_k: When set, only the first ``top_k`` items are rescored; 

301 the remainder is appended after the reranked head with their 

302 original scores. Bounds inference cost on long candidate 

303 lists. Default ``None`` (rescore everything). 

304 

305 Returns: 

306 Items sorted by descending cross-encoder score. Items beyond 

307 ``top_k`` retain their original score and follow the reranked 

308 head in their original relative order. 

309 """ 

310 if not items or not query: 

311 return items 

312 

313 if model is None: 

314 model = get_default_cross_encoder() 

315 

316 head = items if top_k is None else items[:top_k] 

317 tail = [] if top_k is None else items[top_k:] 

318 

319 scores = model.score(query, [item.text for item in head]) 

320 if len(scores) != len(head): # pragma: no cover — backend contract violation 

321 _logger.warning( 

322 "cross_encoder model returned %d scores for %d items; falling back to original order.", 

323 len(scores), 

324 len(head), 

325 ) 

326 return items 

327 

328 rescored = [ 

329 ScoredItem( 

330 id=item.id, 

331 text=item.text, 

332 score=float(score), 

333 fact_type=item.fact_type, 

334 metadata=item.metadata, 

335 tags=item.tags, 

336 memory_layer=item.memory_layer, 

337 occurred_at=item.occurred_at, 

338 retained_at=item.retained_at, 

339 ) 

340 for item, score in zip(head, scores, strict=True) 

341 ] 

342 rescored.sort(key=lambda x: x.score, reverse=True) 

343 return rescored + tail