Coverage for astrocyte/eval/judges/locomo_judge.py: 84%

90 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Canonical LoCoMo judge — ported from the paper's reference evaluation. 

2 

3Upstream: ``datasets/locomo/task_eval/evaluation.py`` from 

4https://github.com/snap-research/locomo. This module reproduces the 

5scoring logic used in the LoCoMo paper and subsequent public 

6comparisons (Mem0, Zep, Hindsight). Deterministic, pure Python, no LLM 

7— cheaper and more reproducible than LLM-judge approaches. 

8 

9## Scoring model 

10 

11LoCoMo evaluates QA predictions with **stemmed token-F1** on normalized 

12text, with **category-specific adjustments**: 

13 

14- **Category 1 — multi-hop**: prediction and ground truth are each split 

15 on commas into sub-answers; for each ground-truth sub-answer, take the 

16 max F1 across all prediction sub-answers; average those maxes. 

17- **Category 2 — temporal**: plain stemmed-token F1. 

18- **Category 3 — open-domain**: ground truth may have multiple acceptable 

19 forms separated by ``;`` — use the first. Plain F1. 

20- **Category 4 — single-hop**: plain F1. 

21- **Category 5 — adversarial**: correct if the prediction signals 

22 abstention (``"no information available"`` or ``"not mentioned"``). 

23 Binary 1/0 score. 

24 

25## Normalization pipeline 

26 

271. lowercase 

282. remove commas 

293. remove articles (``a|an|the|and``) 

304. remove punctuation 

315. collapse whitespace 

326. Porter stem each resulting token 

33 

34F1 is then computed on multiset-intersection of stemmed tokens. 

35 

36## What this module does NOT do 

37 

38- Does not generate answers (that's the upstream LLM pass — Astrocyte's 

39 reflect stage). 

40- Does not ingest per-question context / evidence checking. 

41- Does not compute BERTScore or RougeL (unused by the paper's headline 

42 metric). 

43 

44Those live upstream in the adapter that calls this judge. 

45""" 

46 

47from __future__ import annotations 

48 

49import re 

50import string 

51from collections import Counter 

52from typing import TYPE_CHECKING, Final 

53 

54from astrocyte.eval.judges._stemmer import porter_stem 

55 

56if TYPE_CHECKING: 

57 from astrocyte.provider import LLMProvider 

58 

59#: Category id mapping used by the canonical LoCoMo evaluator. Astrocyte's 

60#: adapter exposes categories as strings (``"multi-hop"``, etc.); 

61#: callers translate via :func:`locomo_category_id` before scoring. 

62#: 

63#: Verified against ``datasets/locomo/data/locomo10.json``: 

64#: 

65#: - cat 1 — multi-hop (multi-speaker synthesis, comma-listed answers) 

66#: - cat 2 — temporal ("when did..." questions) 

67#: - cat 3 — open-domain (commonsense inference — "would X likely...") 

68#: - cat 4 — single-hop (single-session factual) 

69#: - cat 5 — adversarial (unanswerable — empty GT) 

70LOCOMO_CATEGORY_IDS: Final[dict[str, int]] = { 

71 "multi-hop": 1, 

72 "temporal": 2, 

73 "open-domain": 3, 

74 "single-hop": 4, 

75 "adversarial": 5, 

76} 

77 

78#: Tokens treated as articles and removed during normalization. Matches 

79#: the upstream ``normalize_answer`` regex exactly (``a|an|the|and``). 

80_ARTICLES_RE: Final[re.Pattern[str]] = re.compile(r"\b(a|an|the|and)\b") 

81 

82#: Punctuation set removed during normalization (Python ``string.punctuation``). 

83_PUNCTUATION: Final[frozenset[str]] = frozenset(string.punctuation) 

84 

85#: Abstention signal phrases for category-5 scoring. 

86#: 

87#: The upstream paper's list is narrow (``"no information available"`` and 

88#: ``"not mentioned"``) because its baseline LLMs were instruction-tuned 

89#: toward that phrasing. Real-world reflect stages produce a wider range 

90#: of abstention expressions ("not in my memory", "cannot find", etc.) 

91#: that are semantically equivalent but miss the narrow match. Extend 

92#: here when operators find false-negatives in their v5+ runs; the tests 

93#: in :mod:`tests.test_eval_judges` pin the current set. 

94_ABSTENTION_PHRASES: Final[tuple[str, ...]] = ( 

95 # Upstream canonical phrases 

96 "no information available", 

97 "not mentioned", 

98 # Common LLM-output variants observed in v5 runs 

99 "no information", 

100 "not available", 

101 "not found", 

102 "not stated", 

103 "not specified", 

104 "not provided", 

105 "not discussed", 

106 "not indicated", 

107 "cannot find", 

108 "can't find", 

109 "don't have", 

110 "do not have", 

111 "unable to find", 

112 "no record", 

113 "nothing about", 

114 "not in the memor", # prefix: "not in the memory" / "memories" 

115 "not in my memor", # prefix: "not in my memory" / "memories" 

116 "not in the conversation", 

117 "not in the provided", 

118 "no mention", 

119 "isn't mentioned", 

120 "wasn't mentioned", 

121 "i don't know", 

122 "i do not know", 

123) 

124 

125 

126# --------------------------------------------------------------------------- 

127# Normalization — mirrors ``normalize_answer`` upstream 

128# --------------------------------------------------------------------------- 

129 

130 

131def _normalize_answer(text: str) -> str: 

132 """Lowercase, strip commas/articles/punctuation, collapse whitespace. 

133 

134 Matches upstream ``normalize_answer`` in 

135 ``datasets/locomo/task_eval/evaluation.py``. 

136 """ 

137 if text is None: 

138 return "" 

139 text = str(text) 

140 text = text.replace(",", "") 

141 # lower 

142 text = text.lower() 

143 # remove punctuation 

144 text = "".join(ch for ch in text if ch not in _PUNCTUATION) 

145 # remove articles 

146 text = _ARTICLES_RE.sub(" ", text) 

147 # collapse whitespace 

148 text = " ".join(text.split()) 

149 return text 

150 

151 

152def _normalize_and_stem(text: str) -> list[str]: 

153 """Normalize ``text`` and return the list of Porter-stemmed tokens.""" 

154 normalized = _normalize_answer(text) 

155 if not normalized: 

156 return [] 

157 return [porter_stem(w) for w in normalized.split()] 

158 

159 

160# --------------------------------------------------------------------------- 

161# F1 — mirrors ``f1_score`` upstream 

162# --------------------------------------------------------------------------- 

163 

164 

165def _f1_score(prediction: str, ground_truth: str) -> float: 

166 """Stemmed-token F1 between a prediction string and a ground-truth string. 

167 

168 Matches ``f1_score`` in 

169 ``datasets/locomo/task_eval/evaluation.py``. Returns 0.0 when either 

170 side has no tokens or when there is no token overlap. 

171 """ 

172 pred_tokens = _normalize_and_stem(prediction) 

173 gt_tokens = _normalize_and_stem(ground_truth) 

174 

175 common = Counter(pred_tokens) & Counter(gt_tokens) 

176 num_same = sum(common.values()) 

177 if num_same == 0: 

178 return 0.0 

179 

180 precision = num_same / len(pred_tokens) 

181 recall = num_same / len(gt_tokens) 

182 return (2.0 * precision * recall) / (precision + recall) 

183 

184 

185def _multi_hop_f1(prediction: str, ground_truth: str) -> float: 

186 """Multi-hop F1 — for each ground-truth sub-answer, max across all 

187 prediction sub-answers; then average. 

188 

189 Both sides split on ``,``. Matches ``f1`` (the multi-answer variant) 

190 upstream. 

191 """ 

192 predictions = [p.strip() for p in prediction.split(",") if p.strip()] 

193 ground_truths = [g.strip() for g in ground_truth.split(",") if g.strip()] 

194 if not predictions or not ground_truths: 

195 return 0.0 

196 per_gt: list[float] = [] 

197 for gt in ground_truths: 

198 per_gt.append(max(_f1_score(pred, gt) for pred in predictions)) 

199 return sum(per_gt) / len(per_gt) 

200 

201 

202# --------------------------------------------------------------------------- 

203# Category-specific dispatch — mirrors ``eval_question_answering`` upstream 

204# --------------------------------------------------------------------------- 

205 

206 

207def locomo_category_id(category: str | int) -> int: 

208 """Translate Astrocyte's string-category to the paper's integer id. 

209 

210 Accepts either the string form (``"single-hop"``) or the integer 

211 form (``2``). Raises :class:`ValueError` for unknown categories so a 

212 regression in adapter naming fails loudly. 

213 """ 

214 if isinstance(category, int): 

215 if category in LOCOMO_CATEGORY_IDS.values(): 

216 return category 

217 raise ValueError(f"Unknown LoCoMo category id: {category!r}") 

218 key = str(category).strip().lower() 

219 try: 

220 return LOCOMO_CATEGORY_IDS[key] 

221 except KeyError as exc: 

222 known = ", ".join(sorted(LOCOMO_CATEGORY_IDS.keys())) 

223 raise ValueError( 

224 f"Unknown LoCoMo category: {category!r} (known: {known})", 

225 ) from exc 

226 

227 

228def locomo_score_qa( 

229 prediction: str, 

230 ground_truth: str, 

231 category: str | int, 

232) -> float: 

233 """Score a single LoCoMo QA pair using the canonical judge. 

234 

235 Returns a float in ``[0.0, 1.0]``. The aggregator (adapter code) 

236 converts scores to a pass/fail at whatever threshold it wants; the 

237 paper reports raw means of these F1 scores per category, which we 

238 match. 

239 

240 Category-specific semantics: 

241 

242 - 1 / multi-hop: F1 on split sub-answers; average of per-GT-max. 

243 - 2 / temporal: plain F1. 

244 - 3 / open-domain: ground truth may carry alternates separated by ``;``; 

245 upstream takes the first alternate before scoring. Plain F1. 

246 - 4 / single-hop: plain F1. 

247 - 5 / adversarial: 1.0 when prediction contains an abstention 

248 phrase; 0.0 otherwise. 

249 """ 

250 cid = locomo_category_id(category) 

251 if prediction is None: 

252 prediction = "" 

253 if ground_truth is None: 

254 ground_truth = "" 

255 

256 if cid == 5: # adversarial 

257 lower = prediction.lower() 

258 return 1.0 if any(p in lower for p in _ABSTENTION_PHRASES) else 0.0 

259 

260 if cid == 1: # multi-hop 

261 return _multi_hop_f1(prediction, ground_truth) 

262 

263 if cid == 3: # open-domain — upstream takes first alternate 

264 # Upstream defensive: some open-domain answers carry ``;``- 

265 # separated alternates. Take only the first; plain F1 after. 

266 gt_for_scoring = ground_truth.split(";")[0].strip() 

267 return _f1_score(prediction, gt_for_scoring) 

268 

269 # cid in {2 temporal, 4 single-hop} — plain F1 

270 return _f1_score(prediction, ground_truth) 

271 

272 

273# Normalized-text helpers exposed for adapters that want to display 

274# what was actually scored (useful in debug logs / per-question reports). 

275 

276 

277def normalized_for_scoring(text: str) -> str: 

278 """Return the normalized form used for scoring — articles stripped, 

279 lowercase, punctuation removed, whitespace collapsed. Stems are NOT 

280 applied here (stemming is token-level during F1).""" 

281 return _normalize_answer(text) 

282 

283 

284# --------------------------------------------------------------------------- 

285# LLM judge — matches Mem0/Hindsight/MemMachine scoring convention 

286# --------------------------------------------------------------------------- 

287 

288#: Category-specific judge prompts. These match the binary yes/no convention 

289#: used by Mem0 (ECAI 2025 arXiv:2504.19413), Hindsight (arXiv:2512.12818), 

290#: and MemMachine (arXiv:2604.04853) — the scoring framework competitors 

291#: publish against. Use this judge for cross-competitor comparisons; use 

292#: :func:`locomo_score_qa` for numbers comparable to the original paper's F1. 

293_LLM_TEMPLATES: Final[dict[str, str]] = { 

294 "single-hop": ( 

295 "I will give you a question, the correct answer, and a response from a model. " 

296 "Answer yes if the response correctly answers the question. " 

297 "Otherwise, answer no.\n\n" 

298 "Question: {question}\n\nCorrect Answer: {answer}\n\nModel Response: {response}\n\n" 

299 "Is the model response correct? Answer yes or no only." 

300 ), 

301 "multi-hop": ( 

302 "I will give you a question requiring information from multiple parts of a " 

303 "conversation, the correct answer, and a response from a model. " 

304 "Answer yes if the response contains the correct information, even if phrased " 

305 "differently. Otherwise, answer no.\n\n" 

306 "Question: {question}\n\nCorrect Answer: {answer}\n\nModel Response: {response}\n\n" 

307 "Is the model response correct? Answer yes or no only." 

308 ), 

309 "temporal": ( 

310 "I will give you a question about events over time, the correct answer, " 

311 "and a response from a model. " 

312 "Answer yes if the response correctly answers the question. " 

313 "Do not penalize minor date approximations (e.g. off-by-one day or week).\n\n" 

314 "Question: {question}\n\nCorrect Answer: {answer}\n\nModel Response: {response}\n\n" 

315 "Is the model response correct? Answer yes or no only." 

316 ), 

317 "open-domain": ( 

318 "I will give you a question, the correct answer, and a response from a model. " 

319 "Answer yes if the response correctly answers the question. " 

320 "Otherwise, answer no.\n\n" 

321 "Question: {question}\n\nCorrect Answer: {answer}\n\nModel Response: {response}\n\n" 

322 "Is the model response correct? Answer yes or no only." 

323 ), 

324 "adversarial": ( 

325 "I will give you a question that cannot be answered from the conversation " 

326 "history, and a response from a model. " 

327 "Answer yes if the model correctly says the information is not available or " 

328 "that it cannot answer. Answer no if the model attempts to provide an answer.\n\n" 

329 "Question: {question}\n\nModel Response: {response}\n\n" 

330 "Does the model correctly identify the question as unanswerable? " 

331 "Answer yes or no only." 

332 ), 

333} 

334 

335 

336def build_locomo_llm_judge_prompt( 

337 question: str, 

338 answer: str, 

339 category: str, 

340 response: str, 

341) -> str: 

342 """Render the LLM judge prompt for a LoCoMo question. 

343 

344 Falls back to the ``"single-hop"`` template for unrecognised categories 

345 so new category strings don't break evaluation runs silently. 

346 """ 

347 template = _LLM_TEMPLATES.get(category.strip().lower(), _LLM_TEMPLATES["single-hop"]) 

348 if category.strip().lower() == "adversarial": 

349 return template.format(question=question, response=response) 

350 return template.format(question=question, answer=answer, response=response) 

351 

352 

353class LoCoMoLLMJudge: 

354 """LLM-backed yes/no judge for LoCoMo predictions. 

355 

356 Scores each question with a short binary prompt (yes/no), matching the 

357 convention used by Mem0 (ECAI 2025), Hindsight, and MemMachine. Use 

358 this judge — not the stemmed-F1 :func:`locomo_score_qa` — for numbers 

359 directly comparable to those published by competitors. 

360 

361 Instantiate once per benchmark run. Thread-safe for concurrent calls. 

362 """ 

363 

364 def __init__( 

365 self, 

366 llm_provider: LLMProvider, 

367 *, 

368 model: str | None = None, 

369 max_tokens: int = 4, 

370 temperature: float = 0.0, 

371 ) -> None: 

372 self._llm = llm_provider 

373 self._model = model 

374 self._max_tokens = max_tokens 

375 self._temperature = temperature 

376 

377 async def score( 

378 self, 

379 question: str, 

380 answer: str, 

381 category: str, 

382 response: str, 

383 ) -> float: 

384 """Return 1.0 (yes) or 0.0 (no). LLM errors propagate to caller.""" 

385 from astrocyte.types import Message 

386 

387 prompt = build_locomo_llm_judge_prompt(question, answer, category, response) 

388 completion = await self._llm.complete( 

389 messages=[Message(role="user", content=prompt)], 

390 model=self._model, 

391 max_tokens=self._max_tokens, 

392 temperature=self._temperature, 

393 ) 

394 # Reuse LongMemEval's yes/no parser — same tolerance rules. 

395 from astrocyte.eval.judges.longmemeval_judge import parse_yes_no 

396 

397 return parse_yes_no(completion.text)