Coverage for astrocyte/pipeline/link_expansion.py: 86%

154 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""3-parallel-signal link expansion (Hindsight parity, C3b). 

2 

3This is the C3 rewrite of ``spreading_activation.py``. The previous 

4module did a BFS hop walk over ``co_occurs`` entity edges; Hindsight's 

5``link_expansion_retrieval.py`` doesn't walk multi-hop entity chains 

6that way. Instead, it queries three first-class link signals in 

7parallel and combines them: 

8 

91. **Entity overlap** — query-time set-overlap. Candidate memories 

10 that share entities with the seeds score by ``count(distinct shared 

11 entities)``. Computed here in Python via the 

12 ``GraphStore.get_entity_ids_for_memories`` SPI plus a reverse map 

13 (entity → memories) materialized on the fly. 

14 

152. **Semantic links** — precomputed at retain time 

16 (:mod:`semantic_link_graph`). Edges of type ``"semantic"`` connect 

17 each new memory to its top-K most-similar neighbors. The 

18 link-expansion query reads these directly from 

19 ``GraphStore.find_memory_links``. 

20 

213. **Causal links** — explicit ``"caused_by"`` chains extracted at 

22 retain time (:mod:`fact_causal_extraction`). Boosted (+1.0 weight) 

23 as the highest-quality signal because the source-text causal 

24 evidence is unambiguous. 

25 

26Hindsight's actual implementation runs all three as a single 

27recursive-CTE Postgres query for speed. We do the same shape in 

28Python because the orchestrator's :class:`GraphStore` SPI must work 

29for arbitrary backends (in-memory tests, AGE, future stores). For 

30LoCoMo-scale workloads (~thousands of memories per bank), the Python 

31path is well within latency budget. 

32 

33The return type is ``list[ScoredItem]`` — same as the old 

34spread_activation function — so the orchestrator's RRF-fusion 

35plumbing slots in without changes. 

36""" 

37 

38from __future__ import annotations 

39 

40import logging 

41from dataclasses import dataclass, field 

42 

43from astrocyte.pipeline.fusion import ScoredItem 

44from astrocyte.provider import GraphStore, VectorStore 

45 

46_logger = logging.getLogger("astrocyte.link_expansion") 

47 

48 

49# --------------------------------------------------------------------------- 

50# Parameters 

51# --------------------------------------------------------------------------- 

52 

53 

54@dataclass 

55class LinkExpansionParams: 

56 """Tunable knobs for the 3-signal expansion (Hindsight parity). 

57 

58 Defaults match Hindsight's published configuration where stated; 

59 score weights mirror their reranking blend (entity overlap, semantic 

60 weight, causal +1.0 boost). 

61 """ 

62 

63 expansion_limit: int = 30 

64 #: Per-entity LATERAL cap (Hindsight's ``graph_per_entity_limit``): 

65 #: when an entity is shared with many candidates, take at most this 

66 #: many of them per entity to prevent fanout explosion. 

67 per_entity_limit: int = 200 

68 #: Score weights — each signal is normalized to [0, 1] before the 

69 #: weighted sum. Causal gets the highest weight per Hindsight's note 

70 #: that ``causes`` chains are the highest-precision signal. 

71 entity_overlap_weight: float = 0.5 

72 semantic_weight: float = 0.3 

73 causal_weight: float = 0.7 

74 #: Causal link types to walk. Currently only ``caused_by`` is 

75 #: extracted at retain time, but reserved for future extensions 

76 #: (``enables``, ``prevents``, etc.). 

77 causal_link_types: tuple[str, ...] = ("caused_by",) 

78 semantic_link_types: tuple[str, ...] = ("semantic",) 

79 #: Minimum total score (post-weighting) for a candidate to surface. 

80 activation_threshold: float = 0.05 

81 

82 

83# --------------------------------------------------------------------------- 

84# Tag scope helper (mirrors spread/expand path) 

85# --------------------------------------------------------------------------- 

86 

87 

88def _hit_has_required_tags( 

89 metadata: dict | None, 

90 tags: list[str] | None, 

91 required_tags: set[str], 

92) -> bool: 

93 if not required_tags: 

94 return True 

95 item_tags = {str(t).lower() for t in (tags or [])} 

96 return required_tags.issubset(item_tags) 

97 

98 

99# --------------------------------------------------------------------------- 

100# Score accumulator 

101# --------------------------------------------------------------------------- 

102 

103 

104@dataclass 

105class _CandidateScore: 

106 memory_id: str 

107 entity_overlap: int = 0 # count of distinct shared entities 

108 semantic_total: float = 0.0 # sum of semantic edge weights 

109 causal_total: float = 0.0 # sum of (causal weight + 1.0) per Hindsight 

110 sources: set[str] = field(default_factory=set) # which signals contributed 

111 

112 def total(self, params: LinkExpansionParams) -> float: 

113 # Normalize entity overlap by a small constant — diminishing 

114 # returns past 5 shared entities. Hindsight uses raw count; 

115 # the normalization here keeps the weighted sum interpretable. 

116 eo_norm = min(1.0, self.entity_overlap / 5.0) 

117 sem_norm = min(1.0, self.semantic_total) 

118 causal_norm = min(1.0, self.causal_total) 

119 return ( 

120 params.entity_overlap_weight * eo_norm 

121 + params.semantic_weight * sem_norm 

122 + params.causal_weight * causal_norm 

123 ) 

124 

125 

126# --------------------------------------------------------------------------- 

127# Main entry point 

128# --------------------------------------------------------------------------- 

129 

130 

131async def link_expansion( 

132 seed_hits: list[ScoredItem], 

133 *, 

134 bank_id: str, 

135 vector_store: VectorStore, 

136 graph_store: GraphStore, 

137 params: LinkExpansionParams | None = None, 

138 tags: list[str] | None = None, 

139) -> list[ScoredItem]: 

140 """Expand seeds through the three first-class memory-link signals. 

141 

142 Returns NEW candidate memories only (seeds are filtered out). Each 

143 return ``ScoredItem`` carries metadata explaining which signal(s) 

144 surfaced it: 

145 

146 - ``_link_signal``: comma-separated list of contributing signals 

147 (``entity_overlap``, ``semantic``, ``causal``). 

148 - ``_entity_overlap_count``: how many entities it shares with seeds. 

149 - ``_semantic_weight_total``: sum of semantic edge weights to seeds. 

150 - ``_causal_weight_total``: sum of causal edge weights to seeds. 

151 

152 Args: 

153 seed_hits: Top-K from initial RRF fusion. Their entity IDs and 

154 memory_ids drive all three signal queries. 

155 bank_id: Constrains every query. 

156 vector_store: Used to hydrate full memory bodies after scoring. 

157 graph_store: Source of all three signals via the optional 

158 ``get_entity_ids_for_memories`` and ``find_memory_links`` 

159 methods. Returns ``[]`` early when either is unavailable. 

160 params: Tuning knobs; defaults match Hindsight. 

161 tags: Optional tag filter — candidates failing scope are 

162 dropped before being returned (LoCoMo's ``convo:<id>`` 

163 scoping reuses this). 

164 """ 

165 if not seed_hits: 

166 return [] 

167 p = params or LinkExpansionParams() 

168 seed_ids = {h.id for h in seed_hits} 

169 required_tags = {str(t).lower() for t in tags} if tags else set() 

170 

171 fast_expand = getattr(graph_store, "expand_memory_links_fast", None) 

172 if callable(fast_expand): 

173 try: 

174 rows = await fast_expand([h.id for h in seed_hits], bank_id, params=p) 

175 except Exception as exc: 

176 _logger.warning("fast link expansion failed (%s); using portable fallback", exc) 

177 rows = [] 

178 if rows: 

179 fast_result = await _hydrate_candidate_scores( 

180 vector_store, 

181 bank_id, 

182 _candidate_scores_from_rows(rows), 

183 p, 

184 required_tags, 

185 ) 

186 if fast_result: 

187 return fast_result 

188 

189 candidates: dict[str, _CandidateScore] = {} 

190 

191 # --- Signal 1: entity overlap -------------------------------------- 

192 # Pull entity associations for the seeds, then for each entity, 

193 # find its other memories. The reverse-lookup uses 

194 # ``query_neighbors`` since that's the existing memories↔entities 

195 # surface; the per-entity LATERAL cap mirrors Hindsight's 

196 # ``graph_per_entity_limit``. 

197 get_entities = getattr(graph_store, "get_entity_ids_for_memories", None) 

198 if get_entities is not None: 

199 try: 

200 seed_entity_map = await get_entities([h.id for h in seed_hits], bank_id) 

201 except Exception as exc: 

202 _logger.warning("entity-overlap lookup failed (%s)", exc) 

203 seed_entity_map = {} 

204 

205 seed_entity_ids: set[str] = set() 

206 for ents in seed_entity_map.values(): 

207 seed_entity_ids.update(ents) 

208 

209 if seed_entity_ids: 

210 try: 

211 # Reverse: which other memories carry these entities? 

212 graph_hits = await graph_store.query_neighbors( 

213 list(seed_entity_ids), 

214 bank_id, 

215 max_depth=1, 

216 limit=p.per_entity_limit * len(seed_entity_ids), 

217 ) 

218 except Exception as exc: 

219 _logger.warning("query_neighbors failed (%s)", exc) 

220 graph_hits = [] 

221 

222 # For each candidate memory, count distinct shared entities. 

223 for ghit in graph_hits: 

224 mid = ghit.memory_id 

225 if mid in seed_ids: 

226 continue 

227 shared = set(ghit.connected_entities or []) & seed_entity_ids 

228 if not shared: 

229 continue 

230 cand = candidates.setdefault(mid, _CandidateScore(memory_id=mid)) 

231 cand.entity_overlap = max(cand.entity_overlap, len(shared)) 

232 cand.sources.add("entity_overlap") 

233 

234 # --- Signals 2 & 3: precomputed memory_links ----------------------- 

235 find_links = getattr(graph_store, "find_memory_links", None) 

236 if find_links is not None: 

237 all_link_types = list(p.semantic_link_types) + list(p.causal_link_types) 

238 try: 

239 links = await find_links( 

240 [h.id for h in seed_hits], 

241 bank_id, 

242 link_types=all_link_types, 

243 limit=p.expansion_limit * 4, 

244 ) 

245 except Exception as exc: 

246 _logger.warning("find_memory_links failed (%s)", exc) 

247 links = [] 

248 

249 for link in links: 

250 # The link's "other end" relative to a seed is what we want 

251 # to surface as a candidate. 

252 if link.source_memory_id in seed_ids and link.target_memory_id not in seed_ids: 

253 other = link.target_memory_id 

254 elif link.target_memory_id in seed_ids and link.source_memory_id not in seed_ids: 

255 other = link.source_memory_id 

256 else: 

257 continue 

258 

259 cand = candidates.setdefault(other, _CandidateScore(memory_id=other)) 

260 if link.link_type in p.semantic_link_types: 

261 cand.semantic_total += float(link.weight) 

262 cand.sources.add("semantic") 

263 elif link.link_type in p.causal_link_types: 

264 # Hindsight: causal weight + 1.0 boost. 

265 cand.causal_total += float(link.weight) + 1.0 

266 cand.sources.add("causal") 

267 

268 if not candidates: 

269 return [] 

270 

271 return await _hydrate_candidate_scores(vector_store, bank_id, list(candidates.values()), p, required_tags) 

272 

273 

274def _candidate_scores_from_rows(rows: list[dict]) -> list[_CandidateScore]: 

275 candidates: list[_CandidateScore] = [] 

276 for row in rows: 

277 sources = row.get("sources") or [] 

278 if isinstance(sources, str): 

279 sources = [part for part in sources.split(",") if part] 

280 candidates.append( 

281 _CandidateScore( 

282 memory_id=str(row["memory_id"]), 

283 entity_overlap=int(row.get("entity_overlap") or 0), 

284 semantic_total=float(row.get("semantic_total") or 0.0), 

285 causal_total=float(row.get("causal_total") or 0.0), 

286 sources={str(source) for source in sources}, 

287 ) 

288 ) 

289 return candidates 

290 

291 

292async def _hydrate_candidate_scores( 

293 vector_store: VectorStore, 

294 bank_id: str, 

295 candidates: list[_CandidateScore], 

296 params: LinkExpansionParams, 

297 required_tags: set[str], 

298) -> list[ScoredItem]: 

299 """Hydrate candidate IDs from either the SQL fast path or Python fallback.""" 

300 if not candidates: 

301 return [] 

302 

303 # Cap candidate set before fetching bodies to bound the cost. 

304 ranked = sorted( 

305 candidates, 

306 key=lambda c: c.total(params), 

307 reverse=True, 

308 ) 

309 ranked = [c for c in ranked if c.total(params) >= params.activation_threshold] 

310 ranked = ranked[: params.expansion_limit * 2] # over-fetch; tag filter cuts later 

311 if not ranked: 

312 return [] 

313 

314 bodies = await _fetch_bodies_by_id(vector_store, bank_id, [c.memory_id for c in ranked]) 

315 

316 out: list[ScoredItem] = [] 

317 for cand in ranked: 

318 body = bodies.get(cand.memory_id) 

319 if body is None: 

320 continue 

321 if not _hit_has_required_tags(body.metadata, body.tags, required_tags): 

322 continue 

323 

324 metadata = dict(body.metadata or {}) 

325 metadata["_link_signal"] = ",".join(sorted(cand.sources)) 

326 if cand.entity_overlap > 0: 

327 metadata["_entity_overlap_count"] = cand.entity_overlap 

328 if cand.semantic_total > 0: 

329 metadata["_semantic_weight_total"] = round(cand.semantic_total, 4) 

330 if cand.causal_total > 0: 

331 metadata["_causal_weight_total"] = round(cand.causal_total, 4) 

332 

333 out.append( 

334 ScoredItem( 

335 id=body.id, 

336 text=body.text, 

337 score=cand.total(params), 

338 fact_type=body.fact_type, 

339 metadata=metadata, 

340 tags=body.tags, 

341 memory_layer=body.memory_layer, 

342 occurred_at=body.occurred_at, 

343 retained_at=body.retained_at, 

344 ) 

345 ) 

346 if len(out) >= params.expansion_limit: 

347 break 

348 

349 return out 

350 

351 

352# --------------------------------------------------------------------------- 

353# Helpers 

354# --------------------------------------------------------------------------- 

355 

356 

357async def _fetch_bodies_by_id( 

358 vector_store: VectorStore, 

359 bank_id: str, 

360 memory_ids: list[str], 

361): 

362 """Resolve memory IDs to their full ``VectorItem`` bodies. 

363 

364 Bounded ``list_vectors`` scan; same pattern as 

365 ``PipelineOrchestrator._fetch_memory_hits_by_id``. For LoCoMo-scale 

366 banks this is fine; for very large banks we'd want a batched 

367 ``get_by_ids`` SPI extension. 

368 """ 

369 target = set(memory_ids) 

370 out: dict[str, object] = {} 

371 offset = 0 

372 batch = 200 

373 while target: 

374 chunk = await vector_store.list_vectors(bank_id, offset=offset, limit=batch) 

375 if not chunk: 

376 break 

377 for item in chunk: 

378 if item.id in target: 

379 out[item.id] = item 

380 target.discard(item.id) 

381 if len(chunk) < batch: 

382 break 

383 offset += batch 

384 return out