Coverage for astrocyte/pipeline/reranking.py: 90%
169 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Basic reranking — keyword and entity-aware scoring for Phase 1.
3Sync, pure computation — Rust migration candidate.
4Phase 2: cross-encoder or LLM-based reranking.
5"""
7from __future__ import annotations
9import json
10from string import punctuation
11from typing import TYPE_CHECKING
13from astrocyte.mip.schema import RerankSpec
14from astrocyte.pipeline.fusion import ScoredItem
15from astrocyte.types import Message
17if TYPE_CHECKING:
18 from astrocyte.provider import LLMProvider
20COMMON_QUESTION_WORDS: set[str] = {
21 "what",
22 "when",
23 "where",
24 "who",
25 "why",
26 "how",
27 "which",
28 "did",
29 "does",
30 "do",
31 "is",
32 "are",
33 "was",
34 "were",
35 "has",
36 "have",
37 "can",
38 "could",
39 "would",
40 "should",
41 "will",
42 "tell",
43 "describe",
44}
46KEYWORD_OVERLAP_WEIGHT = 0.05
47PROPER_NOUN_WEIGHT = 0.10
48QUERY_INTERACTION_WEIGHT = 0.35
49QUERY_PHRASE_WEIGHT = 0.08
50QUERY_NAME_MATCH_WEIGHT = 0.25
51QUERY_NAME_MISS_PENALTY = 0.20
52COMPILED_LAYER_WEIGHT = 0.30
53OBSERVATION_LAYER_WEIGHT = 0.20
54SESSION_DIVERSITY_PENALTY = 0.08
55# Observation proof-count boost: each additional confirming memory adds this
56# to the score, capped at OBSERVATION_PROOF_CAP × weight. A 5-evidence
57# observation gets a +0.10 bonus over a single-evidence raw memory.
58OBSERVATION_PROOF_WEIGHT = 0.025
59OBSERVATION_PROOF_CAP = 4 # clamp at 4 additional proofs
61# Characters allowed inside proper names (apostrophes and hyphens).
62# Straight apostrophe, left/right single quotation marks, and hyphen.
63NAME_CONNECTOR_CHARS = ("'", "\u2018", "\u2019", "-")
66def _tokenize_terms(text: str) -> list[str]:
67 """Tokenize text consistently for keyword and item matching.
69 - Split on whitespace
70 - Strip leading/trailing punctuation
71 - Lowercase
72 - Drop empty tokens
73 """
74 return [t for t in (w.strip(punctuation).lower() for w in text.split()) if t]
77def _content_terms(text: str) -> set[str]:
78 return {t for t in _tokenize_terms(text) if t not in COMMON_QUESTION_WORDS and len(t) > 2}
81def _is_name_token(token: str) -> bool:
82 """Check if a token looks like a proper name, allowing apostrophes and hyphens.
84 Accepts: "Alice", "O'Brien", "Mary-Ann", "Jean-Paul"
85 Rejects: "--", "'hello", "123", ""
86 """
87 if not token:
88 return False
89 # Must start and end with a letter
90 if not token[0].isalpha() or not token[-1].isalpha():
91 return False
92 # Interior characters must be letters, apostrophes, or hyphens
93 for ch in token:
94 if not (ch.isalpha() or ch in NAME_CONNECTOR_CHARS):
95 return False
96 return True
99def basic_rerank(
100 items: list[ScoredItem],
101 query: str,
102 *,
103 mip_rerank: RerankSpec | None = None,
104) -> list[ScoredItem]:
105 """Rerank items using keyword overlap and proper-noun boosting.
107 Adds bonuses to items whose text contains:
108 - General query terms (``keyword_weight`` per matching term)
109 - Proper nouns / names from the query (``proper_noun_weight`` per match)
111 Defaults come from module constants (``KEYWORD_OVERLAP_WEIGHT`` /
112 ``PROPER_NOUN_WEIGHT``); a MIP ``RerankSpec`` from the active routing
113 decision can override either weight on a per-call basis without mutating
114 module state. ``None`` fields fall through to the default.
116 This is a heuristic — production systems should use cross-encoders.
117 """
118 if not items or not query:
119 return items
121 keyword_weight = (
122 mip_rerank.keyword_weight
123 if mip_rerank is not None and mip_rerank.keyword_weight is not None
124 else KEYWORD_OVERLAP_WEIGHT
125 )
126 proper_noun_weight = (
127 mip_rerank.proper_noun_weight
128 if mip_rerank is not None and mip_rerank.proper_noun_weight is not None
129 else PROPER_NOUN_WEIGHT
130 )
132 # Tokenize query once; filter common question words for overlap scoring
133 query_terms = {t for t in _tokenize_terms(query) if t not in COMMON_QUESTION_WORDS}
135 # Detect proper nouns from all words.
136 # Matches: Title Case ("Alice"), ALL CAPS ("USA"), and lowercase names that
137 # appear as query terms but aren't common words (caught by _is_name_token).
138 proper_nouns: set[str] = set()
139 for w in query.split():
140 cleaned = w.strip(punctuation)
141 if not cleaned or cleaned.lower() in COMMON_QUESTION_WORDS:
142 continue
143 is_proper = (
144 cleaned.istitle() # "Alice"
145 or (cleaned.isupper() and len(cleaned) >= 2) # "USA", "AI"
146 )
147 if is_proper and _is_name_token(cleaned):
148 proper_nouns.add(cleaned.lower())
150 # Pre-compute tokenized terms for all items to avoid repeated work.
151 item_terms_by_item = [(item, set(_tokenize_terms(item.text))) for item in items]
153 return sorted(
154 (
155 ScoredItem(
156 id=item.id,
157 text=item.text,
158 score=item.score
159 + len(query_terms & item_terms) * keyword_weight
160 + len(proper_nouns & item_terms) * proper_noun_weight
161 + _observation_proof_boost(item),
162 fact_type=item.fact_type,
163 metadata=item.metadata,
164 tags=item.tags,
165 memory_layer=item.memory_layer,
166 retained_at=item.retained_at,
167 chunk_id=getattr(item, "chunk_id", None),
168 )
169 for item, item_terms in item_terms_by_item
170 ),
171 key=lambda x: x.score,
172 reverse=True,
173 )
176def cross_encoder_like_rerank(
177 items: list[ScoredItem],
178 query: str,
179) -> list[ScoredItem]:
180 """Final precision rerank using query-item interaction features.
182 This is a deterministic local stand-in for a cross-encoder: it scores each
183 query/memory pair jointly, then applies entity/person and memory-layer
184 signals. It is intentionally cheap enough to run before ``reflect()``
185 synthesis, where precision matters more than broad candidate coverage.
186 """
187 if not items or not query:
188 return items
190 query_terms = _content_terms(query)
191 query_names = _proper_names(query)
192 query_bigrams = _bigrams(_tokenize_terms(query))
194 scored: list[ScoredItem] = []
195 for item in items:
196 item_terms = _content_terms(item.text)
197 item_names = _candidate_names(item)
198 overlap = len(query_terms & item_terms) / max(len(query_terms), 1)
199 phrase_hits = len(query_bigrams & _bigrams(_tokenize_terms(item.text)))
201 score = item.score
202 score += overlap * QUERY_INTERACTION_WEIGHT
203 score += min(phrase_hits, 3) * QUERY_PHRASE_WEIGHT
204 score += _layer_boost(item)
206 if query_names:
207 if query_names & item_names:
208 score += QUERY_NAME_MATCH_WEIGHT
209 elif item_names:
210 score -= QUERY_NAME_MISS_PENALTY
212 scored.append(
213 ScoredItem(
214 id=item.id,
215 text=item.text,
216 score=max(score, 0.0),
217 fact_type=item.fact_type,
218 metadata=item.metadata,
219 tags=item.tags,
220 memory_layer=item.memory_layer,
221 retained_at=item.retained_at,
222 chunk_id=getattr(item, "chunk_id", None),
223 )
224 )
226 return sorted(scored, key=lambda x: x.score, reverse=True)
229async def llm_pairwise_rerank(
230 items: list[ScoredItem],
231 query: str,
232 llm_provider: LLMProvider,
233 *,
234 top_n: int = 30,
235 keep_n: int | None = None,
236) -> list[ScoredItem]:
237 """Use the configured LLM as a lightweight listwise reranker.
239 The prompt asks for candidate IDs in descending relevance order. If the LLM
240 response is missing or malformed, the deterministic local reranker is used.
241 """
242 if not items or not query:
243 return items
244 candidates = cross_encoder_like_rerank(items[:top_n], query)
245 remainder = items[top_n:]
246 keep = keep_n or len(candidates)
248 prompt_lines = [
249 "Rank the memory candidates by how directly they answer the query.",
250 "Penalize wrong-person or wrong-premise candidates.",
251 'Return JSON only: {"ranked_ids": ["id1", "id2"]}.',
252 "",
253 f"Query: {query}",
254 "",
255 "Candidates:",
256 ]
257 for item in candidates:
258 prompt_lines.append(f"- id={item.id} score={item.score:.4f}: {item.text[:500]}")
260 try:
261 completion = await llm_provider.complete(
262 [
263 Message(role="system", content="You are a strict memory reranker."),
264 Message(role="user", content="\n".join(prompt_lines)),
265 ],
266 max_tokens=512,
267 temperature=0.0,
268 )
269 ranked_ids = _parse_ranked_ids(completion.text)
270 except Exception:
271 ranked_ids = []
273 if not ranked_ids:
274 return apply_context_diversity(candidates, query)[:keep] + remainder
276 by_id = {item.id: item for item in candidates}
277 ordered: list[ScoredItem] = []
278 seen: set[str] = set()
279 for item_id in ranked_ids:
280 item = by_id.get(item_id)
281 if item is not None and item_id not in seen:
282 ordered.append(item)
283 seen.add(item_id)
284 ordered.extend(item for item in candidates if item.id not in seen)
285 return apply_context_diversity(ordered, query)[:keep] + remainder
288def apply_context_diversity(items: list[ScoredItem], query: str) -> list[ScoredItem]:
289 """Softly penalize repeated sessions unless the query asks for aggregation."""
290 if not items:
291 return items
292 if _is_aggregate_query(query):
293 return items
295 session_counts: dict[str, int] = {}
296 diversified: list[ScoredItem] = []
297 for item in items:
298 session = str((item.metadata or {}).get("session_id") or "")
299 penalty = 0.0
300 if session:
301 seen = session_counts.get(session, 0)
302 penalty = seen * SESSION_DIVERSITY_PENALTY
303 session_counts[session] = seen + 1
304 diversified.append(
305 ScoredItem(
306 id=item.id,
307 text=item.text,
308 score=max(item.score - penalty, 0.0),
309 fact_type=item.fact_type,
310 metadata=item.metadata,
311 tags=item.tags,
312 memory_layer=item.memory_layer,
313 occurred_at=item.occurred_at,
314 retained_at=item.retained_at,
315 chunk_id=getattr(item, "chunk_id", None),
316 )
317 )
318 return sorted(diversified, key=lambda item: item.score, reverse=True)
321def _proper_names(text: str) -> set[str]:
322 names: set[str] = set()
323 for word in text.split():
324 cleaned = word.strip(punctuation)
325 if cleaned and cleaned.istitle() and _is_name_token(cleaned):
326 names.add(cleaned.lower())
327 return names
330def _candidate_names(item: ScoredItem) -> set[str]:
331 names = _proper_names(item.text)
332 metadata = item.metadata or {}
333 for key in ("locomo_speakers", "locomo_persons", "speakers", "person"):
334 raw = metadata.get(key)
335 if not raw:
336 continue
337 for part in str(raw).replace("|", ",").split(","):
338 cleaned = part.strip().lower()
339 if cleaned:
340 names.add(cleaned)
341 return names
344def _parse_ranked_ids(text: str) -> list[str]:
345 raw = text.strip()
346 if raw.startswith("```"):
347 raw = "\n".join(line for line in raw.splitlines() if not line.strip().startswith("```"))
348 start = raw.find("{")
349 end = raw.rfind("}")
350 if start >= 0 and end > start:
351 raw = raw[start : end + 1]
352 try:
353 parsed = json.loads(raw)
354 except json.JSONDecodeError:
355 return []
356 ids = parsed.get("ranked_ids") if isinstance(parsed, dict) else None
357 if not isinstance(ids, list):
358 return []
359 return [str(item) for item in ids if item]
362def _is_aggregate_query(query: str) -> bool:
363 query_l = (query or "").lower()
364 return any(term in query_l for term in ("how many", "what are", "list", "all ", "which "))
367def _bigrams(tokens: list[str]) -> set[tuple[str, str]]:
368 return set(zip(tokens, tokens[1:], strict=False))
371def _layer_boost(item: ScoredItem) -> float:
372 if item.fact_type == "wiki" or item.memory_layer == "compiled":
373 return COMPILED_LAYER_WEIGHT
374 if item.fact_type == "observation" or item.memory_layer == "observation":
375 return OBSERVATION_LAYER_WEIGHT + _observation_proof_boost(item)
376 return 0.0
379def _observation_proof_boost(item: ScoredItem) -> float:
380 """Additive boost for observation items proportional to their proof count.
382 A single-evidence observation (``_obs_proof_count=1``) gets +0.0.
383 Each additional corroborating memory adds ``OBSERVATION_PROOF_WEIGHT``,
384 capped at ``OBSERVATION_PROOF_CAP`` extra proofs (+0.10 total).
385 Raw memories (no ``_obs_proof_count``) are unaffected.
386 """
387 if item.fact_type != "observation" or not item.metadata:
388 return 0.0
389 proof = item.metadata.get("_obs_proof_count", 1)
390 try:
391 extra = max(0, int(proof) - 1)
392 except (TypeError, ValueError):
393 return 0.0
394 return min(extra, OBSERVATION_PROOF_CAP) * OBSERVATION_PROOF_WEIGHT