Coverage for astrocyte/pipeline/premise_verification.py: 83%

126 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Premise extraction and verification for adversarial defense. 

2 

3Adversarial questions in evaluation benchmarks (LoCoMo, etc.) follow 

4predictable shapes: 

5 

6- **False premise**: "Why did Alice quit her job at Google?" — Alice 

7 never worked at Google. The LLM left to its own devices rationalizes 

8 the false premise instead of refusing. 

9- **Negative existence**: "Did Caroline ever go skiing?" — she didn't. 

10 The LLM invents a yes-answer from weakly-related hits. 

11- **Time-shift**: "What happened in 2024?" — the conversation was 2023. 

12 The LLM silently adopts the wrong date. 

13- **Cross-entity confusion**: "Did they eat there?" with multiple 

14 referents — the LLM picks the wrong referent. 

15 

16The shared failure mode is that the LLM produces a confident answer 

17when the correct answer is "I don't know" or "the question presupposes 

18something I have no evidence for." 

19 

20This module adds a **pre-loop verification step**: the question is 

21decomposed by the LLM into atomic claims, and each claim is verified 

22against memory via a focused recall. The results are returned as a 

23structured verdict that the agentic-reflect loop (or single-shot 

24synthesis) can incorporate into its evidence context. Crucially, when 

25ANY presupposition lacks supporting evidence with high confidence, the 

26caller can short-circuit to "insufficient evidence: <unsupported 

27claim>" without even running the main reflect loop. 

28 

29Cost: 1 LLM call (premise extraction) + N focused recalls (one per 

30claim, typically 1–3). Cheap relative to a 10-iter agentic loop; 

31expensive relative to single-shot synth. Opt-in via config. 

32""" 

33 

34from __future__ import annotations 

35 

36import json 

37import logging 

38import re 

39from dataclasses import dataclass 

40from typing import Awaitable, Callable 

41 

42from astrocyte.types import MemoryHit, Message 

43 

44_logger = logging.getLogger("astrocyte.premise_verification") 

45 

46 

47# --------------------------------------------------------------------------- 

48# Data shapes 

49# --------------------------------------------------------------------------- 

50 

51 

52@dataclass 

53class Premise: 

54 """A single atomic claim presupposed by a question.""" 

55 

56 claim: str 

57 #: Why the LLM extracted this premise — short reasoning. Useful for 

58 #: debugging false-positive presupposition detection. 

59 rationale: str = "" 

60 

61 

62@dataclass 

63class PremiseVerdict: 

64 """Result of verifying a single premise against memory. 

65 

66 ``confidence`` reflects how strongly the retrieved evidence supports 

67 the claim: 

68 

69 - ``>= min_confidence``: claim is supported by retrieved evidence 

70 - ``< min_confidence``: claim is unsupported; the question's 

71 presupposition fails and the caller should abstain 

72 - ``None``: verification was inconclusive (no recall ran, error) 

73 """ 

74 

75 premise: Premise 

76 supported: bool 

77 confidence: float 

78 evidence_ids: list[str] 

79 rationale: str = "" 

80 

81 

82@dataclass 

83class QuestionVerification: 

84 """End-to-end verification result for a question.""" 

85 

86 premises: list[Premise] 

87 verdicts: list[PremiseVerdict] 

88 #: True when EVERY premise is supported with sufficient confidence. 

89 #: Callers use this to decide whether to short-circuit before the 

90 #: main reflect loop runs. 

91 all_premises_supported: bool 

92 

93 def unsupported_premises(self) -> list[PremiseVerdict]: 

94 return [v for v in self.verdicts if not v.supported] 

95 

96 def short_circuit_message(self) -> str | None: 

97 """Return a "insufficient evidence" message when verification 

98 failed — or ``None`` when the question is safe to proceed.""" 

99 if self.all_premises_supported: 

100 return None 

101 unsupported = self.unsupported_premises() 

102 if not unsupported: 

103 return None 

104 # Quote the first unsupported claim — its absence usually 

105 # explains the others (e.g. "Alice worked at Google" failing 

106 # makes "Alice quit Google" moot). 

107 first = unsupported[0] 

108 return ( 

109 f"insufficient evidence: the question presupposes '{first.premise.claim}' which is not supported by memory." 

110 ) 

111 

112 

113# --------------------------------------------------------------------------- 

114# Prompts 

115# --------------------------------------------------------------------------- 

116 

117 

118_EXTRACTION_SYSTEM_PROMPT = """\ 

119You decompose questions into the FACTUAL PRESUPPOSITIONS they assume \ 

120to be true. The downstream system will verify each presupposition \ 

121against a memory bank before the question is answered, so it can \ 

122abstain when a presupposition is false. 

123 

124Return a JSON array of {claim, rationale} objects. Each ``claim`` is \ 

125ONE short atomic statement (≤ 15 words) that the question takes for \ 

126granted. ``rationale`` is one sentence explaining why the question \ 

127implies the claim. 

128 

129Rules: 

1301. Decompose only PRESUPPOSITIONS — facts the question takes as given. \ 

131For "Why did Alice quit Google?" the presuppositions are: 

132 - "Alice worked at Google" 

133 - "Alice quit" 

134 The "why" is the question's actual content; don't include it. 

1352. For yes/no questions ("Did X happen?") the presupposition is the \ 

136participants/setting, NOT the event itself. The event IS what's being \ 

137asked. Example: "Did Alice play tennis at the club?" presupposes \ 

138"Alice was at the club" but NOT "Alice played tennis". 

1393. For pure-fact lookups with no embedded assumption ("What is X?", \ 

140"When did Y happen?"), return []. 

1414. Maximum 3 presuppositions per question — pick the most central. 

1425. Output JSON only, no prose. 

143""" 

144 

145 

146def _build_extraction_user_prompt(question: str) -> str: 

147 return f"Question: {question.strip()}\n\nPresuppositions (JSON array):" 

148 

149 

150_VERIFICATION_SYSTEM_PROMPT = """\ 

151You judge whether retrieved memories support a specific factual claim. 

152 

153Output a JSON object: {"supported": bool, "confidence": float, \ 

154"evidence_ids": [...], "rationale": "<1 sentence>"} 

155 

156Rules: 

1571. ``supported`` is True only when at least one memory directly \ 

158attests the claim. Adjacent / topical relevance is NOT support. 

1592. ``confidence`` ∈ [0, 1]. ≥ 0.8 only when explicit. 0.5-0.8 for \ 

160strongly implied. Below 0.5: not supported. 

1613. ``evidence_ids`` lists the memory IDs that attest the claim. Empty \ 

162list when not supported. 

1634. Rationale is one sentence explaining the verdict. 

164 

165Output JSON only. 

166""" 

167 

168 

169def _build_verification_user_prompt(claim: str, hits: list[MemoryHit]) -> str: 

170 if not hits: 

171 return f"Claim: {claim}\n\nRetrieved memories: (none)\n\nVerdict (JSON):" 

172 lines = [f"Claim: {claim}", "", "Retrieved memories:"] 

173 for hit in hits: 

174 text = (hit.text or "").strip() 

175 if len(text) > 400: 

176 text = text[:397] + "..." 

177 lines.append(f"[{hit.memory_id}] {text}") 

178 lines.extend(["", "Verdict (JSON):"]) 

179 return "\n".join(lines) 

180 

181 

182# --------------------------------------------------------------------------- 

183# JSON parsing helpers 

184# --------------------------------------------------------------------------- 

185 

186 

187def _parse_json_array(raw: str) -> list[dict]: 

188 text = raw.strip() 

189 if text.startswith("```"): 

190 text = re.sub(r"^```(?:json)?\s*", "", text) 

191 text = re.sub(r"\s*```$", "", text) 

192 match = re.search(r"\[.*\]", text, re.DOTALL) 

193 if match is None: 

194 return [] 

195 try: 

196 payload = json.loads(match.group(0)) 

197 except json.JSONDecodeError: 

198 return [] 

199 if not isinstance(payload, list): 

200 return [] 

201 return [p for p in payload if isinstance(p, dict)] 

202 

203 

204def _parse_json_object(raw: str) -> dict | None: 

205 text = raw.strip() 

206 if text.startswith("```"): 

207 text = re.sub(r"^```(?:json)?\s*", "", text) 

208 text = re.sub(r"\s*```$", "", text) 

209 match = re.search(r"\{.*\}", text, re.DOTALL) 

210 if match is None: 

211 return None 

212 try: 

213 payload = json.loads(match.group(0)) 

214 except json.JSONDecodeError: 

215 return None 

216 return payload if isinstance(payload, dict) else None 

217 

218 

219# --------------------------------------------------------------------------- 

220# Extraction + verification 

221# --------------------------------------------------------------------------- 

222 

223 

224async def extract_premises(question: str, llm_provider) -> list[Premise]: 

225 """Decompose a question into the factual claims it presupposes. 

226 

227 Returns ``[]`` for pure fact-lookup questions (no embedded 

228 assumption) and on any LLM/parse failure — caller treats empty list 

229 as "no presupposition to verify, proceed normally." 

230 """ 

231 if not question or not question.strip(): 

232 return [] 

233 try: 

234 completion = await llm_provider.complete( 

235 [ 

236 Message(role="system", content=_EXTRACTION_SYSTEM_PROMPT), 

237 Message(role="user", content=_build_extraction_user_prompt(question)), 

238 ], 

239 max_tokens=512, 

240 temperature=0.0, 

241 ) 

242 except Exception as exc: 

243 _logger.warning("premise extraction LLM call failed (%s)", exc) 

244 return [] 

245 

246 parsed = _parse_json_array(completion.text) 

247 out: list[Premise] = [] 

248 for item in parsed[:3]: 

249 claim = str(item.get("claim") or "").strip() 

250 if not claim: 

251 continue 

252 rationale = str(item.get("rationale") or "").strip() 

253 out.append(Premise(claim=claim, rationale=rationale)) 

254 return out 

255 

256 

257RecallFn = Callable[[str, int], Awaitable[list[MemoryHit]]] 

258 

259 

260async def verify_premise( 

261 premise: Premise, 

262 recall_fn: RecallFn, 

263 llm_provider, 

264 *, 

265 recall_max_results: int = 5, 

266 min_confidence: float = 0.6, 

267) -> PremiseVerdict: 

268 """Verify one premise against memory via focused recall + LLM judge. 

269 

270 The recall_fn is the orchestrator's existing recall (RRF + spread 

271 + cross-encoder rerank + tag scope), parameterized by the claim 

272 text. The judge LLM call is structured-JSON to keep parsing robust. 

273 """ 

274 try: 

275 hits = await recall_fn(premise.claim, recall_max_results) 

276 except Exception as exc: 

277 _logger.warning("premise verification recall failed (%s)", exc) 

278 return PremiseVerdict( 

279 premise=premise, 

280 supported=False, 

281 confidence=0.0, 

282 evidence_ids=[], 

283 rationale=f"recall failed: {exc}", 

284 ) 

285 

286 try: 

287 completion = await llm_provider.complete( 

288 [ 

289 Message(role="system", content=_VERIFICATION_SYSTEM_PROMPT), 

290 Message(role="user", content=_build_verification_user_prompt(premise.claim, hits)), 

291 ], 

292 max_tokens=256, 

293 temperature=0.0, 

294 ) 

295 except Exception as exc: 

296 _logger.warning("premise verification LLM call failed (%s)", exc) 

297 return PremiseVerdict( 

298 premise=premise, 

299 supported=False, 

300 confidence=0.0, 

301 evidence_ids=[], 

302 rationale=f"judge LLM failed: {exc}", 

303 ) 

304 

305 parsed = _parse_json_object(completion.text) or {} 

306 supported_raw = parsed.get("supported", False) 

307 supported = bool(supported_raw) if not isinstance(supported_raw, str) else supported_raw.lower() == "true" 

308 try: 

309 confidence = float(parsed.get("confidence", 0.0)) 

310 except (TypeError, ValueError): 

311 confidence = 0.0 

312 evidence_ids_raw = parsed.get("evidence_ids") or [] 

313 evidence_ids = [str(e) for e in evidence_ids_raw if e] if isinstance(evidence_ids_raw, list) else [] 

314 rationale = str(parsed.get("rationale") or "").strip() 

315 

316 # Final supported flag combines the LLM's verdict with the 

317 # confidence threshold — defensive against the LLM saying 

318 # "supported: true" with low confidence. 

319 final_supported = supported and confidence >= min_confidence 

320 

321 return PremiseVerdict( 

322 premise=premise, 

323 supported=final_supported, 

324 confidence=confidence, 

325 evidence_ids=evidence_ids, 

326 rationale=rationale, 

327 ) 

328 

329 

330async def verify_question( 

331 question: str, 

332 *, 

333 recall_fn: RecallFn, 

334 llm_provider, 

335 recall_max_results: int = 5, 

336 min_confidence: float = 0.6, 

337) -> QuestionVerification: 

338 """End-to-end verification: extract premises, verify each. 

339 

340 Returns a :class:`QuestionVerification` whose 

341 ``short_circuit_message()`` is non-None when the question's 

342 presuppositions fail. 

343 

344 No-presupposition questions return ``all_premises_supported=True`` 

345 so the caller proceeds normally. 

346 """ 

347 premises = await extract_premises(question, llm_provider) 

348 if not premises: 

349 return QuestionVerification( 

350 premises=[], 

351 verdicts=[], 

352 all_premises_supported=True, 

353 ) 

354 

355 # Run verifications in parallel — each premise's recall + judge 

356 # are independent, so latency is bounded by the slowest. 

357 import asyncio 

358 

359 verdicts = await asyncio.gather( 

360 *[ 

361 verify_premise( 

362 p, 

363 recall_fn, 

364 llm_provider, 

365 recall_max_results=recall_max_results, 

366 min_confidence=min_confidence, 

367 ) 

368 for p in premises 

369 ] 

370 ) 

371 

372 all_supported = all(v.supported for v in verdicts) 

373 return QuestionVerification( 

374 premises=list(premises), 

375 verdicts=list(verdicts), 

376 all_premises_supported=all_supported, 

377 )