Coverage for astrocyte/pipeline/fact_causal_extraction.py: 85%
91 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Fact-level cause→effect link extraction (Hindsight parity).
3This is the C2 rewrite of ``causal_extraction.py``. The key change is
4**granularity**: Hindsight extracts causal relations between FACTS
5(full-text statements with rich context), not between entities.
6``"burnout"`` → ``"resignation"`` is much weaker than
7``"Alice was burned out from 80-hour weeks"`` →
8``"Alice quit"`` because the latter preserves the textual evidence
9that makes causal reasoning useful at recall time.
11Hindsight's wire shape (from ``hindsight-api-slim/.../fact_extraction.py``):
13 class FactCausalRelation:
14 target_fact_index: int # 0-based index in same batch
15 relation_type: Literal["caused_by"]
17 class ExtractedFact:
18 text: str
19 causal_relations: list[FactCausalRelation] | None
21In our retain pipeline, "facts" map to chunks (one VectorItem each
22after chunking). This module:
241. Takes a batch of chunk texts from the same retain call.
252. Calls the LLM once to identify cause→effect pairs by chunk INDEX.
263. After the orchestrator stores chunks (assigning memory IDs), the
27 helper :func:`build_memory_links_from_relations` resolves the
28 indices to memory IDs and produces :class:`MemoryLink` objects.
294. Persisted via ``GraphStore.store_memory_links``.
31Hindsight uses a single ``"caused_by"`` relation_type only. We follow
32that constraint.
34Failure modes mirror the entity-level version:
36- LLM call failure → return ``[]`` (retain never aborts).
37- Malformed JSON → log + return ``[]``.
38- ``target_fact_index`` out of range → drop that pair.
39- Self-loops (source == target) → drop.
40- Duplicate pairs → dedupe.
41- Cap on max pairs (default ``2 × num_chunks``, matches Hindsight's
42 "max 2 per fact" guideline).
43"""
45from __future__ import annotations
47import json
48import logging
49import re
50from dataclasses import dataclass
51from datetime import datetime, timezone
53from astrocyte.types import MemoryLink, Message
55_logger = logging.getLogger("astrocyte.fact_causal_extraction")
58@dataclass
59class FactCausalRelation:
60 """Cause→effect relation between two facts in the same batch."""
62 source_fact_index: int # the EFFECT (this fact was caused_by target)
63 target_fact_index: int # the CAUSE
64 evidence: str = ""
65 confidence: float = 1.0
68_SYSTEM_PROMPT = """\
69You extract cause→effect relationships between facts in a numbered batch.
71Each fact has an index (0-based). For each fact, identify which OTHER \
72facts in the batch caused it. Output ONLY pairs that are EXPLICITLY \
73causal in the source — never inferred.
75Output a JSON array. Each element is:
76 {"source_fact_index": <int>, "target_fact_index": <int>, \
77"relation_type": "caused_by", "evidence": "<verbatim quote, ≤ 20 words>", \
78"confidence": <0.0-1.0>}
80The semantics: ``source_fact_index`` was caused_by ``target_fact_index``. \
81target must be a DIFFERENT index than source. target may be in any \
82position relative to source (no ordering constraint — sometimes the \
83cause is mentioned later in the source text).
85Rules:
861. Look for explicit causal language: "because", "caused", "led to", \
87"resulted in", "due to", "triggered", "as a result", "made him/her X", \
88"after X happened, Y happened" (only when the temporal sequence + \
89outcome implies causation).
902. Confidence ≥ 0.8 only when text explicitly states causation. \
910.6-0.8 for strongly-implied. Below 0.6: skip.
923. Evidence MUST be a verbatim quote from the input text (≤ 20 words). \
93Don't paraphrase.
944. Never relate a fact to itself.
955. Maximum 2 causes per fact.
966. If no causal relationships are stated, return [].
98Output JSON only. No prose.
99"""
102def _build_user_prompt(chunk_texts: list[str]) -> str:
103 lines = ["Facts (numbered):"]
104 for idx, text in enumerate(chunk_texts):
105 # Trim each chunk so prompt size stays bounded; preserves the
106 # first ~600 chars which is plenty for causal-text detection.
107 snippet = text.strip()
108 if len(snippet) > 600:
109 snippet = snippet[:597] + "..."
110 lines.append(f"[{idx}] {snippet}")
111 lines.append("")
112 lines.append("Causal pairs (JSON array):")
113 return "\n".join(lines)
116def _parse_relations(raw: str) -> list[dict]:
117 """Pull the first JSON array from the LLM response.
119 Tolerates ``` fences and surrounding prose. Returns ``[]`` on any
120 failure — retain never aborts on causal extraction.
121 """
122 text = raw.strip()
123 if text.startswith("```"):
124 text = re.sub(r"^```(?:json)?\s*", "", text)
125 text = re.sub(r"\s*```$", "", text)
126 match = re.search(r"\[.*\]", text, re.DOTALL)
127 if match is None:
128 return []
129 try:
130 payload = json.loads(match.group(0))
131 except json.JSONDecodeError as exc:
132 _logger.warning("fact_causal_extraction: JSON decode failed (%s)", exc)
133 return []
134 if not isinstance(payload, list):
135 return []
136 return [item for item in payload if isinstance(item, dict)]
139async def extract_fact_causal_relations(
140 chunk_texts: list[str],
141 llm_provider,
142 *,
143 max_pairs_per_fact: int = 2,
144 min_confidence: float = 0.6,
145) -> list[FactCausalRelation]:
146 """Identify cause→effect pairs among the supplied chunks.
148 Returns relations as ``(source_fact_index, target_fact_index)``
149 tuples in :class:`FactCausalRelation` form. The orchestrator
150 converts these to :class:`MemoryLink` objects after storage
151 assigns memory IDs to the chunks.
152 """
153 if len(chunk_texts) < 2:
154 return []
156 try:
157 completion = await llm_provider.complete(
158 [
159 Message(role="system", content=_SYSTEM_PROMPT),
160 Message(role="user", content=_build_user_prompt(chunk_texts)),
161 ],
162 max_tokens=1024,
163 temperature=0.0,
164 )
165 except Exception as exc:
166 _logger.warning("fact_causal_extraction: LLM call failed (%s)", exc)
167 return []
169 parsed = _parse_relations(completion.text)
170 if not parsed:
171 return []
173 n = len(chunk_texts)
174 out: list[FactCausalRelation] = []
175 seen: set[tuple[int, int]] = set()
176 per_source_count: dict[int, int] = {}
177 for raw in parsed:
178 try:
179 src_idx = int(raw.get("source_fact_index"))
180 tgt_idx = int(raw.get("target_fact_index"))
181 except (TypeError, ValueError):
182 continue
183 if src_idx == tgt_idx:
184 continue
185 if not (0 <= src_idx < n) or not (0 <= tgt_idx < n):
186 continue
187 try:
188 confidence = float(raw.get("confidence", 0.0))
189 except (TypeError, ValueError):
190 continue
191 if confidence < min_confidence:
192 continue
193 if per_source_count.get(src_idx, 0) >= max_pairs_per_fact:
194 continue
196 key = (src_idx, tgt_idx)
197 if key in seen:
198 continue
199 seen.add(key)
200 per_source_count[src_idx] = per_source_count.get(src_idx, 0) + 1
202 evidence = str(raw.get("evidence") or "").strip()[:500]
203 out.append(
204 FactCausalRelation(
205 source_fact_index=src_idx,
206 target_fact_index=tgt_idx,
207 evidence=evidence,
208 confidence=confidence,
209 )
210 )
211 return out
214def build_memory_links_from_relations(
215 relations: list[FactCausalRelation],
216 memory_ids: list[str],
217 *,
218 bank_id: str,
219) -> list[MemoryLink]:
220 """Resolve fact indices to memory IDs and build :class:`MemoryLink`.
222 Source = effect, target = cause (matches Hindsight's ``caused_by``
223 semantics). Drops relations whose indices fall outside the supplied
224 ``memory_ids`` range — defensive against off-by-one errors when
225 chunking changes count.
226 """
227 out: list[MemoryLink] = []
228 n = len(memory_ids)
229 now = datetime.now(timezone.utc)
230 for rel in relations:
231 if not (0 <= rel.source_fact_index < n) or not (0 <= rel.target_fact_index < n):
232 continue
233 out.append(
234 MemoryLink(
235 source_memory_id=memory_ids[rel.source_fact_index],
236 target_memory_id=memory_ids[rel.target_fact_index],
237 link_type="caused_by",
238 evidence=rel.evidence,
239 confidence=rel.confidence,
240 weight=1.0,
241 created_at=now,
242 metadata={"bank_id": bank_id, "source": "fact_causal_extraction"},
243 )
244 )
245 return out