Coverage for astrocyte/pipeline/fact_rerank.py: 95%

21 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""M12.3: Cross-encoder rerank over fact-grain hits. 

2 

3Sits between fact retrieval (semantic / entity / temporal) and the 

4``[FACTS]`` block in the synth prompt. Mirrors 

5``astrocyte.pipeline.section_rerank.rerank_fused_hits`` — same cross- 

6encoder backend, same module-level cache, same pattern of building a 

7text representation per candidate and calling 

8``cross_encoder_rerank``. 

9 

10Why a separate module: 

11 

12- Facts have different metadata (fact_type, speaker, entities, 

13 occurred_*) and a richer rerank-input text could in principle attend 

14 to those. The v1 keeps it minimal: just ``fact.text``. The MS MARCO 

15 cross-encoder is trained on natural-language passages; injecting 

16 structured metadata as `[key=value]` tokens tends to confuse it. 

17- Facts are typically retrieved from a wider pool (top-30+ semantic) 

18 and then narrowed by a picker-line filter. Reranking is cheapest 

19 when applied to the already-filtered subset. 

20 

21Generic across benches — the cross-encoder doesn't know LME from 

22LoCoMo, and the rerank text contains no bench-specific shaping. 

23 

24See: 

25- ``docs/_design/benchmark-comparison-methodology.md`` for harness rules 

26- ``astrocyte.pipeline.section_rerank`` for the section-grain analogue 

27""" 

28 

29from __future__ import annotations 

30 

31import logging 

32from dataclasses import replace 

33from typing import TYPE_CHECKING 

34 

35from astrocyte.pipeline.cross_encoder_rerank import ( 

36 CrossEncoderProtocol, 

37 cross_encoder_rerank, 

38) 

39from astrocyte.pipeline.reranking import ScoredItem 

40 

41if TYPE_CHECKING: 

42 from astrocyte.types import PageIndexFactHit 

43 

44logger = logging.getLogger("astrocyte.pipeline.fact_rerank") 

45 

46 

47def rerank_fact_hits( 

48 hits: list[PageIndexFactHit], 

49 question: str, 

50 *, 

51 model: CrossEncoderProtocol | None = None, 

52 rerank_top_k: int = 30, 

53 output_top_k: int = 12, 

54) -> list[PageIndexFactHit]: 

55 """Cross-encoder rerank ``hits`` against ``question``. 

56 

57 Args: 

58 hits: Fact hits, typically the union of semantic / entity / 

59 temporal search results, already deduped by ``fact_id``. 

60 Order on input doesn't matter — the cross-encoder reorders 

61 from scratch. 

62 question: User question, fed to the cross-encoder. 

63 model: Cross-encoder backend. ``None`` → cached default 

64 (``cross-encoder/ms-marco-MiniLM-L-6-v2``). 

65 rerank_top_k: Cap on how many candidates to actually rescore. 

66 Cross-encoder inference is the slow part; we bound it at 30 

67 by default. Items beyond this rank pass through with their 

68 original score. 

69 output_top_k: Final length of the returned list (post-rerank). 

70 

71 Returns: 

72 Hits sorted by cross-encoder score descending, truncated to 

73 ``output_top_k``. The ``score`` field is replaced with the 

74 cross-encoder score for transparency downstream. 

75 """ 

76 if not hits: 

77 return [] 

78 

79 head = hits[:rerank_top_k] 

80 items = [ 

81 ScoredItem( 

82 id=h.fact_id, 

83 text=h.text, 

84 score=h.score, 

85 ) 

86 for h in head 

87 ] 

88 

89 rescored = cross_encoder_rerank(items, question, model=model) 

90 

91 by_id = {h.fact_id: h for h in head} 

92 out: list[PageIndexFactHit] = [] 

93 for item in rescored[:output_top_k]: 

94 original = by_id.get(item.id) 

95 if original is None: 

96 continue 

97 # ``replace`` shallow-copies the dataclass with the new score, 

98 # automatically picking up any future PageIndexFactHit fields 

99 # added to types.py. The shallow-copy semantics match the rest 

100 # of the codebase's treatment of dataclass-like hits. 

101 out.append(replace(original, score=float(item.score))) 

102 return out