Coverage for astrocyte/eval/judges/locomo_judge.py: 84%
90 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Canonical LoCoMo judge — ported from the paper's reference evaluation.
3Upstream: ``datasets/locomo/task_eval/evaluation.py`` from
4https://github.com/snap-research/locomo. This module reproduces the
5scoring logic used in the LoCoMo paper and subsequent public
6comparisons (Mem0, Zep, Hindsight). Deterministic, pure Python, no LLM
7— cheaper and more reproducible than LLM-judge approaches.
9## Scoring model
11LoCoMo evaluates QA predictions with **stemmed token-F1** on normalized
12text, with **category-specific adjustments**:
14- **Category 1 — multi-hop**: prediction and ground truth are each split
15 on commas into sub-answers; for each ground-truth sub-answer, take the
16 max F1 across all prediction sub-answers; average those maxes.
17- **Category 2 — temporal**: plain stemmed-token F1.
18- **Category 3 — open-domain**: ground truth may have multiple acceptable
19 forms separated by ``;`` — use the first. Plain F1.
20- **Category 4 — single-hop**: plain F1.
21- **Category 5 — adversarial**: correct if the prediction signals
22 abstention (``"no information available"`` or ``"not mentioned"``).
23 Binary 1/0 score.
25## Normalization pipeline
271. lowercase
282. remove commas
293. remove articles (``a|an|the|and``)
304. remove punctuation
315. collapse whitespace
326. Porter stem each resulting token
34F1 is then computed on multiset-intersection of stemmed tokens.
36## What this module does NOT do
38- Does not generate answers (that's the upstream LLM pass — Astrocyte's
39 reflect stage).
40- Does not ingest per-question context / evidence checking.
41- Does not compute BERTScore or RougeL (unused by the paper's headline
42 metric).
44Those live upstream in the adapter that calls this judge.
45"""
47from __future__ import annotations
49import re
50import string
51from collections import Counter
52from typing import TYPE_CHECKING, Final
54from astrocyte.eval.judges._stemmer import porter_stem
56if TYPE_CHECKING:
57 from astrocyte.provider import LLMProvider
59#: Category id mapping used by the canonical LoCoMo evaluator. Astrocyte's
60#: adapter exposes categories as strings (``"multi-hop"``, etc.);
61#: callers translate via :func:`locomo_category_id` before scoring.
62#:
63#: Verified against ``datasets/locomo/data/locomo10.json``:
64#:
65#: - cat 1 — multi-hop (multi-speaker synthesis, comma-listed answers)
66#: - cat 2 — temporal ("when did..." questions)
67#: - cat 3 — open-domain (commonsense inference — "would X likely...")
68#: - cat 4 — single-hop (single-session factual)
69#: - cat 5 — adversarial (unanswerable — empty GT)
70LOCOMO_CATEGORY_IDS: Final[dict[str, int]] = {
71 "multi-hop": 1,
72 "temporal": 2,
73 "open-domain": 3,
74 "single-hop": 4,
75 "adversarial": 5,
76}
78#: Tokens treated as articles and removed during normalization. Matches
79#: the upstream ``normalize_answer`` regex exactly (``a|an|the|and``).
80_ARTICLES_RE: Final[re.Pattern[str]] = re.compile(r"\b(a|an|the|and)\b")
82#: Punctuation set removed during normalization (Python ``string.punctuation``).
83_PUNCTUATION: Final[frozenset[str]] = frozenset(string.punctuation)
85#: Abstention signal phrases for category-5 scoring.
86#:
87#: The upstream paper's list is narrow (``"no information available"`` and
88#: ``"not mentioned"``) because its baseline LLMs were instruction-tuned
89#: toward that phrasing. Real-world reflect stages produce a wider range
90#: of abstention expressions ("not in my memory", "cannot find", etc.)
91#: that are semantically equivalent but miss the narrow match. Extend
92#: here when operators find false-negatives in their v5+ runs; the tests
93#: in :mod:`tests.test_eval_judges` pin the current set.
94_ABSTENTION_PHRASES: Final[tuple[str, ...]] = (
95 # Upstream canonical phrases
96 "no information available",
97 "not mentioned",
98 # Common LLM-output variants observed in v5 runs
99 "no information",
100 "not available",
101 "not found",
102 "not stated",
103 "not specified",
104 "not provided",
105 "not discussed",
106 "not indicated",
107 "cannot find",
108 "can't find",
109 "don't have",
110 "do not have",
111 "unable to find",
112 "no record",
113 "nothing about",
114 "not in the memor", # prefix: "not in the memory" / "memories"
115 "not in my memor", # prefix: "not in my memory" / "memories"
116 "not in the conversation",
117 "not in the provided",
118 "no mention",
119 "isn't mentioned",
120 "wasn't mentioned",
121 "i don't know",
122 "i do not know",
123)
126# ---------------------------------------------------------------------------
127# Normalization — mirrors ``normalize_answer`` upstream
128# ---------------------------------------------------------------------------
131def _normalize_answer(text: str) -> str:
132 """Lowercase, strip commas/articles/punctuation, collapse whitespace.
134 Matches upstream ``normalize_answer`` in
135 ``datasets/locomo/task_eval/evaluation.py``.
136 """
137 if text is None:
138 return ""
139 text = str(text)
140 text = text.replace(",", "")
141 # lower
142 text = text.lower()
143 # remove punctuation
144 text = "".join(ch for ch in text if ch not in _PUNCTUATION)
145 # remove articles
146 text = _ARTICLES_RE.sub(" ", text)
147 # collapse whitespace
148 text = " ".join(text.split())
149 return text
152def _normalize_and_stem(text: str) -> list[str]:
153 """Normalize ``text`` and return the list of Porter-stemmed tokens."""
154 normalized = _normalize_answer(text)
155 if not normalized:
156 return []
157 return [porter_stem(w) for w in normalized.split()]
160# ---------------------------------------------------------------------------
161# F1 — mirrors ``f1_score`` upstream
162# ---------------------------------------------------------------------------
165def _f1_score(prediction: str, ground_truth: str) -> float:
166 """Stemmed-token F1 between a prediction string and a ground-truth string.
168 Matches ``f1_score`` in
169 ``datasets/locomo/task_eval/evaluation.py``. Returns 0.0 when either
170 side has no tokens or when there is no token overlap.
171 """
172 pred_tokens = _normalize_and_stem(prediction)
173 gt_tokens = _normalize_and_stem(ground_truth)
175 common = Counter(pred_tokens) & Counter(gt_tokens)
176 num_same = sum(common.values())
177 if num_same == 0:
178 return 0.0
180 precision = num_same / len(pred_tokens)
181 recall = num_same / len(gt_tokens)
182 return (2.0 * precision * recall) / (precision + recall)
185def _multi_hop_f1(prediction: str, ground_truth: str) -> float:
186 """Multi-hop F1 — for each ground-truth sub-answer, max across all
187 prediction sub-answers; then average.
189 Both sides split on ``,``. Matches ``f1`` (the multi-answer variant)
190 upstream.
191 """
192 predictions = [p.strip() for p in prediction.split(",") if p.strip()]
193 ground_truths = [g.strip() for g in ground_truth.split(",") if g.strip()]
194 if not predictions or not ground_truths:
195 return 0.0
196 per_gt: list[float] = []
197 for gt in ground_truths:
198 per_gt.append(max(_f1_score(pred, gt) for pred in predictions))
199 return sum(per_gt) / len(per_gt)
202# ---------------------------------------------------------------------------
203# Category-specific dispatch — mirrors ``eval_question_answering`` upstream
204# ---------------------------------------------------------------------------
207def locomo_category_id(category: str | int) -> int:
208 """Translate Astrocyte's string-category to the paper's integer id.
210 Accepts either the string form (``"single-hop"``) or the integer
211 form (``2``). Raises :class:`ValueError` for unknown categories so a
212 regression in adapter naming fails loudly.
213 """
214 if isinstance(category, int):
215 if category in LOCOMO_CATEGORY_IDS.values():
216 return category
217 raise ValueError(f"Unknown LoCoMo category id: {category!r}")
218 key = str(category).strip().lower()
219 try:
220 return LOCOMO_CATEGORY_IDS[key]
221 except KeyError as exc:
222 known = ", ".join(sorted(LOCOMO_CATEGORY_IDS.keys()))
223 raise ValueError(
224 f"Unknown LoCoMo category: {category!r} (known: {known})",
225 ) from exc
228def locomo_score_qa(
229 prediction: str,
230 ground_truth: str,
231 category: str | int,
232) -> float:
233 """Score a single LoCoMo QA pair using the canonical judge.
235 Returns a float in ``[0.0, 1.0]``. The aggregator (adapter code)
236 converts scores to a pass/fail at whatever threshold it wants; the
237 paper reports raw means of these F1 scores per category, which we
238 match.
240 Category-specific semantics:
242 - 1 / multi-hop: F1 on split sub-answers; average of per-GT-max.
243 - 2 / temporal: plain F1.
244 - 3 / open-domain: ground truth may carry alternates separated by ``;``;
245 upstream takes the first alternate before scoring. Plain F1.
246 - 4 / single-hop: plain F1.
247 - 5 / adversarial: 1.0 when prediction contains an abstention
248 phrase; 0.0 otherwise.
249 """
250 cid = locomo_category_id(category)
251 if prediction is None:
252 prediction = ""
253 if ground_truth is None:
254 ground_truth = ""
256 if cid == 5: # adversarial
257 lower = prediction.lower()
258 return 1.0 if any(p in lower for p in _ABSTENTION_PHRASES) else 0.0
260 if cid == 1: # multi-hop
261 return _multi_hop_f1(prediction, ground_truth)
263 if cid == 3: # open-domain — upstream takes first alternate
264 # Upstream defensive: some open-domain answers carry ``;``-
265 # separated alternates. Take only the first; plain F1 after.
266 gt_for_scoring = ground_truth.split(";")[0].strip()
267 return _f1_score(prediction, gt_for_scoring)
269 # cid in {2 temporal, 4 single-hop} — plain F1
270 return _f1_score(prediction, ground_truth)
273# Normalized-text helpers exposed for adapters that want to display
274# what was actually scored (useful in debug logs / per-question reports).
277def normalized_for_scoring(text: str) -> str:
278 """Return the normalized form used for scoring — articles stripped,
279 lowercase, punctuation removed, whitespace collapsed. Stems are NOT
280 applied here (stemming is token-level during F1)."""
281 return _normalize_answer(text)
284# ---------------------------------------------------------------------------
285# LLM judge — matches Mem0/Hindsight/MemMachine scoring convention
286# ---------------------------------------------------------------------------
288#: Category-specific judge prompts. These match the binary yes/no convention
289#: used by Mem0 (ECAI 2025 arXiv:2504.19413), Hindsight (arXiv:2512.12818),
290#: and MemMachine (arXiv:2604.04853) — the scoring framework competitors
291#: publish against. Use this judge for cross-competitor comparisons; use
292#: :func:`locomo_score_qa` for numbers comparable to the original paper's F1.
293_LLM_TEMPLATES: Final[dict[str, str]] = {
294 "single-hop": (
295 "I will give you a question, the correct answer, and a response from a model. "
296 "Answer yes if the response correctly answers the question. "
297 "Otherwise, answer no.\n\n"
298 "Question: {question}\n\nCorrect Answer: {answer}\n\nModel Response: {response}\n\n"
299 "Is the model response correct? Answer yes or no only."
300 ),
301 "multi-hop": (
302 "I will give you a question requiring information from multiple parts of a "
303 "conversation, the correct answer, and a response from a model. "
304 "Answer yes if the response contains the correct information, even if phrased "
305 "differently. Otherwise, answer no.\n\n"
306 "Question: {question}\n\nCorrect Answer: {answer}\n\nModel Response: {response}\n\n"
307 "Is the model response correct? Answer yes or no only."
308 ),
309 "temporal": (
310 "I will give you a question about events over time, the correct answer, "
311 "and a response from a model. "
312 "Answer yes if the response correctly answers the question. "
313 "Do not penalize minor date approximations (e.g. off-by-one day or week).\n\n"
314 "Question: {question}\n\nCorrect Answer: {answer}\n\nModel Response: {response}\n\n"
315 "Is the model response correct? Answer yes or no only."
316 ),
317 "open-domain": (
318 "I will give you a question, the correct answer, and a response from a model. "
319 "Answer yes if the response correctly answers the question. "
320 "Otherwise, answer no.\n\n"
321 "Question: {question}\n\nCorrect Answer: {answer}\n\nModel Response: {response}\n\n"
322 "Is the model response correct? Answer yes or no only."
323 ),
324 "adversarial": (
325 "I will give you a question that cannot be answered from the conversation "
326 "history, and a response from a model. "
327 "Answer yes if the model correctly says the information is not available or "
328 "that it cannot answer. Answer no if the model attempts to provide an answer.\n\n"
329 "Question: {question}\n\nModel Response: {response}\n\n"
330 "Does the model correctly identify the question as unanswerable? "
331 "Answer yes or no only."
332 ),
333}
336def build_locomo_llm_judge_prompt(
337 question: str,
338 answer: str,
339 category: str,
340 response: str,
341) -> str:
342 """Render the LLM judge prompt for a LoCoMo question.
344 Falls back to the ``"single-hop"`` template for unrecognised categories
345 so new category strings don't break evaluation runs silently.
346 """
347 template = _LLM_TEMPLATES.get(category.strip().lower(), _LLM_TEMPLATES["single-hop"])
348 if category.strip().lower() == "adversarial":
349 return template.format(question=question, response=response)
350 return template.format(question=question, answer=answer, response=response)
353class LoCoMoLLMJudge:
354 """LLM-backed yes/no judge for LoCoMo predictions.
356 Scores each question with a short binary prompt (yes/no), matching the
357 convention used by Mem0 (ECAI 2025), Hindsight, and MemMachine. Use
358 this judge — not the stemmed-F1 :func:`locomo_score_qa` — for numbers
359 directly comparable to those published by competitors.
361 Instantiate once per benchmark run. Thread-safe for concurrent calls.
362 """
364 def __init__(
365 self,
366 llm_provider: LLMProvider,
367 *,
368 model: str | None = None,
369 max_tokens: int = 4,
370 temperature: float = 0.0,
371 ) -> None:
372 self._llm = llm_provider
373 self._model = model
374 self._max_tokens = max_tokens
375 self._temperature = temperature
377 async def score(
378 self,
379 question: str,
380 answer: str,
381 category: str,
382 response: str,
383 ) -> float:
384 """Return 1.0 (yes) or 0.0 (no). LLM errors propagate to caller."""
385 from astrocyte.types import Message
387 prompt = build_locomo_llm_judge_prompt(question, answer, category, response)
388 completion = await self._llm.complete(
389 messages=[Message(role="user", content=prompt)],
390 model=self._model,
391 max_tokens=self._max_tokens,
392 temperature=self._temperature,
393 )
394 # Reuse LongMemEval's yes/no parser — same tolerance rules.
395 from astrocyte.eval.judges.longmemeval_judge import parse_yes_no
397 return parse_yes_no(completion.text)