Coverage for astrocyte/pipeline/fusion.py: 100%
69 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Reciprocal Rank Fusion (RRF) — merge results from multiple retrieval strategies.
3Sync, pure computation — Rust migration candidate.
4See docs/_design/built-in-pipeline.md section 3.
5"""
7from __future__ import annotations
9import hashlib
10from dataclasses import dataclass
11from datetime import datetime
13from astrocyte.types import MemoryHit
15#: Default RRF smoothing constant. Higher values give more weight to lower-ranked items.
16#: Standard value from the original RRF paper (Cormack et al., 2009).
17DEFAULT_RRF_K = 60
20@dataclass
21class ScoredItem:
22 """A scored item from any retrieval strategy."""
24 id: str
25 text: str
26 score: float
27 fact_type: str | None = None
28 metadata: dict[str, str | int | float | bool | None] | None = None
29 tags: list[str] | None = None
30 memory_layer: str | None = None # "fact", "observation", "model"
31 occurred_at: datetime | None = None
32 retained_at: datetime | None = None # M9: wall-clock time item was stored
33 chunk_id: str | None = None # M10: source-chunk backreference
36def rrf_fusion(
37 ranked_lists: list[list[ScoredItem]],
38 k: int = DEFAULT_RRF_K,
39) -> list[ScoredItem]:
40 """Reciprocal Rank Fusion across multiple ranked result lists.
42 RRF score = Σ(1 / (k + rank)) for each list where item appears.
43 Items are deduplicated by id.
45 Sync, pure computation — Rust migration candidate.
46 """
47 if not ranked_lists:
48 return []
50 # Accumulate RRF scores by item id
51 scores: dict[str, float] = {}
52 items: dict[str, ScoredItem] = {}
54 for ranked_list in ranked_lists:
55 for rank, item in enumerate(ranked_list):
56 rrf_score = 1.0 / (k + rank + 1) # rank is 0-indexed, add 1
57 scores[item.id] = scores.get(item.id, 0.0) + rrf_score
58 # Keep the item with the highest original score
59 if item.id not in items or item.score > items[item.id].score:
60 items[item.id] = item
62 # Sort by RRF score descending
63 sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
65 # Build result with RRF score replacing original score
66 result: list[ScoredItem] = []
67 for item_id in sorted_ids:
68 item = items[item_id]
69 result.append(
70 ScoredItem(
71 id=item.id,
72 text=item.text,
73 score=scores[item_id],
74 fact_type=item.fact_type,
75 metadata=item.metadata,
76 tags=item.tags,
77 memory_layer=item.memory_layer,
78 occurred_at=item.occurred_at,
79 retained_at=item.retained_at,
80 chunk_id=getattr(item, "chunk_id", None),
81 )
82 )
84 return result
87def weighted_rrf_fusion(
88 ranked_lists_with_weights: list[tuple[list[ScoredItem], float]],
89 k: int = DEFAULT_RRF_K,
90) -> list[ScoredItem]:
91 """RRF fusion where each input list contributes a scaled reciprocal rank.
93 Standard RRF adds ``1 / (k + rank)`` for each list an item appears in.
94 Weighted RRF adds ``weight / (k + rank)`` — a list with weight 1.5
95 counts 50% more per rank slot than a list with weight 1.0.
97 Used for intent-aware retrieval (see
98 :mod:`astrocyte.pipeline.query_intent`): when a query's intent biases
99 a strategy (e.g. TEMPORAL → boost temporal strategy weight to 1.5),
100 pass the strategy weights here so the biased strategy pulls items
101 up more aggressively.
103 A weight of 0.0 mutes a strategy (it contributes nothing). Negative
104 weights are a caller bug — they would silently invert strategy
105 rankings (items pushed *down* instead of up), which is never
106 intentional. Pass ``weight=0.0`` to mute a strategy explicitly.
108 Raises:
109 ValueError: If any weight is negative.
111 Sync, pure computation — Rust migration candidate.
112 """
113 if not ranked_lists_with_weights:
114 return []
116 scores: dict[str, float] = {}
117 items: dict[str, ScoredItem] = {}
119 for ranked_list, weight in ranked_lists_with_weights:
120 if weight < 0.0:
121 raise ValueError(f"RRF weight must be >= 0.0; got {weight!r}. Pass weight=0.0 to mute a strategy.")
122 effective_weight = weight
123 if effective_weight == 0.0:
124 continue
125 for rank, item in enumerate(ranked_list):
126 rrf_contribution = effective_weight / (k + rank + 1)
127 scores[item.id] = scores.get(item.id, 0.0) + rrf_contribution
128 if item.id not in items or item.score > items[item.id].score:
129 items[item.id] = item
131 sorted_ids = sorted(scores.keys(), key=lambda x: scores[x], reverse=True)
133 return [
134 ScoredItem(
135 id=items[iid].id,
136 text=items[iid].text,
137 score=scores[iid],
138 fact_type=items[iid].fact_type,
139 metadata=items[iid].metadata,
140 tags=items[iid].tags,
141 memory_layer=items[iid].memory_layer,
142 occurred_at=items[iid].occurred_at,
143 retained_at=items[iid].retained_at,
144 chunk_id=getattr(items[iid], "chunk_id", None),
145 )
146 for iid in sorted_ids
147 ]
150def layer_weighted_rrf_fusion(
151 ranked_lists: list[list[ScoredItem]],
152 k: int = 60,
153 layer_weights: dict[str, float] | None = None,
154) -> list[ScoredItem]:
155 """RRF fusion with optional layer-based score boosting.
157 After standard RRF, multiplies each item's score by the weight
158 for its memory_layer. Items with no layer get weight 1.0.
160 layer_weights example: {"fact": 1.0, "observation": 1.5, "model": 2.0}
161 Higher layers (models) are boosted above raw facts.
163 Sync, pure computation — Rust migration candidate.
164 """
165 fused = rrf_fusion(ranked_lists, k=k)
167 if not layer_weights:
168 return fused
170 # Apply layer weights — create new items to avoid mutating rrf_fusion output
171 weighted = [
172 ScoredItem(
173 id=item.id,
174 text=item.text,
175 score=item.score * layer_weights.get(item.memory_layer or "", 1.0),
176 fact_type=item.fact_type,
177 metadata=item.metadata,
178 tags=item.tags,
179 memory_layer=item.memory_layer,
180 occurred_at=item.occurred_at,
181 retained_at=item.retained_at,
182 chunk_id=getattr(item, "chunk_id", None),
183 )
184 for item in fused
185 ]
187 # Re-sort by weighted score
188 weighted.sort(key=lambda x: x.score, reverse=True)
189 return weighted
192def memory_hits_as_scored(hits: list[MemoryHit]) -> list[ScoredItem]:
193 """Convert MemoryHit rows (e.g. federated / proxy recall) into ScoredItem for RRF."""
194 out: list[ScoredItem] = []
195 for h in hits:
196 hid = h.memory_id
197 if not hid:
198 digest = hashlib.sha256(h.text.encode()).hexdigest()[:24]
199 hid = f"ext-{digest}"
200 out.append(
201 ScoredItem(
202 id=hid,
203 text=h.text,
204 score=h.score,
205 fact_type=h.fact_type,
206 metadata=h.metadata,
207 tags=h.tags,
208 memory_layer=h.memory_layer,
209 occurred_at=h.occurred_at,
210 retained_at=getattr(h, "retained_at", None),
211 chunk_id=getattr(h, "chunk_id", None),
212 )
213 )
214 return out