Coverage for astrocyte/pipeline/fact_causal_extraction.py: 85%

91 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Fact-level cause→effect link extraction (Hindsight parity). 

2 

3This is the C2 rewrite of ``causal_extraction.py``. The key change is 

4**granularity**: Hindsight extracts causal relations between FACTS 

5(full-text statements with rich context), not between entities. 

6``"burnout"`` → ``"resignation"`` is much weaker than 

7``"Alice was burned out from 80-hour weeks"`` → 

8``"Alice quit"`` because the latter preserves the textual evidence 

9that makes causal reasoning useful at recall time. 

10 

11Hindsight's wire shape (from ``hindsight-api-slim/.../fact_extraction.py``): 

12 

13 class FactCausalRelation: 

14 target_fact_index: int # 0-based index in same batch 

15 relation_type: Literal["caused_by"] 

16 

17 class ExtractedFact: 

18 text: str 

19 causal_relations: list[FactCausalRelation] | None 

20 

21In our retain pipeline, "facts" map to chunks (one VectorItem each 

22after chunking). This module: 

23 

241. Takes a batch of chunk texts from the same retain call. 

252. Calls the LLM once to identify cause→effect pairs by chunk INDEX. 

263. After the orchestrator stores chunks (assigning memory IDs), the 

27 helper :func:`build_memory_links_from_relations` resolves the 

28 indices to memory IDs and produces :class:`MemoryLink` objects. 

294. Persisted via ``GraphStore.store_memory_links``. 

30 

31Hindsight uses a single ``"caused_by"`` relation_type only. We follow 

32that constraint. 

33 

34Failure modes mirror the entity-level version: 

35 

36- LLM call failure → return ``[]`` (retain never aborts). 

37- Malformed JSON → log + return ``[]``. 

38- ``target_fact_index`` out of range → drop that pair. 

39- Self-loops (source == target) → drop. 

40- Duplicate pairs → dedupe. 

41- Cap on max pairs (default ``2 × num_chunks``, matches Hindsight's 

42 "max 2 per fact" guideline). 

43""" 

44 

45from __future__ import annotations 

46 

47import json 

48import logging 

49import re 

50from dataclasses import dataclass 

51from datetime import datetime, timezone 

52 

53from astrocyte.types import MemoryLink, Message 

54 

55_logger = logging.getLogger("astrocyte.fact_causal_extraction") 

56 

57 

58@dataclass 

59class FactCausalRelation: 

60 """Cause→effect relation between two facts in the same batch.""" 

61 

62 source_fact_index: int # the EFFECT (this fact was caused_by target) 

63 target_fact_index: int # the CAUSE 

64 evidence: str = "" 

65 confidence: float = 1.0 

66 

67 

68_SYSTEM_PROMPT = """\ 

69You extract cause→effect relationships between facts in a numbered batch. 

70 

71Each fact has an index (0-based). For each fact, identify which OTHER \ 

72facts in the batch caused it. Output ONLY pairs that are EXPLICITLY \ 

73causal in the source — never inferred. 

74 

75Output a JSON array. Each element is: 

76 {"source_fact_index": <int>, "target_fact_index": <int>, \ 

77"relation_type": "caused_by", "evidence": "<verbatim quote, ≤ 20 words>", \ 

78"confidence": <0.0-1.0>} 

79 

80The semantics: ``source_fact_index`` was caused_by ``target_fact_index``. \ 

81target must be a DIFFERENT index than source. target may be in any \ 

82position relative to source (no ordering constraint — sometimes the \ 

83cause is mentioned later in the source text). 

84 

85Rules: 

861. Look for explicit causal language: "because", "caused", "led to", \ 

87"resulted in", "due to", "triggered", "as a result", "made him/her X", \ 

88"after X happened, Y happened" (only when the temporal sequence + \ 

89outcome implies causation). 

902. Confidence ≥ 0.8 only when text explicitly states causation. \ 

910.6-0.8 for strongly-implied. Below 0.6: skip. 

923. Evidence MUST be a verbatim quote from the input text (≤ 20 words). \ 

93Don't paraphrase. 

944. Never relate a fact to itself. 

955. Maximum 2 causes per fact. 

966. If no causal relationships are stated, return []. 

97 

98Output JSON only. No prose. 

99""" 

100 

101 

102def _build_user_prompt(chunk_texts: list[str]) -> str: 

103 lines = ["Facts (numbered):"] 

104 for idx, text in enumerate(chunk_texts): 

105 # Trim each chunk so prompt size stays bounded; preserves the 

106 # first ~600 chars which is plenty for causal-text detection. 

107 snippet = text.strip() 

108 if len(snippet) > 600: 

109 snippet = snippet[:597] + "..." 

110 lines.append(f"[{idx}] {snippet}") 

111 lines.append("") 

112 lines.append("Causal pairs (JSON array):") 

113 return "\n".join(lines) 

114 

115 

116def _parse_relations(raw: str) -> list[dict]: 

117 """Pull the first JSON array from the LLM response. 

118 

119 Tolerates ``` fences and surrounding prose. Returns ``[]`` on any 

120 failure — retain never aborts on causal extraction. 

121 """ 

122 text = raw.strip() 

123 if text.startswith("```"): 

124 text = re.sub(r"^```(?:json)?\s*", "", text) 

125 text = re.sub(r"\s*```$", "", text) 

126 match = re.search(r"\[.*\]", text, re.DOTALL) 

127 if match is None: 

128 return [] 

129 try: 

130 payload = json.loads(match.group(0)) 

131 except json.JSONDecodeError as exc: 

132 _logger.warning("fact_causal_extraction: JSON decode failed (%s)", exc) 

133 return [] 

134 if not isinstance(payload, list): 

135 return [] 

136 return [item for item in payload if isinstance(item, dict)] 

137 

138 

139async def extract_fact_causal_relations( 

140 chunk_texts: list[str], 

141 llm_provider, 

142 *, 

143 max_pairs_per_fact: int = 2, 

144 min_confidence: float = 0.6, 

145) -> list[FactCausalRelation]: 

146 """Identify cause→effect pairs among the supplied chunks. 

147 

148 Returns relations as ``(source_fact_index, target_fact_index)`` 

149 tuples in :class:`FactCausalRelation` form. The orchestrator 

150 converts these to :class:`MemoryLink` objects after storage 

151 assigns memory IDs to the chunks. 

152 """ 

153 if len(chunk_texts) < 2: 

154 return [] 

155 

156 try: 

157 completion = await llm_provider.complete( 

158 [ 

159 Message(role="system", content=_SYSTEM_PROMPT), 

160 Message(role="user", content=_build_user_prompt(chunk_texts)), 

161 ], 

162 max_tokens=1024, 

163 temperature=0.0, 

164 ) 

165 except Exception as exc: 

166 _logger.warning("fact_causal_extraction: LLM call failed (%s)", exc) 

167 return [] 

168 

169 parsed = _parse_relations(completion.text) 

170 if not parsed: 

171 return [] 

172 

173 n = len(chunk_texts) 

174 out: list[FactCausalRelation] = [] 

175 seen: set[tuple[int, int]] = set() 

176 per_source_count: dict[int, int] = {} 

177 for raw in parsed: 

178 try: 

179 src_idx = int(raw.get("source_fact_index")) 

180 tgt_idx = int(raw.get("target_fact_index")) 

181 except (TypeError, ValueError): 

182 continue 

183 if src_idx == tgt_idx: 

184 continue 

185 if not (0 <= src_idx < n) or not (0 <= tgt_idx < n): 

186 continue 

187 try: 

188 confidence = float(raw.get("confidence", 0.0)) 

189 except (TypeError, ValueError): 

190 continue 

191 if confidence < min_confidence: 

192 continue 

193 if per_source_count.get(src_idx, 0) >= max_pairs_per_fact: 

194 continue 

195 

196 key = (src_idx, tgt_idx) 

197 if key in seen: 

198 continue 

199 seen.add(key) 

200 per_source_count[src_idx] = per_source_count.get(src_idx, 0) + 1 

201 

202 evidence = str(raw.get("evidence") or "").strip()[:500] 

203 out.append( 

204 FactCausalRelation( 

205 source_fact_index=src_idx, 

206 target_fact_index=tgt_idx, 

207 evidence=evidence, 

208 confidence=confidence, 

209 ) 

210 ) 

211 return out 

212 

213 

214def build_memory_links_from_relations( 

215 relations: list[FactCausalRelation], 

216 memory_ids: list[str], 

217 *, 

218 bank_id: str, 

219) -> list[MemoryLink]: 

220 """Resolve fact indices to memory IDs and build :class:`MemoryLink`. 

221 

222 Source = effect, target = cause (matches Hindsight's ``caused_by`` 

223 semantics). Drops relations whose indices fall outside the supplied 

224 ``memory_ids`` range — defensive against off-by-one errors when 

225 chunking changes count. 

226 """ 

227 out: list[MemoryLink] = [] 

228 n = len(memory_ids) 

229 now = datetime.now(timezone.utc) 

230 for rel in relations: 

231 if not (0 <= rel.source_fact_index < n) or not (0 <= rel.target_fact_index < n): 

232 continue 

233 out.append( 

234 MemoryLink( 

235 source_memory_id=memory_ids[rel.source_fact_index], 

236 target_memory_id=memory_ids[rel.target_fact_index], 

237 link_type="caused_by", 

238 evidence=rel.evidence, 

239 confidence=rel.confidence, 

240 weight=1.0, 

241 created_at=now, 

242 metadata={"bank_id": bank_id, "source": "fact_causal_extraction"}, 

243 ) 

244 ) 

245 return out