Coverage for astrocyte/pipeline/link_expansion.py: 86%
154 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""3-parallel-signal link expansion (Hindsight parity, C3b).
3This is the C3 rewrite of ``spreading_activation.py``. The previous
4module did a BFS hop walk over ``co_occurs`` entity edges; Hindsight's
5``link_expansion_retrieval.py`` doesn't walk multi-hop entity chains
6that way. Instead, it queries three first-class link signals in
7parallel and combines them:
91. **Entity overlap** — query-time set-overlap. Candidate memories
10 that share entities with the seeds score by ``count(distinct shared
11 entities)``. Computed here in Python via the
12 ``GraphStore.get_entity_ids_for_memories`` SPI plus a reverse map
13 (entity → memories) materialized on the fly.
152. **Semantic links** — precomputed at retain time
16 (:mod:`semantic_link_graph`). Edges of type ``"semantic"`` connect
17 each new memory to its top-K most-similar neighbors. The
18 link-expansion query reads these directly from
19 ``GraphStore.find_memory_links``.
213. **Causal links** — explicit ``"caused_by"`` chains extracted at
22 retain time (:mod:`fact_causal_extraction`). Boosted (+1.0 weight)
23 as the highest-quality signal because the source-text causal
24 evidence is unambiguous.
26Hindsight's actual implementation runs all three as a single
27recursive-CTE Postgres query for speed. We do the same shape in
28Python because the orchestrator's :class:`GraphStore` SPI must work
29for arbitrary backends (in-memory tests, AGE, future stores). For
30LoCoMo-scale workloads (~thousands of memories per bank), the Python
31path is well within latency budget.
33The return type is ``list[ScoredItem]`` — same as the old
34spread_activation function — so the orchestrator's RRF-fusion
35plumbing slots in without changes.
36"""
38from __future__ import annotations
40import logging
41from dataclasses import dataclass, field
43from astrocyte.pipeline.fusion import ScoredItem
44from astrocyte.provider import GraphStore, VectorStore
46_logger = logging.getLogger("astrocyte.link_expansion")
49# ---------------------------------------------------------------------------
50# Parameters
51# ---------------------------------------------------------------------------
54@dataclass
55class LinkExpansionParams:
56 """Tunable knobs for the 3-signal expansion (Hindsight parity).
58 Defaults match Hindsight's published configuration where stated;
59 score weights mirror their reranking blend (entity overlap, semantic
60 weight, causal +1.0 boost).
61 """
63 expansion_limit: int = 30
64 #: Per-entity LATERAL cap (Hindsight's ``graph_per_entity_limit``):
65 #: when an entity is shared with many candidates, take at most this
66 #: many of them per entity to prevent fanout explosion.
67 per_entity_limit: int = 200
68 #: Score weights — each signal is normalized to [0, 1] before the
69 #: weighted sum. Causal gets the highest weight per Hindsight's note
70 #: that ``causes`` chains are the highest-precision signal.
71 entity_overlap_weight: float = 0.5
72 semantic_weight: float = 0.3
73 causal_weight: float = 0.7
74 #: Causal link types to walk. Currently only ``caused_by`` is
75 #: extracted at retain time, but reserved for future extensions
76 #: (``enables``, ``prevents``, etc.).
77 causal_link_types: tuple[str, ...] = ("caused_by",)
78 semantic_link_types: tuple[str, ...] = ("semantic",)
79 #: Minimum total score (post-weighting) for a candidate to surface.
80 activation_threshold: float = 0.05
83# ---------------------------------------------------------------------------
84# Tag scope helper (mirrors spread/expand path)
85# ---------------------------------------------------------------------------
88def _hit_has_required_tags(
89 metadata: dict | None,
90 tags: list[str] | None,
91 required_tags: set[str],
92) -> bool:
93 if not required_tags:
94 return True
95 item_tags = {str(t).lower() for t in (tags or [])}
96 return required_tags.issubset(item_tags)
99# ---------------------------------------------------------------------------
100# Score accumulator
101# ---------------------------------------------------------------------------
104@dataclass
105class _CandidateScore:
106 memory_id: str
107 entity_overlap: int = 0 # count of distinct shared entities
108 semantic_total: float = 0.0 # sum of semantic edge weights
109 causal_total: float = 0.0 # sum of (causal weight + 1.0) per Hindsight
110 sources: set[str] = field(default_factory=set) # which signals contributed
112 def total(self, params: LinkExpansionParams) -> float:
113 # Normalize entity overlap by a small constant — diminishing
114 # returns past 5 shared entities. Hindsight uses raw count;
115 # the normalization here keeps the weighted sum interpretable.
116 eo_norm = min(1.0, self.entity_overlap / 5.0)
117 sem_norm = min(1.0, self.semantic_total)
118 causal_norm = min(1.0, self.causal_total)
119 return (
120 params.entity_overlap_weight * eo_norm
121 + params.semantic_weight * sem_norm
122 + params.causal_weight * causal_norm
123 )
126# ---------------------------------------------------------------------------
127# Main entry point
128# ---------------------------------------------------------------------------
131async def link_expansion(
132 seed_hits: list[ScoredItem],
133 *,
134 bank_id: str,
135 vector_store: VectorStore,
136 graph_store: GraphStore,
137 params: LinkExpansionParams | None = None,
138 tags: list[str] | None = None,
139) -> list[ScoredItem]:
140 """Expand seeds through the three first-class memory-link signals.
142 Returns NEW candidate memories only (seeds are filtered out). Each
143 return ``ScoredItem`` carries metadata explaining which signal(s)
144 surfaced it:
146 - ``_link_signal``: comma-separated list of contributing signals
147 (``entity_overlap``, ``semantic``, ``causal``).
148 - ``_entity_overlap_count``: how many entities it shares with seeds.
149 - ``_semantic_weight_total``: sum of semantic edge weights to seeds.
150 - ``_causal_weight_total``: sum of causal edge weights to seeds.
152 Args:
153 seed_hits: Top-K from initial RRF fusion. Their entity IDs and
154 memory_ids drive all three signal queries.
155 bank_id: Constrains every query.
156 vector_store: Used to hydrate full memory bodies after scoring.
157 graph_store: Source of all three signals via the optional
158 ``get_entity_ids_for_memories`` and ``find_memory_links``
159 methods. Returns ``[]`` early when either is unavailable.
160 params: Tuning knobs; defaults match Hindsight.
161 tags: Optional tag filter — candidates failing scope are
162 dropped before being returned (LoCoMo's ``convo:<id>``
163 scoping reuses this).
164 """
165 if not seed_hits:
166 return []
167 p = params or LinkExpansionParams()
168 seed_ids = {h.id for h in seed_hits}
169 required_tags = {str(t).lower() for t in tags} if tags else set()
171 fast_expand = getattr(graph_store, "expand_memory_links_fast", None)
172 if callable(fast_expand):
173 try:
174 rows = await fast_expand([h.id for h in seed_hits], bank_id, params=p)
175 except Exception as exc:
176 _logger.warning("fast link expansion failed (%s); using portable fallback", exc)
177 rows = []
178 if rows:
179 fast_result = await _hydrate_candidate_scores(
180 vector_store,
181 bank_id,
182 _candidate_scores_from_rows(rows),
183 p,
184 required_tags,
185 )
186 if fast_result:
187 return fast_result
189 candidates: dict[str, _CandidateScore] = {}
191 # --- Signal 1: entity overlap --------------------------------------
192 # Pull entity associations for the seeds, then for each entity,
193 # find its other memories. The reverse-lookup uses
194 # ``query_neighbors`` since that's the existing memories↔entities
195 # surface; the per-entity LATERAL cap mirrors Hindsight's
196 # ``graph_per_entity_limit``.
197 get_entities = getattr(graph_store, "get_entity_ids_for_memories", None)
198 if get_entities is not None:
199 try:
200 seed_entity_map = await get_entities([h.id for h in seed_hits], bank_id)
201 except Exception as exc:
202 _logger.warning("entity-overlap lookup failed (%s)", exc)
203 seed_entity_map = {}
205 seed_entity_ids: set[str] = set()
206 for ents in seed_entity_map.values():
207 seed_entity_ids.update(ents)
209 if seed_entity_ids:
210 try:
211 # Reverse: which other memories carry these entities?
212 graph_hits = await graph_store.query_neighbors(
213 list(seed_entity_ids),
214 bank_id,
215 max_depth=1,
216 limit=p.per_entity_limit * len(seed_entity_ids),
217 )
218 except Exception as exc:
219 _logger.warning("query_neighbors failed (%s)", exc)
220 graph_hits = []
222 # For each candidate memory, count distinct shared entities.
223 for ghit in graph_hits:
224 mid = ghit.memory_id
225 if mid in seed_ids:
226 continue
227 shared = set(ghit.connected_entities or []) & seed_entity_ids
228 if not shared:
229 continue
230 cand = candidates.setdefault(mid, _CandidateScore(memory_id=mid))
231 cand.entity_overlap = max(cand.entity_overlap, len(shared))
232 cand.sources.add("entity_overlap")
234 # --- Signals 2 & 3: precomputed memory_links -----------------------
235 find_links = getattr(graph_store, "find_memory_links", None)
236 if find_links is not None:
237 all_link_types = list(p.semantic_link_types) + list(p.causal_link_types)
238 try:
239 links = await find_links(
240 [h.id for h in seed_hits],
241 bank_id,
242 link_types=all_link_types,
243 limit=p.expansion_limit * 4,
244 )
245 except Exception as exc:
246 _logger.warning("find_memory_links failed (%s)", exc)
247 links = []
249 for link in links:
250 # The link's "other end" relative to a seed is what we want
251 # to surface as a candidate.
252 if link.source_memory_id in seed_ids and link.target_memory_id not in seed_ids:
253 other = link.target_memory_id
254 elif link.target_memory_id in seed_ids and link.source_memory_id not in seed_ids:
255 other = link.source_memory_id
256 else:
257 continue
259 cand = candidates.setdefault(other, _CandidateScore(memory_id=other))
260 if link.link_type in p.semantic_link_types:
261 cand.semantic_total += float(link.weight)
262 cand.sources.add("semantic")
263 elif link.link_type in p.causal_link_types:
264 # Hindsight: causal weight + 1.0 boost.
265 cand.causal_total += float(link.weight) + 1.0
266 cand.sources.add("causal")
268 if not candidates:
269 return []
271 return await _hydrate_candidate_scores(vector_store, bank_id, list(candidates.values()), p, required_tags)
274def _candidate_scores_from_rows(rows: list[dict]) -> list[_CandidateScore]:
275 candidates: list[_CandidateScore] = []
276 for row in rows:
277 sources = row.get("sources") or []
278 if isinstance(sources, str):
279 sources = [part for part in sources.split(",") if part]
280 candidates.append(
281 _CandidateScore(
282 memory_id=str(row["memory_id"]),
283 entity_overlap=int(row.get("entity_overlap") or 0),
284 semantic_total=float(row.get("semantic_total") or 0.0),
285 causal_total=float(row.get("causal_total") or 0.0),
286 sources={str(source) for source in sources},
287 )
288 )
289 return candidates
292async def _hydrate_candidate_scores(
293 vector_store: VectorStore,
294 bank_id: str,
295 candidates: list[_CandidateScore],
296 params: LinkExpansionParams,
297 required_tags: set[str],
298) -> list[ScoredItem]:
299 """Hydrate candidate IDs from either the SQL fast path or Python fallback."""
300 if not candidates:
301 return []
303 # Cap candidate set before fetching bodies to bound the cost.
304 ranked = sorted(
305 candidates,
306 key=lambda c: c.total(params),
307 reverse=True,
308 )
309 ranked = [c for c in ranked if c.total(params) >= params.activation_threshold]
310 ranked = ranked[: params.expansion_limit * 2] # over-fetch; tag filter cuts later
311 if not ranked:
312 return []
314 bodies = await _fetch_bodies_by_id(vector_store, bank_id, [c.memory_id for c in ranked])
316 out: list[ScoredItem] = []
317 for cand in ranked:
318 body = bodies.get(cand.memory_id)
319 if body is None:
320 continue
321 if not _hit_has_required_tags(body.metadata, body.tags, required_tags):
322 continue
324 metadata = dict(body.metadata or {})
325 metadata["_link_signal"] = ",".join(sorted(cand.sources))
326 if cand.entity_overlap > 0:
327 metadata["_entity_overlap_count"] = cand.entity_overlap
328 if cand.semantic_total > 0:
329 metadata["_semantic_weight_total"] = round(cand.semantic_total, 4)
330 if cand.causal_total > 0:
331 metadata["_causal_weight_total"] = round(cand.causal_total, 4)
333 out.append(
334 ScoredItem(
335 id=body.id,
336 text=body.text,
337 score=cand.total(params),
338 fact_type=body.fact_type,
339 metadata=metadata,
340 tags=body.tags,
341 memory_layer=body.memory_layer,
342 occurred_at=body.occurred_at,
343 retained_at=body.retained_at,
344 )
345 )
346 if len(out) >= params.expansion_limit:
347 break
349 return out
352# ---------------------------------------------------------------------------
353# Helpers
354# ---------------------------------------------------------------------------
357async def _fetch_bodies_by_id(
358 vector_store: VectorStore,
359 bank_id: str,
360 memory_ids: list[str],
361):
362 """Resolve memory IDs to their full ``VectorItem`` bodies.
364 Bounded ``list_vectors`` scan; same pattern as
365 ``PipelineOrchestrator._fetch_memory_hits_by_id``. For LoCoMo-scale
366 banks this is fine; for very large banks we'd want a batched
367 ``get_by_ids`` SPI extension.
368 """
369 target = set(memory_ids)
370 out: dict[str, object] = {}
371 offset = 0
372 batch = 200
373 while target:
374 chunk = await vector_store.list_vectors(bank_id, offset=offset, limit=batch)
375 if not chunk:
376 break
377 for item in chunk:
378 if item.id in target:
379 out[item.id] = item
380 target.discard(item.id)
381 if len(chunk) < batch:
382 break
383 offset += batch
384 return out