Coverage for astrocyte/pipeline/semantic_link_graph.py: 88%
32 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Precomputed semantic-kNN graph (Hindsight parity, C3a).
3Hindsight's link-expansion retrieval relies on a precomputed semantic
4kNN graph: at retain time, each new memory is linked to its top-K
5most-similar existing memories with similarity above a threshold
6(default ``0.7``). At recall time, those edges become a parallel
7expansion signal alongside entity-overlap and causal links.
9The semantic-kNN edges are essentially a static "what's already in the
10bank that's nearby in embedding space" — much cheaper to maintain than
11recomputing kNN at every recall.
13This module provides :func:`compute_semantic_links` which takes a
14freshly-embedded batch of chunks (with their assigned memory_ids), runs
15``search_similar`` against the existing bank for each, and produces
16:class:`MemoryLink` records with ``link_type="semantic"``. The
17orchestrator persists them via ``GraphStore.store_memory_links``.
19Notes:
21- Per-chunk asyncio.gather: K queries are independent; gather lets the
22 bank-side concurrency soak up the cost.
23- Self-exclusion: each chunk's own memory_id is filtered out of its
24 own kNN result set (a chunk would otherwise link to itself with
25 similarity 1.0).
26- Same-batch exclusion: memories created in the same retain call are
27 also filtered (we don't want chunk_2 to link to chunk_3 just because
28 they shared the same source paragraph — that's already captured by
29 the causal_by signal when applicable).
30- Threshold: configurable, default ``0.7`` matches Hindsight.
31- top-K: configurable, default ``5`` matches Hindsight.
32"""
34from __future__ import annotations
36import asyncio
37import logging
38from datetime import datetime, timezone
40from astrocyte.types import MemoryLink
42_logger = logging.getLogger("astrocyte.semantic_link_graph")
45async def compute_semantic_links(
46 *,
47 bank_id: str,
48 new_memory_ids: list[str],
49 new_embeddings: list[list[float]],
50 vector_store,
51 top_k: int = 5,
52 similarity_threshold: float = 0.7,
53) -> list[MemoryLink]:
54 """Build the semantic-kNN edges for a batch of new memories.
56 For each ``(memory_id, embedding)``, run a similarity search against
57 ``bank_id`` and emit one :class:`MemoryLink` per hit above
58 ``similarity_threshold``. The edges are directional but the
59 *semantic* link is symmetric semantically — the link-expansion
60 retrieval queries both directions at recall time.
62 Args:
63 bank_id: Target bank; the search runs scoped to this bank.
64 new_memory_ids: IDs of the freshly-stored memories. Must align
65 with ``new_embeddings`` index-for-index.
66 new_embeddings: Embeddings for each new memory.
67 vector_store: Provider implementing ``search_similar``.
68 top_k: Maximum number of nearest neighbors per new memory.
69 similarity_threshold: Minimum cosine similarity to keep an edge.
71 Returns:
72 :class:`MemoryLink` objects with ``link_type="semantic"`` and
73 ``weight=similarity``. Empty list when no neighbors qualify.
74 """
75 if not new_memory_ids or len(new_memory_ids) != len(new_embeddings):
76 return []
78 same_batch = set(new_memory_ids)
80 async def _search_one(idx: int) -> list[MemoryLink]:
81 embedding = new_embeddings[idx]
82 if not embedding:
83 return []
84 try:
85 # Fetch a few extra so post-filtering for self + same-batch
86 # exclusions still leaves us close to top_k.
87 hits = await vector_store.search_similar(
88 embedding,
89 bank_id,
90 limit=top_k + len(same_batch) + 2,
91 )
92 except Exception as exc:
93 _logger.warning(
94 "semantic_link_graph: search_similar failed for %r (%s)",
95 new_memory_ids[idx],
96 exc,
97 )
98 return []
100 out: list[MemoryLink] = []
101 now = datetime.now(timezone.utc)
102 for hit in hits:
103 if hit.id in same_batch:
104 continue
105 if hit.score < similarity_threshold:
106 # search_similar returns hits sorted descending; once we
107 # drop below threshold, the rest will too.
108 break
109 out.append(
110 MemoryLink(
111 source_memory_id=new_memory_ids[idx],
112 target_memory_id=hit.id,
113 link_type="semantic",
114 evidence="",
115 confidence=1.0,
116 weight=float(hit.score),
117 created_at=now,
118 metadata={"source": "semantic_link_graph"},
119 )
120 )
121 if len(out) >= top_k:
122 break
123 return out
125 per_chunk = await asyncio.gather(*[_search_one(i) for i in range(len(new_memory_ids))])
126 return [link for chunk_links in per_chunk for link in chunk_links]