Coverage for astrocyte/pipeline/section_rerank.py: 98%
63 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Section cross-encoder rerank (M9 PR2 commit C).
3Sits between the section recall orchestrator (RRF-fused candidates) and
4the picker (PageIndex's reasoning loop). Two responsibilities:
61. **Cross-encoder rerank**: take the top-30 RRF-fused hits, score each
7 against the question with a cross-encoder, return the top-15. Same
8 pattern Hindsight uses (``cross-encoder/ms-marco-MiniLM-L-6-v2``);
9 we reuse the existing ``astrocyte.pipeline.cross_encoder_rerank``
10 plumbing rather than reinventing it.
122. **Picker-as-reranker constraint**: build a "constrained skeleton" —
13 the picker still sees a nested-dict tree, but only the 15 reranked
14 nodes appear (other nodes are pruned). This is the structural fix
15 for the v6 picker non-compliance: gpt-4o-mini reliably picks 5-10
16 from a curated 15 but degenerates to ``[1]`` when given the raw
17 30-node skeleton (proven in Phase A failure analysis).
19The picker's prompt format is unchanged — it just sees fewer nodes.
20No prompt tuning required. This is the cheapest accuracy lift in PR2.
22See:
23- ``docs/_design/recall.md`` §6, §8.3
24- ``docs/_design/adr/adr-006-three-layer-recall-stack.md``
25"""
27from __future__ import annotations
29import logging
30from typing import TYPE_CHECKING
32from astrocyte.pipeline.cross_encoder_rerank import (
33 CrossEncoderProtocol,
34 cross_encoder_rerank,
35)
36from astrocyte.pipeline.rerank_boosts import RerankBoostConfig, apply_boosts
37from astrocyte.pipeline.reranking import ScoredItem
39if TYPE_CHECKING:
40 from datetime import datetime
42 from astrocyte.pipeline.section_recall import FusedHit
43 from astrocyte.types import PageIndexSection
45logger = logging.getLogger("astrocyte.pipeline.section_rerank")
48def rerank_fused_hits(
49 fused: list[FusedHit],
50 sections_by_key: dict[tuple[str, int], PageIndexSection],
51 question: str,
52 *,
53 model: CrossEncoderProtocol | None = None,
54 rerank_top_k: int = 30,
55 output_top_k: int = 15,
56 query_range: tuple[datetime, datetime] | None = None,
57 proof_counts: dict[tuple[str, int], int] | None = None,
58 boost_config: RerankBoostConfig | None = None,
59) -> list[FusedHit]:
60 """Rerank top-K fused hits with a cross-encoder; return top-N.
62 Args:
63 fused: ``SectionRecallResult.fused`` (sorted by RRF score desc).
64 sections_by_key: Map ``(document_id, line_num) → PageIndexSection``
65 so we can fetch (title + summary) without re-querying the store.
66 Caller pre-builds this from the conv_tree's skeleton.
67 question: User question, fed to the cross-encoder.
68 model: Cross-encoder; ``None`` → default Hindsight model.
69 rerank_top_k: How many of the RRF top to actually rescore. The
70 cross-encoder is the slow part (transformer inference); we cap
71 at 30 by default. Items beyond this rank pass through with
72 their original RRF score.
73 output_top_k: Final length of the returned list.
74 query_range: Optional inferred date band from question_annotator.
75 Feeds the temporal-band boost; ``None`` makes it a no-op.
76 proof_counts: Optional ``(document_id, line_num) → fact count``
77 map. Feeds the proof-count boost; ``None`` makes it a no-op.
78 boost_config: Tuning + toggle for the multiplicative post-rerank
79 boosts. ``None`` uses defaults (enabled=True with Hindsight-
80 parity alphas).
82 Returns:
83 ``FusedHit`` list, sorted by cross-encoder score (modulated by
84 post-rerank boosts) descending, truncated to ``output_top_k``.
85 The ``rrf_score`` field is replaced with the cross-encoder score
86 (then multiplicatively boosted) for transparency.
87 """
88 if not fused:
89 return []
91 head = fused[:rerank_top_k]
92 items = []
93 for h in head:
94 section = sections_by_key.get((h.document_id, h.line_num))
95 # Build the rerank input text from title + summary. The body
96 # itself isn't needed at this stage; the picker fetches
97 # excerpts for the synth, not the reranker.
98 if section is None:
99 text = f"line {h.line_num}"
100 else:
101 title = section.title or ""
102 summary = section.summary or ""
103 text = f"{title}. {summary}".strip(" .")
104 items.append(
105 ScoredItem(
106 id=f"{h.document_id}:{h.line_num}",
107 text=text,
108 score=h.rrf_score,
109 )
110 )
112 rescored = cross_encoder_rerank(items, question, model=model)
114 # Map ScoredItem.id back to (doc, line) and emit FusedHit objects
115 # with the new score in ``rrf_score`` (we abuse the field for
116 # uniformity downstream — picker doesn't care which scorer wrote it).
117 from astrocyte.pipeline.section_recall import FusedHit # avoid circular import
119 out: list[FusedHit] = []
120 by_key = {(h.document_id, h.line_num): h for h in fused}
121 # Gap 3 (2026-05-16): pre-truncation boost. Apply post-rerank
122 # multiplicative boosts BEFORE the output_top_k cut so the boosts
123 # can re-order adjacent rank-15/16 candidates into the visible set
124 # rather than rescoring an already-truncated list.
125 rescored_full: list[FusedHit] = []
126 for item in rescored:
127 doc_id, line_str = item.id.split(":", 1)
128 line_num = int(line_str)
129 original = by_key.get((doc_id, line_num))
130 if original is None:
131 continue
132 rescored_full.append(
133 FusedHit(
134 document_id=doc_id,
135 line_num=line_num,
136 rrf_score=float(item.score),
137 per_strategy_rank=dict(original.per_strategy_rank),
138 )
139 )
141 boosted = apply_boosts(
142 rescored_full,
143 sections_by_key,
144 query_range=query_range,
145 proof_counts=proof_counts,
146 config=boost_config,
147 )
148 out = boosted[:output_top_k]
149 return out
152def build_constrained_skeleton(
153 full_skeleton: list | dict,
154 keep_keys: set[tuple[str, int]],
155 document_id: str,
156) -> list:
157 """Prune the picker's nested-dict skeleton to only the nodes in
158 ``keep_keys`` (preserving tree structure).
160 The picker's prompt format expects a nested-dict tree; the skeleton
161 we pass in PR1 was the FULL tree (~30 nodes for LoCoMo). PR2 commit
162 C narrows that to the top-15 reranked sections so the picker has a
163 much smaller, more relevant input.
165 A node is kept if EITHER:
166 a) ``(document_id, node.line_num)`` is in ``keep_keys``, OR
167 b) it has at least one descendant that is in ``keep_keys``
168 (so the path from root to a kept leaf survives — the picker's
169 tree-walk needs the parent chain).
171 Returns a list of root-level nodes. The caller passes this to the
172 picker the same way it would pass the full skeleton.
173 """
175 def _walk(node: dict) -> dict | None:
176 kept = (document_id, node.get("line_num")) in keep_keys
177 children = node.get("nodes")
178 kept_children: list[dict] = []
179 if isinstance(children, list):
180 for child in children:
181 if not isinstance(child, dict):
182 continue
183 pruned = _walk(child)
184 if pruned is not None:
185 kept_children.append(pruned)
186 if not kept and not kept_children:
187 return None
188 out = {k: v for k, v in node.items() if k != "nodes"}
189 if kept_children:
190 out["nodes"] = kept_children
191 return out
193 if isinstance(full_skeleton, dict):
194 full_skeleton = full_skeleton.get("structure", [full_skeleton])
196 pruned: list[dict] = []
197 for node in full_skeleton:
198 if not isinstance(node, dict):
199 continue
200 kept = _walk(node)
201 if kept is not None:
202 pruned.append(kept)
203 return pruned