Coverage for astrocyte/pipeline/rerank_boosts.py: 99%
83 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Post-rerank multiplicative boosts (M14 Gap 3 closure).
3After the cross-encoder rerank in ``section_rerank.py`` produces a
4score for each candidate, we compose three additional signals
5multiplicatively:
71. **Recency boost** — exponential decay on session age. More recent
8 sessions score higher. ``half_life_days=180`` matches Hindsight's
9 ~365-day linear curve at the midpoint; bounded to [0.5, 1.0] so a
10 stale section is at most halved, never zeroed.
122. **Temporal-band intersection** — when the question carries a
13 ``date_range`` (from ``question_annotator``), sections whose own
14 time range (session_date or [occurred_start, occurred_end]) overlap
15 the question's band get a multiplicative bonus. Sections with no
16 date are neutral (multiplier 1.0).
183. **Proof-count boost** — log-normalised count of facts linked to a
19 section. Sections backed by more atomic facts are stronger
20 evidence. Bounded to [1.0, 1.5].
22Composition is multiplicative (Hindsight pattern, see
23``hindsight-api-slim/hindsight_api/engine/search/reranking.py``):
25 final = CE_score × recency × temporal_band × proof_count
27The multiplicative form avoids cancellation (an additive scheme can
28let a strong negative signal wipe out a strong positive). With the
29default alphas the worst case is ~-19% and the best ~+21%, so the
30cross-encoder remains the primary ordering signal — the boosts
31nudge ties and adjacent ranks.
33Gated behind a config flag (``RerankBoostConfig.enabled``) so the
34component can be ablated independently for bench gates. Default ON
35once the gate clears.
36"""
38from __future__ import annotations
40import math
41from dataclasses import dataclass
42from datetime import datetime, timezone
43from typing import TYPE_CHECKING
45if TYPE_CHECKING:
46 from astrocyte.pipeline.section_recall import FusedHit
47 from astrocyte.types import PageIndexSection
50UTC = timezone.utc
53# Default multiplicative alphas — match Hindsight's published values.
54# Each signal contributes at most ±(alpha/2) relative adjustment.
55_RECENCY_ALPHA: float = 0.2 # ±10%
56_TEMPORAL_ALPHA: float = 0.2 # ±10%
57_PROOF_COUNT_ALPHA: float = 0.1 # ±5%
60@dataclass
61class RerankBoostConfig:
62 """Toggle + tuning for post-rerank boosts. Default values mirror
63 Hindsight's tuned defaults (see ``reranking.py`` upstream).
64 """
66 enabled: bool = True
67 recency_alpha: float = _RECENCY_ALPHA
68 temporal_alpha: float = _TEMPORAL_ALPHA
69 proof_count_alpha: float = _PROOF_COUNT_ALPHA
70 #: Linear-decay window for recency. Sections older than this fall
71 #: to the floor (0.1). 365 days matches Hindsight's default.
72 recency_window_days: float = 365.0
75def _section_date(section: PageIndexSection | None) -> datetime | None:
76 """Pick the best timestamp for recency: prefer event date over
77 session date when both are present, since the event date is when
78 the *content* actually happened.
80 Defensive against duck-typed inputs: returns None if the object
81 doesn't have ``occurred_start`` or ``session_date`` attributes
82 (e.g. a ``FusedHit`` from section_recall, which only carries
83 ``document_id``/``line_num``/``rrf_score``).
84 """
85 if section is None:
86 return None
87 return getattr(section, "occurred_start", None) or getattr(section, "session_date", None)
90def _normalize_to_utc(dt: datetime) -> datetime:
91 return dt.replace(tzinfo=UTC) if dt.tzinfo is None else dt
94def recency_score(
95 section_date: datetime | None,
96 now: datetime,
97 *,
98 window_days: float = 365.0,
99) -> float:
100 """Linear decay over ``window_days`` → [0.1, 1.0]; neutral 0.5 if
101 no date. Same shape Hindsight uses; floor of 0.1 keeps very-old
102 items from being de-ranked to oblivion.
103 """
104 if section_date is None:
105 return 0.5
106 section_date = _normalize_to_utc(section_date)
107 now = _normalize_to_utc(now)
108 days_ago = (now - section_date).total_seconds() / 86400.0
109 return max(0.1, min(1.0, 1.0 - (days_ago / window_days)))
112def temporal_band_score(
113 section: PageIndexSection | None,
114 query_range: tuple[datetime, datetime] | None,
115) -> float:
116 """0.5 neutral when the question has no date band OR the section
117 has no date. 1.0 when the section's date range overlaps the
118 query's. 0.2 when both are dated and disjoint (penalise wrong-
119 period sections — they were probably promoted by topical relevance
120 despite missing the temporal target).
122 Defensive against duck-typed inputs (e.g. ``FusedHit`` lacks
123 ``occurred_start``/``session_date``): collapses to neutral 0.5
124 when the passed object doesn't have the needed attributes.
125 """
126 if query_range is None or section is None:
127 return 0.5
128 q_start, q_end = query_range
129 q_start = _normalize_to_utc(q_start)
130 q_end = _normalize_to_utc(q_end)
132 # Prefer event range when populated, else collapse to session_date.
133 occurred_start = getattr(section, "occurred_start", None)
134 occurred_end = getattr(section, "occurred_end", None)
135 session_date = getattr(section, "session_date", None)
136 s_start = occurred_start or session_date
137 s_end = occurred_end or occurred_start or session_date
138 if s_start is None or s_end is None:
139 return 0.5
140 s_start = _normalize_to_utc(s_start)
141 s_end = _normalize_to_utc(s_end)
143 # Overlap iff start <= other.end and end >= other.start.
144 if s_start <= q_end and s_end >= q_start:
145 return 1.0
146 return 0.2
149def proof_count_score(proof_count: int | None) -> float:
150 """Log-normalised mapping ``proof_count`` → [0.5, 1.0]. A single-
151 fact section is neutral (0.5); high-evidence sections approach the
152 cap. Sections with no fact count attached are also neutral so the
153 boost collapses to 1.0.
154 """
155 if proof_count is None or proof_count < 1:
156 return 0.5
157 # log curve centred at 0.5, clamped to 1.0 around proof_count=150.
158 return min(1.0, max(0.0, 0.5 + (math.log(proof_count) / 10.0)))
161def compute_boost_multiplier(
162 section: PageIndexSection | None,
163 *,
164 query_range: tuple[datetime, datetime] | None = None,
165 proof_count: int | None = None,
166 now: datetime | None = None,
167 config: RerankBoostConfig | None = None,
168 section_date_override: datetime | None = None,
169) -> float:
170 """Return the multiplicative score boost for a single candidate.
172 Reusable helper so callers outside ``section_rerank`` (notably the
173 bench's unified fact+section+wiki rerank in ``astrocyte_client``)
174 can apply the same Hindsight-parity bounded boosts.
176 Args:
177 section: The ``PageIndexSection`` associated with the candidate.
178 ``None`` for candidates without a section anchor (e.g. wiki
179 grain, top-level facts) — collapses temporal_band to neutral
180 (0.5 → boost 1.0).
181 query_range: Inferred date band from question analyzer. ``None``
182 collapses the temporal_band boost to 1.0.
183 proof_count: Optional fact count for this section. ``None``
184 collapses the proof_count boost to 1.0.
185 now: Reference timestamp for recency. Defaults to
186 ``datetime.now(UTC)``.
187 config: Toggles + alphas. ``None`` uses defaults (``enabled=True``).
188 When ``config.enabled`` is False, returns 1.0 (no-op multiplier).
189 section_date_override: Use this datetime for the recency calculation
190 instead of the section's own date. Lets fact-grain callers pass
191 ``fact.occurred_start`` directly without round-tripping through
192 a section lookup.
194 Returns:
195 A float multiplier in roughly ``[0.81, 1.21]`` with default alphas.
196 Caller multiplies its CE score by this and re-sorts.
197 """
198 if config is None:
199 config = RerankBoostConfig()
200 if not config.enabled:
201 return 1.0
202 if now is None:
203 now = datetime.now(UTC)
204 now = _normalize_to_utc(now)
206 section_dt = section_date_override if section_date_override is not None else _section_date(section)
207 recency = recency_score(section_dt, now, window_days=config.recency_window_days)
208 temporal = temporal_band_score(section, query_range)
209 proof = proof_count_score(proof_count)
211 recency_boost = 1.0 + config.recency_alpha * (recency - 0.5)
212 temporal_boost = 1.0 + config.temporal_alpha * (temporal - 0.5)
213 proof_count_boost = 1.0 + config.proof_count_alpha * (proof - 0.5)
214 return recency_boost * temporal_boost * proof_count_boost
217def apply_boosts(
218 hits: list[FusedHit],
219 sections_by_key: dict[tuple[str, int], PageIndexSection],
220 *,
221 query_range: tuple[datetime, datetime] | None = None,
222 proof_counts: dict[tuple[str, int], int] | None = None,
223 now: datetime | None = None,
224 config: RerankBoostConfig | None = None,
225) -> list[FusedHit]:
226 """Apply multiplicative post-rerank boosts and re-sort.
228 Args:
229 hits: Cross-encoder-reranked hits. ``rrf_score`` carries the CE
230 score after ``section_rerank.rerank_fused_hits`` runs.
231 sections_by_key: ``(document_id, line_num) → PageIndexSection``
232 for date lookups.
233 query_range: Inferred date band from question_annotator. ``None``
234 collapses the temporal_band boost to 1.0.
235 proof_counts: Optional ``(document_id, line_num) → int`` mapping
236 of facts linked to each section. ``None`` collapses the
237 proof_count boost to 1.0.
238 now: Reference timestamp for recency. Defaults to ``datetime.now(UTC)``.
239 config: Toggles + alphas. ``None`` uses defaults (enabled=True).
241 Returns:
242 A new sorted list. Input is not mutated. When ``config.enabled``
243 is False, returns ``hits`` unchanged.
244 """
245 if config is None:
246 config = RerankBoostConfig()
247 if not config.enabled or not hits:
248 return list(hits)
250 if now is None:
251 now = datetime.now(UTC)
252 now = _normalize_to_utc(now)
254 from astrocyte.pipeline.section_recall import FusedHit # avoid circular import
256 boosted: list[FusedHit] = []
257 for h in hits:
258 section = sections_by_key.get((h.document_id, h.line_num))
259 multiplier = compute_boost_multiplier(
260 section,
261 query_range=query_range,
262 proof_count=(proof_counts or {}).get((h.document_id, h.line_num)),
263 now=now,
264 config=config,
265 )
267 boosted.append(
268 FusedHit(
269 document_id=h.document_id,
270 line_num=h.line_num,
271 rrf_score=h.rrf_score * multiplier,
272 per_strategy_rank=dict(h.per_strategy_rank),
273 )
274 )
276 boosted.sort(key=lambda h: h.rrf_score, reverse=True)
277 return boosted