Coverage for astrocyte/pipeline/section_rerank.py: 98%

63 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Section cross-encoder rerank (M9 PR2 commit C). 

2 

3Sits between the section recall orchestrator (RRF-fused candidates) and 

4the picker (PageIndex's reasoning loop). Two responsibilities: 

5 

61. **Cross-encoder rerank**: take the top-30 RRF-fused hits, score each 

7 against the question with a cross-encoder, return the top-15. Same 

8 pattern Hindsight uses (``cross-encoder/ms-marco-MiniLM-L-6-v2``); 

9 we reuse the existing ``astrocyte.pipeline.cross_encoder_rerank`` 

10 plumbing rather than reinventing it. 

11 

122. **Picker-as-reranker constraint**: build a "constrained skeleton" — 

13 the picker still sees a nested-dict tree, but only the 15 reranked 

14 nodes appear (other nodes are pruned). This is the structural fix 

15 for the v6 picker non-compliance: gpt-4o-mini reliably picks 5-10 

16 from a curated 15 but degenerates to ``[1]`` when given the raw 

17 30-node skeleton (proven in Phase A failure analysis). 

18 

19The picker's prompt format is unchanged — it just sees fewer nodes. 

20No prompt tuning required. This is the cheapest accuracy lift in PR2. 

21 

22See: 

23- ``docs/_design/recall.md`` §6, §8.3 

24- ``docs/_design/adr/adr-006-three-layer-recall-stack.md`` 

25""" 

26 

27from __future__ import annotations 

28 

29import logging 

30from typing import TYPE_CHECKING 

31 

32from astrocyte.pipeline.cross_encoder_rerank import ( 

33 CrossEncoderProtocol, 

34 cross_encoder_rerank, 

35) 

36from astrocyte.pipeline.rerank_boosts import RerankBoostConfig, apply_boosts 

37from astrocyte.pipeline.reranking import ScoredItem 

38 

39if TYPE_CHECKING: 

40 from datetime import datetime 

41 

42 from astrocyte.pipeline.section_recall import FusedHit 

43 from astrocyte.types import PageIndexSection 

44 

45logger = logging.getLogger("astrocyte.pipeline.section_rerank") 

46 

47 

48def rerank_fused_hits( 

49 fused: list[FusedHit], 

50 sections_by_key: dict[tuple[str, int], PageIndexSection], 

51 question: str, 

52 *, 

53 model: CrossEncoderProtocol | None = None, 

54 rerank_top_k: int = 30, 

55 output_top_k: int = 15, 

56 query_range: tuple[datetime, datetime] | None = None, 

57 proof_counts: dict[tuple[str, int], int] | None = None, 

58 boost_config: RerankBoostConfig | None = None, 

59) -> list[FusedHit]: 

60 """Rerank top-K fused hits with a cross-encoder; return top-N. 

61 

62 Args: 

63 fused: ``SectionRecallResult.fused`` (sorted by RRF score desc). 

64 sections_by_key: Map ``(document_id, line_num) → PageIndexSection`` 

65 so we can fetch (title + summary) without re-querying the store. 

66 Caller pre-builds this from the conv_tree's skeleton. 

67 question: User question, fed to the cross-encoder. 

68 model: Cross-encoder; ``None`` → default Hindsight model. 

69 rerank_top_k: How many of the RRF top to actually rescore. The 

70 cross-encoder is the slow part (transformer inference); we cap 

71 at 30 by default. Items beyond this rank pass through with 

72 their original RRF score. 

73 output_top_k: Final length of the returned list. 

74 query_range: Optional inferred date band from question_annotator. 

75 Feeds the temporal-band boost; ``None`` makes it a no-op. 

76 proof_counts: Optional ``(document_id, line_num) → fact count`` 

77 map. Feeds the proof-count boost; ``None`` makes it a no-op. 

78 boost_config: Tuning + toggle for the multiplicative post-rerank 

79 boosts. ``None`` uses defaults (enabled=True with Hindsight- 

80 parity alphas). 

81 

82 Returns: 

83 ``FusedHit`` list, sorted by cross-encoder score (modulated by 

84 post-rerank boosts) descending, truncated to ``output_top_k``. 

85 The ``rrf_score`` field is replaced with the cross-encoder score 

86 (then multiplicatively boosted) for transparency. 

87 """ 

88 if not fused: 

89 return [] 

90 

91 head = fused[:rerank_top_k] 

92 items = [] 

93 for h in head: 

94 section = sections_by_key.get((h.document_id, h.line_num)) 

95 # Build the rerank input text from title + summary. The body 

96 # itself isn't needed at this stage; the picker fetches 

97 # excerpts for the synth, not the reranker. 

98 if section is None: 

99 text = f"line {h.line_num}" 

100 else: 

101 title = section.title or "" 

102 summary = section.summary or "" 

103 text = f"{title}. {summary}".strip(" .") 

104 items.append( 

105 ScoredItem( 

106 id=f"{h.document_id}:{h.line_num}", 

107 text=text, 

108 score=h.rrf_score, 

109 ) 

110 ) 

111 

112 rescored = cross_encoder_rerank(items, question, model=model) 

113 

114 # Map ScoredItem.id back to (doc, line) and emit FusedHit objects 

115 # with the new score in ``rrf_score`` (we abuse the field for 

116 # uniformity downstream — picker doesn't care which scorer wrote it). 

117 from astrocyte.pipeline.section_recall import FusedHit # avoid circular import 

118 

119 out: list[FusedHit] = [] 

120 by_key = {(h.document_id, h.line_num): h for h in fused} 

121 # Gap 3 (2026-05-16): pre-truncation boost. Apply post-rerank 

122 # multiplicative boosts BEFORE the output_top_k cut so the boosts 

123 # can re-order adjacent rank-15/16 candidates into the visible set 

124 # rather than rescoring an already-truncated list. 

125 rescored_full: list[FusedHit] = [] 

126 for item in rescored: 

127 doc_id, line_str = item.id.split(":", 1) 

128 line_num = int(line_str) 

129 original = by_key.get((doc_id, line_num)) 

130 if original is None: 

131 continue 

132 rescored_full.append( 

133 FusedHit( 

134 document_id=doc_id, 

135 line_num=line_num, 

136 rrf_score=float(item.score), 

137 per_strategy_rank=dict(original.per_strategy_rank), 

138 ) 

139 ) 

140 

141 boosted = apply_boosts( 

142 rescored_full, 

143 sections_by_key, 

144 query_range=query_range, 

145 proof_counts=proof_counts, 

146 config=boost_config, 

147 ) 

148 out = boosted[:output_top_k] 

149 return out 

150 

151 

152def build_constrained_skeleton( 

153 full_skeleton: list | dict, 

154 keep_keys: set[tuple[str, int]], 

155 document_id: str, 

156) -> list: 

157 """Prune the picker's nested-dict skeleton to only the nodes in 

158 ``keep_keys`` (preserving tree structure). 

159 

160 The picker's prompt format expects a nested-dict tree; the skeleton 

161 we pass in PR1 was the FULL tree (~30 nodes for LoCoMo). PR2 commit 

162 C narrows that to the top-15 reranked sections so the picker has a 

163 much smaller, more relevant input. 

164 

165 A node is kept if EITHER: 

166 a) ``(document_id, node.line_num)`` is in ``keep_keys``, OR 

167 b) it has at least one descendant that is in ``keep_keys`` 

168 (so the path from root to a kept leaf survives — the picker's 

169 tree-walk needs the parent chain). 

170 

171 Returns a list of root-level nodes. The caller passes this to the 

172 picker the same way it would pass the full skeleton. 

173 """ 

174 

175 def _walk(node: dict) -> dict | None: 

176 kept = (document_id, node.get("line_num")) in keep_keys 

177 children = node.get("nodes") 

178 kept_children: list[dict] = [] 

179 if isinstance(children, list): 

180 for child in children: 

181 if not isinstance(child, dict): 

182 continue 

183 pruned = _walk(child) 

184 if pruned is not None: 

185 kept_children.append(pruned) 

186 if not kept and not kept_children: 

187 return None 

188 out = {k: v for k, v in node.items() if k != "nodes"} 

189 if kept_children: 

190 out["nodes"] = kept_children 

191 return out 

192 

193 if isinstance(full_skeleton, dict): 

194 full_skeleton = full_skeleton.get("structure", [full_skeleton]) 

195 

196 pruned: list[dict] = [] 

197 for node in full_skeleton: 

198 if not isinstance(node, dict): 

199 continue 

200 kept = _walk(node) 

201 if kept is not None: 

202 pruned.append(kept) 

203 return pruned