Coverage for astrocyte/pipeline/premise_verification.py: 83%
126 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Premise extraction and verification for adversarial defense.
3Adversarial questions in evaluation benchmarks (LoCoMo, etc.) follow
4predictable shapes:
6- **False premise**: "Why did Alice quit her job at Google?" — Alice
7 never worked at Google. The LLM left to its own devices rationalizes
8 the false premise instead of refusing.
9- **Negative existence**: "Did Caroline ever go skiing?" — she didn't.
10 The LLM invents a yes-answer from weakly-related hits.
11- **Time-shift**: "What happened in 2024?" — the conversation was 2023.
12 The LLM silently adopts the wrong date.
13- **Cross-entity confusion**: "Did they eat there?" with multiple
14 referents — the LLM picks the wrong referent.
16The shared failure mode is that the LLM produces a confident answer
17when the correct answer is "I don't know" or "the question presupposes
18something I have no evidence for."
20This module adds a **pre-loop verification step**: the question is
21decomposed by the LLM into atomic claims, and each claim is verified
22against memory via a focused recall. The results are returned as a
23structured verdict that the agentic-reflect loop (or single-shot
24synthesis) can incorporate into its evidence context. Crucially, when
25ANY presupposition lacks supporting evidence with high confidence, the
26caller can short-circuit to "insufficient evidence: <unsupported
27claim>" without even running the main reflect loop.
29Cost: 1 LLM call (premise extraction) + N focused recalls (one per
30claim, typically 1–3). Cheap relative to a 10-iter agentic loop;
31expensive relative to single-shot synth. Opt-in via config.
32"""
34from __future__ import annotations
36import json
37import logging
38import re
39from dataclasses import dataclass
40from typing import Awaitable, Callable
42from astrocyte.types import MemoryHit, Message
44_logger = logging.getLogger("astrocyte.premise_verification")
47# ---------------------------------------------------------------------------
48# Data shapes
49# ---------------------------------------------------------------------------
52@dataclass
53class Premise:
54 """A single atomic claim presupposed by a question."""
56 claim: str
57 #: Why the LLM extracted this premise — short reasoning. Useful for
58 #: debugging false-positive presupposition detection.
59 rationale: str = ""
62@dataclass
63class PremiseVerdict:
64 """Result of verifying a single premise against memory.
66 ``confidence`` reflects how strongly the retrieved evidence supports
67 the claim:
69 - ``>= min_confidence``: claim is supported by retrieved evidence
70 - ``< min_confidence``: claim is unsupported; the question's
71 presupposition fails and the caller should abstain
72 - ``None``: verification was inconclusive (no recall ran, error)
73 """
75 premise: Premise
76 supported: bool
77 confidence: float
78 evidence_ids: list[str]
79 rationale: str = ""
82@dataclass
83class QuestionVerification:
84 """End-to-end verification result for a question."""
86 premises: list[Premise]
87 verdicts: list[PremiseVerdict]
88 #: True when EVERY premise is supported with sufficient confidence.
89 #: Callers use this to decide whether to short-circuit before the
90 #: main reflect loop runs.
91 all_premises_supported: bool
93 def unsupported_premises(self) -> list[PremiseVerdict]:
94 return [v for v in self.verdicts if not v.supported]
96 def short_circuit_message(self) -> str | None:
97 """Return a "insufficient evidence" message when verification
98 failed — or ``None`` when the question is safe to proceed."""
99 if self.all_premises_supported:
100 return None
101 unsupported = self.unsupported_premises()
102 if not unsupported:
103 return None
104 # Quote the first unsupported claim — its absence usually
105 # explains the others (e.g. "Alice worked at Google" failing
106 # makes "Alice quit Google" moot).
107 first = unsupported[0]
108 return (
109 f"insufficient evidence: the question presupposes '{first.premise.claim}' which is not supported by memory."
110 )
113# ---------------------------------------------------------------------------
114# Prompts
115# ---------------------------------------------------------------------------
118_EXTRACTION_SYSTEM_PROMPT = """\
119You decompose questions into the FACTUAL PRESUPPOSITIONS they assume \
120to be true. The downstream system will verify each presupposition \
121against a memory bank before the question is answered, so it can \
122abstain when a presupposition is false.
124Return a JSON array of {claim, rationale} objects. Each ``claim`` is \
125ONE short atomic statement (≤ 15 words) that the question takes for \
126granted. ``rationale`` is one sentence explaining why the question \
127implies the claim.
129Rules:
1301. Decompose only PRESUPPOSITIONS — facts the question takes as given. \
131For "Why did Alice quit Google?" the presuppositions are:
132 - "Alice worked at Google"
133 - "Alice quit"
134 The "why" is the question's actual content; don't include it.
1352. For yes/no questions ("Did X happen?") the presupposition is the \
136participants/setting, NOT the event itself. The event IS what's being \
137asked. Example: "Did Alice play tennis at the club?" presupposes \
138"Alice was at the club" but NOT "Alice played tennis".
1393. For pure-fact lookups with no embedded assumption ("What is X?", \
140"When did Y happen?"), return [].
1414. Maximum 3 presuppositions per question — pick the most central.
1425. Output JSON only, no prose.
143"""
146def _build_extraction_user_prompt(question: str) -> str:
147 return f"Question: {question.strip()}\n\nPresuppositions (JSON array):"
150_VERIFICATION_SYSTEM_PROMPT = """\
151You judge whether retrieved memories support a specific factual claim.
153Output a JSON object: {"supported": bool, "confidence": float, \
154"evidence_ids": [...], "rationale": "<1 sentence>"}
156Rules:
1571. ``supported`` is True only when at least one memory directly \
158attests the claim. Adjacent / topical relevance is NOT support.
1592. ``confidence`` ∈ [0, 1]. ≥ 0.8 only when explicit. 0.5-0.8 for \
160strongly implied. Below 0.5: not supported.
1613. ``evidence_ids`` lists the memory IDs that attest the claim. Empty \
162list when not supported.
1634. Rationale is one sentence explaining the verdict.
165Output JSON only.
166"""
169def _build_verification_user_prompt(claim: str, hits: list[MemoryHit]) -> str:
170 if not hits:
171 return f"Claim: {claim}\n\nRetrieved memories: (none)\n\nVerdict (JSON):"
172 lines = [f"Claim: {claim}", "", "Retrieved memories:"]
173 for hit in hits:
174 text = (hit.text or "").strip()
175 if len(text) > 400:
176 text = text[:397] + "..."
177 lines.append(f"[{hit.memory_id}] {text}")
178 lines.extend(["", "Verdict (JSON):"])
179 return "\n".join(lines)
182# ---------------------------------------------------------------------------
183# JSON parsing helpers
184# ---------------------------------------------------------------------------
187def _parse_json_array(raw: str) -> list[dict]:
188 text = raw.strip()
189 if text.startswith("```"):
190 text = re.sub(r"^```(?:json)?\s*", "", text)
191 text = re.sub(r"\s*```$", "", text)
192 match = re.search(r"\[.*\]", text, re.DOTALL)
193 if match is None:
194 return []
195 try:
196 payload = json.loads(match.group(0))
197 except json.JSONDecodeError:
198 return []
199 if not isinstance(payload, list):
200 return []
201 return [p for p in payload if isinstance(p, dict)]
204def _parse_json_object(raw: str) -> dict | None:
205 text = raw.strip()
206 if text.startswith("```"):
207 text = re.sub(r"^```(?:json)?\s*", "", text)
208 text = re.sub(r"\s*```$", "", text)
209 match = re.search(r"\{.*\}", text, re.DOTALL)
210 if match is None:
211 return None
212 try:
213 payload = json.loads(match.group(0))
214 except json.JSONDecodeError:
215 return None
216 return payload if isinstance(payload, dict) else None
219# ---------------------------------------------------------------------------
220# Extraction + verification
221# ---------------------------------------------------------------------------
224async def extract_premises(question: str, llm_provider) -> list[Premise]:
225 """Decompose a question into the factual claims it presupposes.
227 Returns ``[]`` for pure fact-lookup questions (no embedded
228 assumption) and on any LLM/parse failure — caller treats empty list
229 as "no presupposition to verify, proceed normally."
230 """
231 if not question or not question.strip():
232 return []
233 try:
234 completion = await llm_provider.complete(
235 [
236 Message(role="system", content=_EXTRACTION_SYSTEM_PROMPT),
237 Message(role="user", content=_build_extraction_user_prompt(question)),
238 ],
239 max_tokens=512,
240 temperature=0.0,
241 )
242 except Exception as exc:
243 _logger.warning("premise extraction LLM call failed (%s)", exc)
244 return []
246 parsed = _parse_json_array(completion.text)
247 out: list[Premise] = []
248 for item in parsed[:3]:
249 claim = str(item.get("claim") or "").strip()
250 if not claim:
251 continue
252 rationale = str(item.get("rationale") or "").strip()
253 out.append(Premise(claim=claim, rationale=rationale))
254 return out
257RecallFn = Callable[[str, int], Awaitable[list[MemoryHit]]]
260async def verify_premise(
261 premise: Premise,
262 recall_fn: RecallFn,
263 llm_provider,
264 *,
265 recall_max_results: int = 5,
266 min_confidence: float = 0.6,
267) -> PremiseVerdict:
268 """Verify one premise against memory via focused recall + LLM judge.
270 The recall_fn is the orchestrator's existing recall (RRF + spread
271 + cross-encoder rerank + tag scope), parameterized by the claim
272 text. The judge LLM call is structured-JSON to keep parsing robust.
273 """
274 try:
275 hits = await recall_fn(premise.claim, recall_max_results)
276 except Exception as exc:
277 _logger.warning("premise verification recall failed (%s)", exc)
278 return PremiseVerdict(
279 premise=premise,
280 supported=False,
281 confidence=0.0,
282 evidence_ids=[],
283 rationale=f"recall failed: {exc}",
284 )
286 try:
287 completion = await llm_provider.complete(
288 [
289 Message(role="system", content=_VERIFICATION_SYSTEM_PROMPT),
290 Message(role="user", content=_build_verification_user_prompt(premise.claim, hits)),
291 ],
292 max_tokens=256,
293 temperature=0.0,
294 )
295 except Exception as exc:
296 _logger.warning("premise verification LLM call failed (%s)", exc)
297 return PremiseVerdict(
298 premise=premise,
299 supported=False,
300 confidence=0.0,
301 evidence_ids=[],
302 rationale=f"judge LLM failed: {exc}",
303 )
305 parsed = _parse_json_object(completion.text) or {}
306 supported_raw = parsed.get("supported", False)
307 supported = bool(supported_raw) if not isinstance(supported_raw, str) else supported_raw.lower() == "true"
308 try:
309 confidence = float(parsed.get("confidence", 0.0))
310 except (TypeError, ValueError):
311 confidence = 0.0
312 evidence_ids_raw = parsed.get("evidence_ids") or []
313 evidence_ids = [str(e) for e in evidence_ids_raw if e] if isinstance(evidence_ids_raw, list) else []
314 rationale = str(parsed.get("rationale") or "").strip()
316 # Final supported flag combines the LLM's verdict with the
317 # confidence threshold — defensive against the LLM saying
318 # "supported: true" with low confidence.
319 final_supported = supported and confidence >= min_confidence
321 return PremiseVerdict(
322 premise=premise,
323 supported=final_supported,
324 confidence=confidence,
325 evidence_ids=evidence_ids,
326 rationale=rationale,
327 )
330async def verify_question(
331 question: str,
332 *,
333 recall_fn: RecallFn,
334 llm_provider,
335 recall_max_results: int = 5,
336 min_confidence: float = 0.6,
337) -> QuestionVerification:
338 """End-to-end verification: extract premises, verify each.
340 Returns a :class:`QuestionVerification` whose
341 ``short_circuit_message()`` is non-None when the question's
342 presuppositions fail.
344 No-presupposition questions return ``all_premises_supported=True``
345 so the caller proceeds normally.
346 """
347 premises = await extract_premises(question, llm_provider)
348 if not premises:
349 return QuestionVerification(
350 premises=[],
351 verdicts=[],
352 all_premises_supported=True,
353 )
355 # Run verifications in parallel — each premise's recall + judge
356 # are independent, so latency is bounded by the slowest.
357 import asyncio
359 verdicts = await asyncio.gather(
360 *[
361 verify_premise(
362 p,
363 recall_fn,
364 llm_provider,
365 recall_max_results=recall_max_results,
366 min_confidence=min_confidence,
367 )
368 for p in premises
369 ]
370 )
372 all_supported = all(v.supported for v in verdicts)
373 return QuestionVerification(
374 premises=list(premises),
375 verdicts=list(verdicts),
376 all_premises_supported=all_supported,
377 )