Coverage for astrocyte/pipeline/reranking.py: 90%

169 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Basic reranking — keyword and entity-aware scoring for Phase 1. 

2 

3Sync, pure computation — Rust migration candidate. 

4Phase 2: cross-encoder or LLM-based reranking. 

5""" 

6 

7from __future__ import annotations 

8 

9import json 

10from string import punctuation 

11from typing import TYPE_CHECKING 

12 

13from astrocyte.mip.schema import RerankSpec 

14from astrocyte.pipeline.fusion import ScoredItem 

15from astrocyte.types import Message 

16 

17if TYPE_CHECKING: 

18 from astrocyte.provider import LLMProvider 

19 

20COMMON_QUESTION_WORDS: set[str] = { 

21 "what", 

22 "when", 

23 "where", 

24 "who", 

25 "why", 

26 "how", 

27 "which", 

28 "did", 

29 "does", 

30 "do", 

31 "is", 

32 "are", 

33 "was", 

34 "were", 

35 "has", 

36 "have", 

37 "can", 

38 "could", 

39 "would", 

40 "should", 

41 "will", 

42 "tell", 

43 "describe", 

44} 

45 

46KEYWORD_OVERLAP_WEIGHT = 0.05 

47PROPER_NOUN_WEIGHT = 0.10 

48QUERY_INTERACTION_WEIGHT = 0.35 

49QUERY_PHRASE_WEIGHT = 0.08 

50QUERY_NAME_MATCH_WEIGHT = 0.25 

51QUERY_NAME_MISS_PENALTY = 0.20 

52COMPILED_LAYER_WEIGHT = 0.30 

53OBSERVATION_LAYER_WEIGHT = 0.20 

54SESSION_DIVERSITY_PENALTY = 0.08 

55# Observation proof-count boost: each additional confirming memory adds this 

56# to the score, capped at OBSERVATION_PROOF_CAP × weight. A 5-evidence 

57# observation gets a +0.10 bonus over a single-evidence raw memory. 

58OBSERVATION_PROOF_WEIGHT = 0.025 

59OBSERVATION_PROOF_CAP = 4 # clamp at 4 additional proofs 

60 

61# Characters allowed inside proper names (apostrophes and hyphens). 

62# Straight apostrophe, left/right single quotation marks, and hyphen. 

63NAME_CONNECTOR_CHARS = ("'", "\u2018", "\u2019", "-") 

64 

65 

66def _tokenize_terms(text: str) -> list[str]: 

67 """Tokenize text consistently for keyword and item matching. 

68 

69 - Split on whitespace 

70 - Strip leading/trailing punctuation 

71 - Lowercase 

72 - Drop empty tokens 

73 """ 

74 return [t for t in (w.strip(punctuation).lower() for w in text.split()) if t] 

75 

76 

77def _content_terms(text: str) -> set[str]: 

78 return {t for t in _tokenize_terms(text) if t not in COMMON_QUESTION_WORDS and len(t) > 2} 

79 

80 

81def _is_name_token(token: str) -> bool: 

82 """Check if a token looks like a proper name, allowing apostrophes and hyphens. 

83 

84 Accepts: "Alice", "O'Brien", "Mary-Ann", "Jean-Paul" 

85 Rejects: "--", "'hello", "123", "" 

86 """ 

87 if not token: 

88 return False 

89 # Must start and end with a letter 

90 if not token[0].isalpha() or not token[-1].isalpha(): 

91 return False 

92 # Interior characters must be letters, apostrophes, or hyphens 

93 for ch in token: 

94 if not (ch.isalpha() or ch in NAME_CONNECTOR_CHARS): 

95 return False 

96 return True 

97 

98 

99def basic_rerank( 

100 items: list[ScoredItem], 

101 query: str, 

102 *, 

103 mip_rerank: RerankSpec | None = None, 

104) -> list[ScoredItem]: 

105 """Rerank items using keyword overlap and proper-noun boosting. 

106 

107 Adds bonuses to items whose text contains: 

108 - General query terms (``keyword_weight`` per matching term) 

109 - Proper nouns / names from the query (``proper_noun_weight`` per match) 

110 

111 Defaults come from module constants (``KEYWORD_OVERLAP_WEIGHT`` / 

112 ``PROPER_NOUN_WEIGHT``); a MIP ``RerankSpec`` from the active routing 

113 decision can override either weight on a per-call basis without mutating 

114 module state. ``None`` fields fall through to the default. 

115 

116 This is a heuristic — production systems should use cross-encoders. 

117 """ 

118 if not items or not query: 

119 return items 

120 

121 keyword_weight = ( 

122 mip_rerank.keyword_weight 

123 if mip_rerank is not None and mip_rerank.keyword_weight is not None 

124 else KEYWORD_OVERLAP_WEIGHT 

125 ) 

126 proper_noun_weight = ( 

127 mip_rerank.proper_noun_weight 

128 if mip_rerank is not None and mip_rerank.proper_noun_weight is not None 

129 else PROPER_NOUN_WEIGHT 

130 ) 

131 

132 # Tokenize query once; filter common question words for overlap scoring 

133 query_terms = {t for t in _tokenize_terms(query) if t not in COMMON_QUESTION_WORDS} 

134 

135 # Detect proper nouns from all words. 

136 # Matches: Title Case ("Alice"), ALL CAPS ("USA"), and lowercase names that 

137 # appear as query terms but aren't common words (caught by _is_name_token). 

138 proper_nouns: set[str] = set() 

139 for w in query.split(): 

140 cleaned = w.strip(punctuation) 

141 if not cleaned or cleaned.lower() in COMMON_QUESTION_WORDS: 

142 continue 

143 is_proper = ( 

144 cleaned.istitle() # "Alice" 

145 or (cleaned.isupper() and len(cleaned) >= 2) # "USA", "AI" 

146 ) 

147 if is_proper and _is_name_token(cleaned): 

148 proper_nouns.add(cleaned.lower()) 

149 

150 # Pre-compute tokenized terms for all items to avoid repeated work. 

151 item_terms_by_item = [(item, set(_tokenize_terms(item.text))) for item in items] 

152 

153 return sorted( 

154 ( 

155 ScoredItem( 

156 id=item.id, 

157 text=item.text, 

158 score=item.score 

159 + len(query_terms & item_terms) * keyword_weight 

160 + len(proper_nouns & item_terms) * proper_noun_weight 

161 + _observation_proof_boost(item), 

162 fact_type=item.fact_type, 

163 metadata=item.metadata, 

164 tags=item.tags, 

165 memory_layer=item.memory_layer, 

166 retained_at=item.retained_at, 

167 chunk_id=getattr(item, "chunk_id", None), 

168 ) 

169 for item, item_terms in item_terms_by_item 

170 ), 

171 key=lambda x: x.score, 

172 reverse=True, 

173 ) 

174 

175 

176def cross_encoder_like_rerank( 

177 items: list[ScoredItem], 

178 query: str, 

179) -> list[ScoredItem]: 

180 """Final precision rerank using query-item interaction features. 

181 

182 This is a deterministic local stand-in for a cross-encoder: it scores each 

183 query/memory pair jointly, then applies entity/person and memory-layer 

184 signals. It is intentionally cheap enough to run before ``reflect()`` 

185 synthesis, where precision matters more than broad candidate coverage. 

186 """ 

187 if not items or not query: 

188 return items 

189 

190 query_terms = _content_terms(query) 

191 query_names = _proper_names(query) 

192 query_bigrams = _bigrams(_tokenize_terms(query)) 

193 

194 scored: list[ScoredItem] = [] 

195 for item in items: 

196 item_terms = _content_terms(item.text) 

197 item_names = _candidate_names(item) 

198 overlap = len(query_terms & item_terms) / max(len(query_terms), 1) 

199 phrase_hits = len(query_bigrams & _bigrams(_tokenize_terms(item.text))) 

200 

201 score = item.score 

202 score += overlap * QUERY_INTERACTION_WEIGHT 

203 score += min(phrase_hits, 3) * QUERY_PHRASE_WEIGHT 

204 score += _layer_boost(item) 

205 

206 if query_names: 

207 if query_names & item_names: 

208 score += QUERY_NAME_MATCH_WEIGHT 

209 elif item_names: 

210 score -= QUERY_NAME_MISS_PENALTY 

211 

212 scored.append( 

213 ScoredItem( 

214 id=item.id, 

215 text=item.text, 

216 score=max(score, 0.0), 

217 fact_type=item.fact_type, 

218 metadata=item.metadata, 

219 tags=item.tags, 

220 memory_layer=item.memory_layer, 

221 retained_at=item.retained_at, 

222 chunk_id=getattr(item, "chunk_id", None), 

223 ) 

224 ) 

225 

226 return sorted(scored, key=lambda x: x.score, reverse=True) 

227 

228 

229async def llm_pairwise_rerank( 

230 items: list[ScoredItem], 

231 query: str, 

232 llm_provider: LLMProvider, 

233 *, 

234 top_n: int = 30, 

235 keep_n: int | None = None, 

236) -> list[ScoredItem]: 

237 """Use the configured LLM as a lightweight listwise reranker. 

238 

239 The prompt asks for candidate IDs in descending relevance order. If the LLM 

240 response is missing or malformed, the deterministic local reranker is used. 

241 """ 

242 if not items or not query: 

243 return items 

244 candidates = cross_encoder_like_rerank(items[:top_n], query) 

245 remainder = items[top_n:] 

246 keep = keep_n or len(candidates) 

247 

248 prompt_lines = [ 

249 "Rank the memory candidates by how directly they answer the query.", 

250 "Penalize wrong-person or wrong-premise candidates.", 

251 'Return JSON only: {"ranked_ids": ["id1", "id2"]}.', 

252 "", 

253 f"Query: {query}", 

254 "", 

255 "Candidates:", 

256 ] 

257 for item in candidates: 

258 prompt_lines.append(f"- id={item.id} score={item.score:.4f}: {item.text[:500]}") 

259 

260 try: 

261 completion = await llm_provider.complete( 

262 [ 

263 Message(role="system", content="You are a strict memory reranker."), 

264 Message(role="user", content="\n".join(prompt_lines)), 

265 ], 

266 max_tokens=512, 

267 temperature=0.0, 

268 ) 

269 ranked_ids = _parse_ranked_ids(completion.text) 

270 except Exception: 

271 ranked_ids = [] 

272 

273 if not ranked_ids: 

274 return apply_context_diversity(candidates, query)[:keep] + remainder 

275 

276 by_id = {item.id: item for item in candidates} 

277 ordered: list[ScoredItem] = [] 

278 seen: set[str] = set() 

279 for item_id in ranked_ids: 

280 item = by_id.get(item_id) 

281 if item is not None and item_id not in seen: 

282 ordered.append(item) 

283 seen.add(item_id) 

284 ordered.extend(item for item in candidates if item.id not in seen) 

285 return apply_context_diversity(ordered, query)[:keep] + remainder 

286 

287 

288def apply_context_diversity(items: list[ScoredItem], query: str) -> list[ScoredItem]: 

289 """Softly penalize repeated sessions unless the query asks for aggregation.""" 

290 if not items: 

291 return items 

292 if _is_aggregate_query(query): 

293 return items 

294 

295 session_counts: dict[str, int] = {} 

296 diversified: list[ScoredItem] = [] 

297 for item in items: 

298 session = str((item.metadata or {}).get("session_id") or "") 

299 penalty = 0.0 

300 if session: 

301 seen = session_counts.get(session, 0) 

302 penalty = seen * SESSION_DIVERSITY_PENALTY 

303 session_counts[session] = seen + 1 

304 diversified.append( 

305 ScoredItem( 

306 id=item.id, 

307 text=item.text, 

308 score=max(item.score - penalty, 0.0), 

309 fact_type=item.fact_type, 

310 metadata=item.metadata, 

311 tags=item.tags, 

312 memory_layer=item.memory_layer, 

313 occurred_at=item.occurred_at, 

314 retained_at=item.retained_at, 

315 chunk_id=getattr(item, "chunk_id", None), 

316 ) 

317 ) 

318 return sorted(diversified, key=lambda item: item.score, reverse=True) 

319 

320 

321def _proper_names(text: str) -> set[str]: 

322 names: set[str] = set() 

323 for word in text.split(): 

324 cleaned = word.strip(punctuation) 

325 if cleaned and cleaned.istitle() and _is_name_token(cleaned): 

326 names.add(cleaned.lower()) 

327 return names 

328 

329 

330def _candidate_names(item: ScoredItem) -> set[str]: 

331 names = _proper_names(item.text) 

332 metadata = item.metadata or {} 

333 for key in ("locomo_speakers", "locomo_persons", "speakers", "person"): 

334 raw = metadata.get(key) 

335 if not raw: 

336 continue 

337 for part in str(raw).replace("|", ",").split(","): 

338 cleaned = part.strip().lower() 

339 if cleaned: 

340 names.add(cleaned) 

341 return names 

342 

343 

344def _parse_ranked_ids(text: str) -> list[str]: 

345 raw = text.strip() 

346 if raw.startswith("```"): 

347 raw = "\n".join(line for line in raw.splitlines() if not line.strip().startswith("```")) 

348 start = raw.find("{") 

349 end = raw.rfind("}") 

350 if start >= 0 and end > start: 

351 raw = raw[start : end + 1] 

352 try: 

353 parsed = json.loads(raw) 

354 except json.JSONDecodeError: 

355 return [] 

356 ids = parsed.get("ranked_ids") if isinstance(parsed, dict) else None 

357 if not isinstance(ids, list): 

358 return [] 

359 return [str(item) for item in ids if item] 

360 

361 

362def _is_aggregate_query(query: str) -> bool: 

363 query_l = (query or "").lower() 

364 return any(term in query_l for term in ("how many", "what are", "list", "all ", "which ")) 

365 

366 

367def _bigrams(tokens: list[str]) -> set[tuple[str, str]]: 

368 return set(zip(tokens, tokens[1:], strict=False)) 

369 

370 

371def _layer_boost(item: ScoredItem) -> float: 

372 if item.fact_type == "wiki" or item.memory_layer == "compiled": 

373 return COMPILED_LAYER_WEIGHT 

374 if item.fact_type == "observation" or item.memory_layer == "observation": 

375 return OBSERVATION_LAYER_WEIGHT + _observation_proof_boost(item) 

376 return 0.0 

377 

378 

379def _observation_proof_boost(item: ScoredItem) -> float: 

380 """Additive boost for observation items proportional to their proof count. 

381 

382 A single-evidence observation (``_obs_proof_count=1``) gets +0.0. 

383 Each additional corroborating memory adds ``OBSERVATION_PROOF_WEIGHT``, 

384 capped at ``OBSERVATION_PROOF_CAP`` extra proofs (+0.10 total). 

385 Raw memories (no ``_obs_proof_count``) are unaffected. 

386 """ 

387 if item.fact_type != "observation" or not item.metadata: 

388 return 0.0 

389 proof = item.metadata.get("_obs_proof_count", 1) 

390 try: 

391 extra = max(0, int(proof) - 1) 

392 except (TypeError, ValueError): 

393 return 0.0 

394 return min(extra, OBSERVATION_PROOF_CAP) * OBSERVATION_PROOF_WEIGHT