Coverage for astrocyte/pipeline/cross_encoder_rerank.py: 71%
92 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Cross-encoder final-stage reranker (Hindsight parity).
3The retrieval stack uses a bi-encoder (embedding cosine) to fetch a
4broad candidate set, then this module's cross-encoder to rerank the
5top-K with full query/document attention. Cross-encoders score every
6(query, candidate) pair jointly — slower per-pair than bi-encoders, but
7substantially more accurate. The combination is the standard IR pattern
8Hindsight uses (see ``hindsight-docs/docs/developer/configuration.md``):
9default model ``cross-encoder/ms-marco-MiniLM-L-6-v2``, with pluggable
10local / FlashRank / jina-mlx backends.
12Design:
14- :class:`CrossEncoderProtocol` defines the minimal scoring surface.
15 Production uses :class:`SentenceTransformersCrossEncoder` (pulls
16 ``sentence-transformers`` and ``torch`` from the optional
17 ``[rerank]`` extras). Tests can pass a fake.
18- :func:`cross_encoder_rerank` reuses :class:`ScoredItem` from
19 :mod:`astrocyte.pipeline.reranking` so it slots into the existing
20 pipeline at the same boundary as ``cross_encoder_like_rerank``.
21- A module-level cache keys models by ``(model_name, force_cpu)`` so
22 repeated calls within a process amortize the load.
24Failure mode: when ``sentence-transformers`` isn't installed and no
25explicit model is supplied, callers fall back to the heuristic
26``cross_encoder_like_rerank``. The pipeline orchestrator threads this
27fallback automatically based on config.
28"""
30from __future__ import annotations
32import logging
33import os
34from threading import Lock
35from typing import Protocol, runtime_checkable
37from astrocyte.pipeline.reranking import ScoredItem
39_logger = logging.getLogger("astrocyte.cross_encoder_rerank")
41# ---------------------------------------------------------------------------
42# Protocol
43# ---------------------------------------------------------------------------
46@runtime_checkable
47class CrossEncoderProtocol(Protocol):
48 """Minimal scoring surface a cross-encoder backend must implement.
50 Returns a list of relevance scores in the same order as ``candidates``.
51 Higher = more relevant. Scale is backend-specific (sentence-transformers
52 cross-encoders return raw logits; FlashRank returns calibrated [0, 1]).
53 Reranking is order-preserving against the score vector, so absolute
54 scale doesn't matter — only relative ranking.
55 """
57 def score(self, query: str, candidates: list[str]) -> list[float]:
58 pass
61# ---------------------------------------------------------------------------
62# Sentence-transformers backend (default production implementation)
63# ---------------------------------------------------------------------------
66class SentenceTransformersCrossEncoder:
67 """Default backend wrapping ``sentence_transformers.CrossEncoder``.
69 Loads on first call; raises a clear error if the dependency isn't
70 installed. The Hindsight default model is
71 ``cross-encoder/ms-marco-MiniLM-L-6-v2`` — small (~80MB), CPU-fast,
72 and trained on MS MARCO passage ranking which transfers well to
73 open-domain QA reranking.
75 **Device selection** (added 2026-05-16):
76 By default the backend auto-detects Apple Silicon (MPS) and routes
77 inference there, giving ~3-5× speedup over CPU on Mac. Other
78 platforms use sentence-transformers' default (CUDA if available,
79 else CPU). Force a specific device via ``device=...`` or
80 ``force_cpu=True``.
82 For stronger quality, pass a preset model name from
83 ``APACHE2_MODEL_PRESETS``:
84 encoder = SentenceTransformersCrossEncoder(
85 model_name=APACHE2_MODEL_PRESETS["mxbai-base"],
86 )
87 """
89 def __init__(
90 self,
91 model_name: str = "cross-encoder/ms-marco-MiniLM-L-6-v2",
92 *,
93 force_cpu: bool = False,
94 device: str | None = None,
95 max_length: int = 512,
96 ) -> None:
97 self.model_name = model_name
98 self.force_cpu = force_cpu
99 self._explicit_device = device
100 self.max_length = max_length
101 self._model: object | None = None # lazily populated on first score()
103 def _resolve_device(self) -> str | None:
104 """Pick the device to load the model on. Returns None for sentence-transformers default.
106 Priority: explicit ``device=`` arg > ``force_cpu=True`` > Apple-Silicon MPS auto-detect > default.
107 """
108 if self._explicit_device is not None:
109 return self._explicit_device
110 if self.force_cpu:
111 return "cpu"
112 if is_mps_available():
113 return "mps"
114 return None # let sentence-transformers pick (CUDA/CPU)
116 def _load(self) -> object:
117 if self._model is not None:
118 return self._model
119 try:
120 from sentence_transformers import CrossEncoder # type: ignore
121 except ImportError as exc: # pragma: no cover — import-time failure
122 raise ImportError(
123 "Cross-encoder reranking requires the 'sentence-transformers' "
124 "package. Install with: pip install 'astrocyte[rerank]' "
125 "(or: pip install sentence-transformers torch)."
126 ) from exc
128 kwargs: dict[str, object] = {"max_length": self.max_length}
129 device = self._resolve_device()
130 if device is not None:
131 kwargs["device"] = device
133 _logger.info(
134 "Loading cross-encoder model %r (device=%s, force_cpu=%s)",
135 self.model_name,
136 device or "<st-default>",
137 self.force_cpu,
138 )
139 self._model = CrossEncoder(self.model_name, **kwargs)
140 return self._model
142 def score(self, query: str, candidates: list[str]) -> list[float]:
143 if not candidates:
144 return []
145 model = self._load()
146 pairs = [(query, candidate) for candidate in candidates]
147 # ``CrossEncoder.predict`` returns a numpy array; convert to plain
148 # floats so callers don't need numpy in their typing.
149 raw = model.predict(pairs) # type: ignore[attr-defined]
150 return [float(score) for score in raw]
153# ---------------------------------------------------------------------------
154# Apple Silicon device detection (MPS routing for sentence-transformers)
155# ---------------------------------------------------------------------------
158def is_apple_silicon() -> bool:
159 """Detect arm64 macOS — the platform where MPS (Metal) acceleration applies."""
160 import platform
162 return platform.system() == "Darwin" and platform.machine() == "arm64"
165def is_mps_available() -> bool:
166 """Detect whether torch's MPS backend is usable in this process.
168 Returns True only when:
169 - we're on Apple Silicon, AND
170 - torch is installed (transitively via [rerank] / sentence-transformers), AND
171 - torch.backends.mps.is_available() reports True.
173 Use this to opt into the MPS device automatically without forcing a
174 hard dependency on torch at import time.
175 """
176 if not is_apple_silicon():
177 return False
178 try:
179 import torch # type: ignore # noqa: PLC0415
180 except ImportError:
181 return False
182 try:
183 return bool(torch.backends.mps.is_available())
184 except Exception: # noqa: BLE001
185 return False
188# Apache-2.0 model presets — production-friendly defaults that pair well
189# with MPS on Apple Silicon and CUDA / CPU elsewhere. Pick via the
190# ``model_name`` arg to SentenceTransformersCrossEncoder.
191APACHE2_MODEL_PRESETS = {
192 # Smallest, fastest. Our historical default — good baseline.
193 "minilm": "cross-encoder/ms-marco-MiniLM-L-6-v2",
194 # mixedbread-ai/mxbai-rerank — production-focused, strong on BEIR.
195 # ~184M params; ~3-5× quality lift over MiniLM at ~2× the latency.
196 # Apache 2.0. Recommended default for new deployments.
197 "mxbai-base": "mixedbread-ai/mxbai-rerank-base-v2",
198 "mxbai-large": "mixedbread-ai/mxbai-rerank-large-v2",
199 # BAAI/bge-reranker — multilingual; well-benchmarked.
200 "bge-base": "BAAI/bge-reranker-base",
201 "bge-large": "BAAI/bge-reranker-large",
202 "bge-v2-m3": "BAAI/bge-reranker-v2-m3",
203}
206# ---------------------------------------------------------------------------
207# Module-level model cache
208# ---------------------------------------------------------------------------
210_model_cache: dict[tuple[str, bool], CrossEncoderProtocol] = {}
211_cache_lock = Lock()
213#: Built-in default cross-encoder. Matches the historical Hindsight default.
214#: Override at runtime via the ``ASTROCYTE_CROSS_ENCODER_MODEL`` env var
215#: (accepts a full HuggingFace path or a preset alias from
216#: :data:`APACHE2_MODEL_PRESETS`). M33-1a (v0.15.0): scaffolding for A/B-ing
217#: bge-reranker-large vs mxbai-rerank-large-v2.
218_DEFAULT_MODEL: str = "cross-encoder/ms-marco-MiniLM-L-6-v2"
219_ENV_VAR: str = "ASTROCYTE_CROSS_ENCODER_MODEL"
222def _resolve_default_model() -> str:
223 """Pick the active default cross-encoder model.
225 Resolution order:
227 1. ``ASTROCYTE_CROSS_ENCODER_MODEL`` env var, if set and non-empty.
228 Value may be a full HuggingFace path (``BAAI/bge-reranker-large``)
229 or a preset alias (``bge-large``); aliases resolve via
230 :data:`APACHE2_MODEL_PRESETS`.
231 2. :data:`_DEFAULT_MODEL` (MiniLM-L-6-v2) otherwise.
233 This is the seam used for bench A/B without code changes.
234 """
235 raw = os.environ.get(_ENV_VAR, "").strip()
236 if not raw:
237 return _DEFAULT_MODEL
238 return APACHE2_MODEL_PRESETS.get(raw, raw)
241def get_default_cross_encoder(
242 model_name: str | None = None,
243 *,
244 force_cpu: bool = False,
245) -> CrossEncoderProtocol:
246 """Return a cached :class:`SentenceTransformersCrossEncoder`.
248 Args:
249 model_name: Explicit model HF path. When ``None`` (default),
250 resolves via :func:`_resolve_default_model` so the
251 ``ASTROCYTE_CROSS_ENCODER_MODEL`` env var can override
252 without touching code.
253 force_cpu: Pin to CPU even when MPS/CUDA is available.
255 Threadsafe — concurrent first-load calls block on a lock so the
256 model is only loaded once. Subsequent calls return the cached
257 instance immediately.
258 """
259 resolved = model_name if model_name is not None else _resolve_default_model()
260 key = (resolved, force_cpu)
261 with _cache_lock:
262 cached = _model_cache.get(key)
263 if cached is None:
264 cached = SentenceTransformersCrossEncoder(
265 resolved,
266 force_cpu=force_cpu,
267 )
268 _model_cache[key] = cached
269 return cached
272def reset_default_cross_encoder_cache() -> None:
273 """Drop cached cross-encoder instances. Test-only."""
274 with _cache_lock:
275 _model_cache.clear()
278# ---------------------------------------------------------------------------
279# Reranking entry point
280# ---------------------------------------------------------------------------
283def cross_encoder_rerank(
284 items: list[ScoredItem],
285 query: str,
286 *,
287 model: CrossEncoderProtocol | None = None,
288 top_k: int | None = None,
289) -> list[ScoredItem]:
290 """Rerank ``items`` by a cross-encoder's joint relevance score.
292 Args:
293 items: Candidate items to rerank — typically the top-K from a
294 cheaper retrieval stage (bi-encoder or BM25).
295 query: The user query / synthesis prompt fragment to score
296 candidates against.
297 model: Cross-encoder backend. Defaults to the cached
298 :class:`SentenceTransformersCrossEncoder` with the Hindsight
299 default model.
300 top_k: When set, only the first ``top_k`` items are rescored;
301 the remainder is appended after the reranked head with their
302 original scores. Bounds inference cost on long candidate
303 lists. Default ``None`` (rescore everything).
305 Returns:
306 Items sorted by descending cross-encoder score. Items beyond
307 ``top_k`` retain their original score and follow the reranked
308 head in their original relative order.
309 """
310 if not items or not query:
311 return items
313 if model is None:
314 model = get_default_cross_encoder()
316 head = items if top_k is None else items[:top_k]
317 tail = [] if top_k is None else items[top_k:]
319 scores = model.score(query, [item.text for item in head])
320 if len(scores) != len(head): # pragma: no cover — backend contract violation
321 _logger.warning(
322 "cross_encoder model returned %d scores for %d items; falling back to original order.",
323 len(scores),
324 len(head),
325 )
326 return items
328 rescored = [
329 ScoredItem(
330 id=item.id,
331 text=item.text,
332 score=float(score),
333 fact_type=item.fact_type,
334 metadata=item.metadata,
335 tags=item.tags,
336 memory_layer=item.memory_layer,
337 occurred_at=item.occurred_at,
338 retained_at=item.retained_at,
339 )
340 for item, score in zip(head, scores, strict=True)
341 ]
342 rescored.sort(key=lambda x: x.score, reverse=True)
343 return rescored + tail