Coverage for astrocyte/pipeline/fusion.py: 100%

69 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Reciprocal Rank Fusion (RRF) — merge results from multiple retrieval strategies. 

2 

3Sync, pure computation — Rust migration candidate. 

4See docs/_design/built-in-pipeline.md section 3. 

5""" 

6 

7from __future__ import annotations 

8 

9import hashlib 

10from dataclasses import dataclass 

11from datetime import datetime 

12 

13from astrocyte.types import MemoryHit 

14 

15#: Default RRF smoothing constant. Higher values give more weight to lower-ranked items. 

16#: Standard value from the original RRF paper (Cormack et al., 2009). 

17DEFAULT_RRF_K = 60 

18 

19 

20@dataclass 

21class ScoredItem: 

22 """A scored item from any retrieval strategy.""" 

23 

24 id: str 

25 text: str 

26 score: float 

27 fact_type: str | None = None 

28 metadata: dict[str, str | int | float | bool | None] | None = None 

29 tags: list[str] | None = None 

30 memory_layer: str | None = None # "fact", "observation", "model" 

31 occurred_at: datetime | None = None 

32 retained_at: datetime | None = None # M9: wall-clock time item was stored 

33 chunk_id: str | None = None # M10: source-chunk backreference 

34 

35 

36def rrf_fusion( 

37 ranked_lists: list[list[ScoredItem]], 

38 k: int = DEFAULT_RRF_K, 

39) -> list[ScoredItem]: 

40 """Reciprocal Rank Fusion across multiple ranked result lists. 

41 

42 RRF score = Σ(1 / (k + rank)) for each list where item appears. 

43 Items are deduplicated by id. 

44 

45 Sync, pure computation — Rust migration candidate. 

46 """ 

47 if not ranked_lists: 

48 return [] 

49 

50 # Accumulate RRF scores by item id 

51 scores: dict[str, float] = {} 

52 items: dict[str, ScoredItem] = {} 

53 

54 for ranked_list in ranked_lists: 

55 for rank, item in enumerate(ranked_list): 

56 rrf_score = 1.0 / (k + rank + 1) # rank is 0-indexed, add 1 

57 scores[item.id] = scores.get(item.id, 0.0) + rrf_score 

58 # Keep the item with the highest original score 

59 if item.id not in items or item.score > items[item.id].score: 

60 items[item.id] = item 

61 

62 # Sort by RRF score descending 

63 sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True) 

64 

65 # Build result with RRF score replacing original score 

66 result: list[ScoredItem] = [] 

67 for item_id in sorted_ids: 

68 item = items[item_id] 

69 result.append( 

70 ScoredItem( 

71 id=item.id, 

72 text=item.text, 

73 score=scores[item_id], 

74 fact_type=item.fact_type, 

75 metadata=item.metadata, 

76 tags=item.tags, 

77 memory_layer=item.memory_layer, 

78 occurred_at=item.occurred_at, 

79 retained_at=item.retained_at, 

80 chunk_id=getattr(item, "chunk_id", None), 

81 ) 

82 ) 

83 

84 return result 

85 

86 

87def weighted_rrf_fusion( 

88 ranked_lists_with_weights: list[tuple[list[ScoredItem], float]], 

89 k: int = DEFAULT_RRF_K, 

90) -> list[ScoredItem]: 

91 """RRF fusion where each input list contributes a scaled reciprocal rank. 

92 

93 Standard RRF adds ``1 / (k + rank)`` for each list an item appears in. 

94 Weighted RRF adds ``weight / (k + rank)`` — a list with weight 1.5 

95 counts 50% more per rank slot than a list with weight 1.0. 

96 

97 Used for intent-aware retrieval (see 

98 :mod:`astrocyte.pipeline.query_intent`): when a query's intent biases 

99 a strategy (e.g. TEMPORAL → boost temporal strategy weight to 1.5), 

100 pass the strategy weights here so the biased strategy pulls items 

101 up more aggressively. 

102 

103 A weight of 0.0 mutes a strategy (it contributes nothing). Negative 

104 weights are a caller bug — they would silently invert strategy 

105 rankings (items pushed *down* instead of up), which is never 

106 intentional. Pass ``weight=0.0`` to mute a strategy explicitly. 

107 

108 Raises: 

109 ValueError: If any weight is negative. 

110 

111 Sync, pure computation — Rust migration candidate. 

112 """ 

113 if not ranked_lists_with_weights: 

114 return [] 

115 

116 scores: dict[str, float] = {} 

117 items: dict[str, ScoredItem] = {} 

118 

119 for ranked_list, weight in ranked_lists_with_weights: 

120 if weight < 0.0: 

121 raise ValueError(f"RRF weight must be >= 0.0; got {weight!r}. Pass weight=0.0 to mute a strategy.") 

122 effective_weight = weight 

123 if effective_weight == 0.0: 

124 continue 

125 for rank, item in enumerate(ranked_list): 

126 rrf_contribution = effective_weight / (k + rank + 1) 

127 scores[item.id] = scores.get(item.id, 0.0) + rrf_contribution 

128 if item.id not in items or item.score > items[item.id].score: 

129 items[item.id] = item 

130 

131 sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True) 

132 

133 return [ 

134 ScoredItem( 

135 id=items[iid].id, 

136 text=items[iid].text, 

137 score=scores[iid], 

138 fact_type=items[iid].fact_type, 

139 metadata=items[iid].metadata, 

140 tags=items[iid].tags, 

141 memory_layer=items[iid].memory_layer, 

142 occurred_at=items[iid].occurred_at, 

143 retained_at=items[iid].retained_at, 

144 chunk_id=getattr(items[iid], "chunk_id", None), 

145 ) 

146 for iid in sorted_ids 

147 ] 

148 

149 

150def layer_weighted_rrf_fusion( 

151 ranked_lists: list[list[ScoredItem]], 

152 k: int = 60, 

153 layer_weights: dict[str, float] | None = None, 

154) -> list[ScoredItem]: 

155 """RRF fusion with optional layer-based score boosting. 

156 

157 After standard RRF, multiplies each item's score by the weight 

158 for its memory_layer. Items with no layer get weight 1.0. 

159 

160 layer_weights example: {"fact": 1.0, "observation": 1.5, "model": 2.0} 

161 Higher layers (models) are boosted above raw facts. 

162 

163 Sync, pure computation — Rust migration candidate. 

164 """ 

165 fused = rrf_fusion(ranked_lists, k=k) 

166 

167 if not layer_weights: 

168 return fused 

169 

170 # Apply layer weights — create new items to avoid mutating rrf_fusion output 

171 weighted = [ 

172 ScoredItem( 

173 id=item.id, 

174 text=item.text, 

175 score=item.score * layer_weights.get(item.memory_layer or "", 1.0), 

176 fact_type=item.fact_type, 

177 metadata=item.metadata, 

178 tags=item.tags, 

179 memory_layer=item.memory_layer, 

180 occurred_at=item.occurred_at, 

181 retained_at=item.retained_at, 

182 chunk_id=getattr(item, "chunk_id", None), 

183 ) 

184 for item in fused 

185 ] 

186 

187 # Re-sort by weighted score 

188 weighted.sort(key=lambda x: x.score, reverse=True) 

189 return weighted 

190 

191 

192def memory_hits_as_scored(hits: list[MemoryHit]) -> list[ScoredItem]: 

193 """Convert MemoryHit rows (e.g. federated / proxy recall) into ScoredItem for RRF.""" 

194 out: list[ScoredItem] = [] 

195 for h in hits: 

196 hid = h.memory_id 

197 if not hid: 

198 digest = hashlib.sha256(h.text.encode()).hexdigest()[:24] 

199 hid = f"ext-{digest}" 

200 out.append( 

201 ScoredItem( 

202 id=hid, 

203 text=h.text, 

204 score=h.score, 

205 fact_type=h.fact_type, 

206 metadata=h.metadata, 

207 tags=h.tags, 

208 memory_layer=h.memory_layer, 

209 occurred_at=h.occurred_at, 

210 retained_at=getattr(h, "retained_at", None), 

211 chunk_id=getattr(h, "chunk_id", None), 

212 ) 

213 ) 

214 return out