Coverage for astrocyte/testing/in_memory.py: 84%
1057 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""In-memory provider implementations for testing.
3These are fully functional providers backed by Python dicts/lists.
4Used by conformance tests and integration tests.
5"""
7from __future__ import annotations
9import math
10import uuid
11from datetime import UTC, datetime
12from typing import ClassVar
14from astrocyte.types import (
15 Completion,
16 Document,
17 DocumentFilters,
18 DocumentHit,
19 EngineCapabilities,
20 Entity,
21 EntityCandidateMatch,
22 EntityLink,
23 ForgetRequest,
24 ForgetResult,
25 GraphHit,
26 HealthStatus,
27 LLMCapabilities,
28 MemoryEntityAssociation,
29 MemoryHit,
30 MentalModel,
31 Message,
32 PageIndexDocument,
33 PageIndexFact,
34 PageIndexSection,
35 PageIndexSectionEntity,
36 PageIndexSectionLink,
37 RecallRequest,
38 RecallResult,
39 ReflectRequest,
40 ReflectResult,
41 RetainRequest,
42 RetainResult,
43 SourceChunk,
44 SourceDocument,
45 TokenUsage,
46 VectorFilters,
47 VectorHit,
48 VectorItem,
49 WikiPage,
50)
53def _cosine_sim(a: list[float], b: list[float]) -> float:
54 dot = sum(x * y for x, y in zip(a, b))
55 na = math.sqrt(sum(x * x for x in a))
56 nb = math.sqrt(sum(x * x for x in b))
57 if na == 0 or nb == 0:
58 return 0.0
59 return dot / (na * nb)
62# ---------------------------------------------------------------------------
63# In-memory Vector Store
64# ---------------------------------------------------------------------------
67class InMemoryVectorStore:
68 """Fully functional in-memory vector store for testing."""
70 SPI_VERSION: ClassVar[int] = 1
72 def __init__(self) -> None:
73 self._vectors: dict[str, VectorItem] = {}
75 async def store_vectors(self, items: list[VectorItem]) -> list[str]:
76 ids = []
77 for item in items:
78 self._vectors[item.id] = item
79 ids.append(item.id)
80 return ids
82 async def search_similar(
83 self,
84 query_vector: list[float],
85 bank_id: str,
86 limit: int = 10,
87 filters: VectorFilters | None = None,
88 ) -> list[VectorHit]:
89 results: list[tuple[float, VectorItem]] = []
90 for item in self._vectors.values():
91 if item.bank_id != bank_id:
92 continue
93 if filters:
94 if filters.tags and item.tags:
95 if not set(filters.tags) & set(item.tags):
96 continue
97 if filters.fact_types and item.fact_type:
98 if item.fact_type not in filters.fact_types:
99 continue
100 # M9: time-travel filter — exclude items retained after as_of
101 if filters.as_of is not None and item.retained_at is not None:
102 if item.retained_at > filters.as_of:
103 continue
104 if filters.time_range is not None and item.occurred_at is not None:
105 start, end = filters.time_range
106 if item.occurred_at < start or item.occurred_at > end:
107 continue
108 # M31 Fix 2 — session_id filter. ``metadata['session_id']``
109 # is the canonical location (preserves the same shape
110 # production callers use when calling retain()). When the
111 # filter is set but the item has no session_id metadata,
112 # the item is excluded (consistent with the
113 # PageIndexStore behaviour where top-level facts without
114 # an anchoring section are excluded from session-filtered
115 # results).
116 if filters.session_id is not None:
117 item_session = None
118 if item.metadata:
119 item_session = item.metadata.get("session_id") if isinstance(item.metadata, dict) else None
120 if item_session != filters.session_id:
121 continue
122 sim = _cosine_sim(query_vector, item.vector)
123 results.append((sim, item))
125 results.sort(key=lambda x: x[0], reverse=True)
126 return [
127 VectorHit(
128 id=item.id,
129 text=item.text,
130 score=sim,
131 metadata=item.metadata,
132 tags=item.tags,
133 fact_type=item.fact_type,
134 occurred_at=item.occurred_at,
135 memory_layer=item.memory_layer,
136 retained_at=item.retained_at, # M9
137 chunk_id=item.chunk_id, # M10
138 )
139 for sim, item in results[:limit]
140 ]
142 async def delete(self, ids: list[str], bank_id: str) -> int:
143 count = 0
144 for vid in ids:
145 if vid in self._vectors and self._vectors[vid].bank_id == bank_id:
146 del self._vectors[vid]
147 count += 1
148 return count
150 async def get_by_chunk_ids(
151 self,
152 chunk_ids: list[str],
153 bank_id: str,
154 ) -> list[VectorHit]:
155 """M10 chunk expansion: return all vectors whose ``chunk_id`` is in the list.
157 Used by the recall pipeline's chunk-expansion stage when
158 ``source_aware_retrieval.chunk_expansion`` is on. Score is set
159 to ``1.0`` (the caller applies the expansion multiplier);
160 ``chunk_id`` round-trips so callers can group results by source.
161 """
162 if not chunk_ids:
163 return []
164 wanted = set(chunk_ids)
165 hits: list[VectorHit] = []
166 for item in self._vectors.values():
167 if item.bank_id != bank_id:
168 continue
169 if item.chunk_id is None or item.chunk_id not in wanted:
170 continue
171 hits.append(
172 VectorHit(
173 id=item.id,
174 text=item.text,
175 score=1.0,
176 metadata=item.metadata,
177 tags=item.tags,
178 fact_type=item.fact_type,
179 occurred_at=item.occurred_at,
180 memory_layer=item.memory_layer,
181 retained_at=item.retained_at,
182 chunk_id=item.chunk_id,
183 )
184 )
185 return hits
187 async def list_vectors(
188 self,
189 bank_id: str,
190 offset: int = 0,
191 limit: int = 100,
192 ) -> list[VectorItem]:
193 bank_items = sorted(
194 (v for v in self._vectors.values() if v.bank_id == bank_id),
195 key=lambda v: v.id,
196 )
197 return bank_items[offset : offset + limit]
199 async def health(self) -> HealthStatus:
200 return HealthStatus(healthy=True, message="in-memory vector store")
203# ---------------------------------------------------------------------------
204# In-memory Graph Store
205# ---------------------------------------------------------------------------
208class InMemoryGraphStore:
209 """Fully functional in-memory graph store for testing."""
211 SPI_VERSION: ClassVar[int] = 1
213 def __init__(self) -> None:
214 self._entities: dict[str, dict[str, Entity]] = {} # bank_id → {entity_id → Entity}
215 self._links: dict[str, list[EntityLink]] = {} # bank_id → links
216 self._memory_entity_map: dict[str, list[MemoryEntityAssociation]] = {} # bank_id → assocs
217 # Memory-to-memory links (Hindsight parity) — caused_by, semantic, etc.
218 self._memory_links: dict[str, list] = {} # bank_id → list[MemoryLink]
220 async def store_entities(self, entities: list[Entity], bank_id: str) -> list[str]:
221 if bank_id not in self._entities:
222 self._entities[bank_id] = {}
223 ids = []
224 for entity in entities:
225 self._entities[bank_id][entity.id] = entity
226 ids.append(entity.id)
227 return ids
229 async def store_links(self, links: list[EntityLink], bank_id: str) -> list[str]:
230 if bank_id not in self._links:
231 self._links[bank_id] = []
232 ids = []
233 for link in links:
234 lid = uuid.uuid4().hex[:12]
235 self._links[bank_id].append(link)
236 ids.append(lid)
237 return ids
239 async def link_memories_to_entities(
240 self,
241 associations: list[MemoryEntityAssociation],
242 bank_id: str,
243 ) -> None:
244 if bank_id not in self._memory_entity_map:
245 self._memory_entity_map[bank_id] = []
246 self._memory_entity_map[bank_id].extend(associations)
248 async def query_neighbors(
249 self,
250 entity_ids: list[str],
251 bank_id: str,
252 max_depth: int = 2,
253 limit: int = 20,
254 ) -> list[GraphHit]:
255 # Find memories linked to these entities within this bank.
256 # Per-memory ``connected_entities`` is the SUBSET of entity_ids
257 # that this memory is associated with — spreading activation uses
258 # this to decide which activated entity surfaced the memory.
259 per_memory: dict[str, set[str]] = {}
260 for assoc in self._memory_entity_map.get(bank_id, []):
261 if assoc.entity_id in entity_ids:
262 per_memory.setdefault(assoc.memory_id, set()).add(assoc.entity_id)
264 return [
265 GraphHit(
266 memory_id=mid,
267 text=f"[graph result for {mid}]",
268 connected_entities=sorted(connected),
269 depth=1,
270 score=0.5,
271 )
272 for mid, connected in list(per_memory.items())[:limit]
273 ]
275 async def query_entities(self, query: str, bank_id: str, limit: int = 10) -> list[Entity]:
276 query_lower = query.lower()
277 bank_entities = self._entities.get(bank_id, {})
278 results = [e for e in bank_entities.values() if query_lower in e.name.lower()]
279 return results[:limit]
281 async def find_entity_candidates(
282 self,
283 name: str,
284 bank_id: str,
285 threshold: float = 0.8,
286 limit: int = 5,
287 ) -> list[Entity]:
288 """Return entities whose name contains *name* as a substring (case-insensitive).
290 The in-memory implementation uses substring overlap as a proxy for
291 similarity; production adapters use vector or edit-distance similarity.
292 The *threshold* parameter is accepted for interface compatibility but
293 ignored — substring match is all-or-nothing.
294 """
295 name_lower = name.lower()
296 bank_entities = self._entities.get(bank_id, {})
297 results = [e for e in bank_entities.values() if name_lower in e.name.lower() or e.name.lower() in name_lower]
298 return results[:limit]
300 async def find_entity_candidates_scored(
301 self,
302 name: str,
303 bank_id: str,
304 *,
305 name_embedding: list[float] | None = None,
306 trigram_threshold: float = 0.3,
307 limit: int = 10,
308 ) -> list[EntityCandidateMatch]:
309 """In-memory equivalent of pg_trgm + cosine candidate scoring.
311 - ``name_similarity`` is computed with :class:`difflib.SequenceMatcher`
312 on lowercased names, which approximates PostgreSQL's pg_trgm
313 ``similarity()`` closely enough for tests.
314 - ``embedding_similarity`` is computed with the module's
315 :func:`_cosine_sim` against the candidate's stored embedding when
316 both sides have one; ``None`` otherwise.
317 - ``co_occurring_names`` are derived from the in-memory link list —
318 for each ``co_occurs`` link involving the candidate, the OTHER
319 entity's lowercased name is collected.
320 - ``last_seen`` mirrors PostgreSQL's ``updated_at`` semantics —
321 the in-memory store doesn't track update timestamps, so this is
322 ``None`` unless the entity has a ``last_seen`` in metadata.
324 Candidates with ``name_similarity < trigram_threshold`` are dropped.
325 Results are ordered by ``max(name_similarity, embedding_similarity or 0)``
326 descending.
327 """
328 from difflib import SequenceMatcher
330 bank_entities = self._entities.get(bank_id, {})
331 if not bank_entities:
332 return []
334 name_norm = (name or "").strip().lower()
335 if not name_norm:
336 return []
338 # Pre-build a map: candidate_id → list of co-occurring entity names.
339 # Walks the bank's links once; the resolver only iterates a
340 # pre-filtered candidate slice afterward.
341 cooccurrence_map: dict[str, list[str]] = {}
342 for link in self._links.get(bank_id, []):
343 if link.link_type != "co_occurs":
344 continue
345 other_a = bank_entities.get(link.entity_b)
346 other_b = bank_entities.get(link.entity_a)
347 if other_a is not None:
348 cooccurrence_map.setdefault(link.entity_a, []).append((other_a.name or "").strip().lower())
349 if other_b is not None:
350 cooccurrence_map.setdefault(link.entity_b, []).append((other_b.name or "").strip().lower())
352 scored: list[EntityCandidateMatch] = []
353 for entity in bank_entities.values():
354 cand_name = (entity.name or "").strip().lower()
355 if not cand_name:
356 continue
357 name_sim = SequenceMatcher(None, name_norm, cand_name).ratio()
358 if name_sim < trigram_threshold:
359 continue
361 emb_sim: float | None = None
362 if name_embedding is not None and entity.embedding is not None:
363 emb_sim = _cosine_sim(name_embedding, entity.embedding)
364 # Clamp to [0, 1] — _cosine_sim can return negative values
365 # for vectors that aren't non-negative.
366 emb_sim = max(0.0, min(1.0, emb_sim))
368 co_names = cooccurrence_map.get(entity.id, [])
369 last_seen = None
370 if entity.metadata and "last_seen" in entity.metadata:
371 # InMemoryGraphStore doesn't auto-stamp last_seen — tests can
372 # populate via metadata to exercise the temporal tier.
373 raw = entity.metadata["last_seen"]
374 if isinstance(raw, str):
375 try:
376 last_seen = datetime.fromisoformat(raw)
377 except ValueError:
378 last_seen = None
380 scored.append(
381 EntityCandidateMatch(
382 entity=entity,
383 name_similarity=name_sim,
384 embedding_similarity=emb_sim,
385 co_occurring_names=co_names,
386 last_seen=last_seen,
387 mention_count=getattr(entity, "mention_count", 1),
388 )
389 )
391 def _sort_key(m: EntityCandidateMatch) -> float:
392 return max(m.name_similarity, m.embedding_similarity or 0.0)
394 scored.sort(key=_sort_key, reverse=True)
395 return scored[:limit]
397 async def store_memory_links(
398 self,
399 links: list, # list[MemoryLink] — kept loose to avoid forward-ref noise
400 bank_id: str,
401 ) -> list[str]:
402 """Persist memory-to-memory links (Hindsight-parity)."""
403 store = self._memory_links.setdefault(bank_id, [])
404 ids: list[str] = []
405 for link in links:
406 lid = uuid.uuid4().hex[:12]
407 store.append(link)
408 ids.append(lid)
409 return ids
411 async def find_memory_links(
412 self,
413 seed_memory_ids: list[str],
414 bank_id: str,
415 *,
416 link_types: list[str] | None = None,
417 limit: int = 200,
418 ) -> list: # list[MemoryLink]
419 """Find links touching any seed memory in either direction."""
420 if not seed_memory_ids:
421 return []
422 seeds = set(seed_memory_ids)
423 type_filter = set(link_types) if link_types else None
424 out: list = []
425 for link in self._memory_links.get(bank_id, []):
426 if type_filter is not None and link.link_type not in type_filter:
427 continue
428 if link.source_memory_id in seeds or link.target_memory_id in seeds:
429 out.append(link)
430 if len(out) >= limit:
431 break
432 return out
434 async def get_entity_ids_for_memories(
435 self,
436 memory_ids: list[str],
437 bank_id: str,
438 ) -> dict[str, list[str]]:
439 """Return ``{memory_id: [entity_id, ...]}`` from the association table.
441 Used by the spreading-activation pipeline to seed entity IDs
442 directly from memory↔entity associations (the most-accurate
443 path; metadata-based extraction is the fallback).
444 """
445 if not memory_ids:
446 return {}
447 target = set(memory_ids)
448 out: dict[str, list[str]] = {}
449 for assoc in self._memory_entity_map.get(bank_id, []):
450 if assoc.memory_id in target:
451 out.setdefault(assoc.memory_id, []).append(assoc.entity_id)
452 return out
454 async def expand_entities_via_links(
455 self,
456 entity_ids: list[str],
457 bank_id: str,
458 *,
459 max_hops: int = 2,
460 link_types: list[str] | None = None,
461 ) -> dict[str, int]:
462 """BFS over entity-to-entity links to ``max_hops`` distance.
464 Returns ``{entity_id: hop_distance}`` where ``hop_distance == 0``
465 for entities in the input set, ``1`` for direct neighbors, etc.
467 Used by the spreading-activation pipeline to walk
468 ``co_occurs`` links from seed entities to their graph
469 neighborhood. Mirrors the AGE adapter's recursive-CTE
470 implementation so behavior is consistent across stores.
471 """
472 if not entity_ids:
473 return {}
474 accepted_types = set(link_types or ["co_occurs"])
475 bank_links = self._links.get(bank_id, [])
477 # Pre-build a neighbor map keyed by entity_id for O(1) lookup
478 # during the BFS frontier expansion.
479 neighbors: dict[str, set[str]] = {}
480 for link in bank_links:
481 if link.link_type not in accepted_types:
482 continue
483 neighbors.setdefault(link.entity_a, set()).add(link.entity_b)
484 neighbors.setdefault(link.entity_b, set()).add(link.entity_a)
486 distances: dict[str, int] = {eid: 0 for eid in entity_ids}
487 frontier = set(entity_ids)
488 for hop in range(1, max(1, max_hops) + 1):
489 next_frontier: set[str] = set()
490 for eid in frontier:
491 for nb in neighbors.get(eid, ()):
492 if nb not in distances:
493 distances[nb] = hop
494 next_frontier.add(nb)
495 if not next_frontier:
496 break
497 frontier = next_frontier
498 return distances
500 async def increment_mention_counts(
501 self,
502 entity_ids: list[str],
503 bank_id: str,
504 ) -> None:
505 """Bump ``mention_count`` for each canonical entity ID by 1.
507 Mirrors the AGE adapter's :meth:`increment_mention_counts` for
508 in-memory tests of the resolver's mention-count flow.
509 """
510 bank_entities = self._entities.get(bank_id, {})
511 for eid in entity_ids:
512 entity = bank_entities.get(eid)
513 if entity is None:
514 continue
515 # Entity is a frozen-ish dataclass — replace via dataclasses.replace
516 # to avoid mutating the original object held elsewhere.
517 from dataclasses import replace as _dc_replace
519 bank_entities[eid] = _dc_replace(
520 entity,
521 mention_count=int(getattr(entity, "mention_count", 1)) + 1,
522 )
524 async def store_entity_link(self, link: EntityLink, bank_id: str) -> str:
525 """Store a single resolved entity link (M11 entity resolution)."""
526 if bank_id not in self._links:
527 self._links[bank_id] = []
528 lid = uuid.uuid4().hex[:12]
529 self._links[bank_id].append(link)
530 return lid
532 async def health(self) -> HealthStatus:
533 return HealthStatus(healthy=True, message="in-memory graph store")
536# ---------------------------------------------------------------------------
537# In-memory Document Store
538# ---------------------------------------------------------------------------
541class InMemoryDocumentStore:
542 """Fully functional in-memory document store for testing."""
544 SPI_VERSION: ClassVar[int] = 1
546 def __init__(self) -> None:
547 self._docs: dict[str, tuple[Document, str]] = {} # id -> (doc, bank_id)
549 async def store_document(self, document: Document, bank_id: str) -> str:
550 self._docs[document.id] = (document, bank_id)
551 return document.id
553 async def search_fulltext(
554 self,
555 query: str,
556 bank_id: str,
557 limit: int = 10,
558 filters: DocumentFilters | None = None,
559 ) -> list[DocumentHit]:
560 query_terms = set(query.lower().split())
561 results: list[tuple[float, Document]] = []
562 for doc, bid in self._docs.values():
563 if bid != bank_id:
564 continue
565 doc_terms = set(doc.text.lower().split())
566 overlap = len(query_terms & doc_terms)
567 if overlap > 0:
568 score = overlap / max(len(query_terms), 1)
569 results.append((score, doc))
571 results.sort(key=lambda x: x[0], reverse=True)
572 return [
573 DocumentHit(document_id=doc.id, text=doc.text, score=score, metadata=doc.metadata)
574 for score, doc in results[:limit]
575 ]
577 async def get_document(self, document_id: str, bank_id: str) -> Document | None:
578 entry = self._docs.get(document_id)
579 if entry and entry[1] == bank_id:
580 return entry[0]
581 return None
583 async def health(self) -> HealthStatus:
584 return HealthStatus(healthy=True, message="in-memory document store")
587# ---------------------------------------------------------------------------
588# In-memory Engine Provider
589# ---------------------------------------------------------------------------
592class InMemoryEngineProvider:
593 """Fully functional in-memory engine provider for testing."""
595 SPI_VERSION: ClassVar[int] = 1
597 def __init__(self, supports_reflect: bool = True, supports_forget: bool = True) -> None:
598 self._memories: dict[str, list[MemoryHit]] = {} # bank_id -> memories
599 self._supports_reflect = supports_reflect
600 self._supports_forget = supports_forget
602 def capabilities(self) -> EngineCapabilities:
603 return EngineCapabilities(
604 supports_reflect=self._supports_reflect,
605 supports_forget=self._supports_forget,
606 supports_semantic_search=True,
607 supports_tags=True,
608 supports_metadata=True,
609 )
611 async def health(self) -> HealthStatus:
612 return HealthStatus(healthy=True, message="in-memory engine")
614 async def retain(self, request: RetainRequest) -> RetainResult:
615 if request.bank_id not in self._memories:
616 self._memories[request.bank_id] = []
617 mem_id = uuid.uuid4().hex[:16]
618 # Stamp _created_at so MIP forget min_age_days can be enforced. Existing
619 # callers that already set _created_at (eg. import/replay) win.
620 meta = dict(request.metadata) if request.metadata else {}
621 meta.setdefault("_created_at", datetime.now(UTC).isoformat())
622 self._memories[request.bank_id].append(
623 MemoryHit(
624 text=request.content,
625 score=1.0,
626 fact_type="world",
627 metadata=meta,
628 tags=request.tags,
629 occurred_at=request.occurred_at,
630 source=request.source,
631 memory_id=mem_id,
632 bank_id=request.bank_id,
633 )
634 )
635 return RetainResult(stored=True, memory_id=mem_id)
637 async def recall(self, request: RecallRequest) -> RecallResult:
638 memories = self._memories.get(request.bank_id, [])
640 # Apply filters before scoring
641 # MIP soft-delete: hide records flagged with `_deleted: true` in metadata.
642 # The flag is set by `soft_delete()` when forget.mode == "soft" so that
643 # records survive on disk for audit/restore but disappear from recall.
644 filtered = [m for m in memories if not (m.metadata and m.metadata.get("_deleted") is True)]
645 if request.tags:
646 tag_set = set(request.tags)
647 filtered = [m for m in filtered if m.tags and tag_set & set(m.tags)]
648 if request.fact_types:
649 ft_set = set(request.fact_types)
650 filtered = [m for m in filtered if m.fact_type in ft_set]
651 if request.time_range:
652 t_start, t_end = request.time_range
653 filtered = [m for m in filtered if m.occurred_at and t_start <= m.occurred_at <= t_end]
655 # Wildcard query returns all memories (used by export)
656 if request.query.strip() == "*":
657 scored = [(1.0, mem) for mem in filtered]
658 else:
659 # Simple keyword matching
660 query_terms = set(request.query.lower().split())
661 scored = []
662 for mem in filtered:
663 mem_terms = set(mem.text.lower().split())
664 overlap = len(query_terms & mem_terms)
665 if overlap > 0:
666 score = overlap / max(len(query_terms), 1)
667 scored.append((score, mem))
669 scored.sort(key=lambda x: x[0], reverse=True)
670 hits = [
671 MemoryHit(
672 text=mem.text,
673 score=score,
674 fact_type=mem.fact_type,
675 metadata=mem.metadata,
676 tags=mem.tags,
677 occurred_at=mem.occurred_at,
678 source=mem.source,
679 memory_id=mem.memory_id,
680 bank_id=request.bank_id,
681 )
682 for score, mem in scored[: request.max_results]
683 ]
684 return RecallResult(hits=hits, total_available=len(scored), truncated=len(scored) > request.max_results)
686 async def reflect(self, request: ReflectRequest) -> ReflectResult:
687 if not self._supports_reflect:
688 raise NotImplementedError("reflect not supported")
689 recall_result = await self.recall(RecallRequest(query=request.query, bank_id=request.bank_id))
690 answer = "Synthesis: " + "; ".join(h.text for h in recall_result.hits[:5])
691 return ReflectResult(answer=answer, sources=recall_result.hits[:5])
693 async def soft_delete(self, bank_id: str, memory_ids: list[str]) -> int:
694 """Mark records as soft-deleted by setting ``_deleted: true`` in metadata.
696 Used by MIP forget when ``forget.mode == "soft"``: records remain in
697 the store (auditable, restorable) but are hidden from recall.
698 Returns the number of records that were transitioned to deleted.
699 """
700 if not memory_ids:
701 return 0
702 ids_set = set(memory_ids)
703 count = 0
704 for m in self._memories.get(bank_id, []):
705 if m.memory_id in ids_set:
706 meta = dict(m.metadata) if m.metadata else {}
707 if meta.get("_deleted") is True:
708 continue # already soft-deleted
709 meta["_deleted"] = True
710 m.metadata = meta
711 count += 1
712 return count
714 async def forget(self, request: ForgetRequest) -> ForgetResult:
715 if not self._supports_forget:
716 raise NotImplementedError("forget not supported")
717 bank_memories = self._memories.get(request.bank_id, [])
719 if request.scope == "all":
720 count = len(bank_memories)
721 self._memories[request.bank_id] = []
722 return ForgetResult(deleted_count=count)
724 if request.memory_ids:
725 ids_set = set(request.memory_ids)
726 before = len(bank_memories)
727 bank_memories = [m for m in bank_memories if m.memory_id not in ids_set]
728 self._memories[request.bank_id] = bank_memories
729 return ForgetResult(deleted_count=before - len(bank_memories))
731 # Tag-based and/or date-based deletion
732 if request.tags or request.before_date:
733 before = len(bank_memories)
734 tag_set = set(request.tags) if request.tags else None
736 def _should_keep(m: MemoryHit) -> bool:
737 if tag_set and m.tags and tag_set & set(m.tags):
738 return False
739 if request.before_date and m.occurred_at and m.occurred_at < request.before_date:
740 return False
741 return True
743 bank_memories = [m for m in bank_memories if _should_keep(m)]
744 self._memories[request.bank_id] = bank_memories
745 return ForgetResult(deleted_count=before - len(bank_memories))
747 return ForgetResult(deleted_count=0)
750# ---------------------------------------------------------------------------
751# In-memory Wiki Store (M8)
752# ---------------------------------------------------------------------------
755class InMemoryWikiStore:
756 """Fully functional in-memory wiki store for testing (M8).
758 Stores current-revision pages and an audit log of past revisions.
759 Vector embeddings of pages are managed separately by the VectorStore
760 (stored with ``memory_layer="compiled"``); this store holds only the
761 structured WikiPage metadata.
762 """
764 SPI_VERSION: ClassVar[int] = 1
766 def __init__(self) -> None:
767 # bank_id → {page_id → WikiPage} (current revisions)
768 self._pages: dict[str, dict[str, WikiPage]] = {}
769 # bank_id → {page_id → list[WikiPage]} (past revisions, newest last)
770 self._history: dict[str, dict[str, list[WikiPage]]] = {}
772 async def upsert_page(self, page: WikiPage, bank_id: str) -> str:
773 if bank_id not in self._pages:
774 self._pages[bank_id] = {}
775 self._history[bank_id] = {}
777 existing = self._pages[bank_id].get(page.page_id)
778 if existing is not None:
779 # Archive current revision before replacing
780 self._history[bank_id].setdefault(page.page_id, []).append(existing)
781 # Increment revision on the incoming page
782 from dataclasses import replace as _replace
784 page = _replace(page, revision=existing.revision + 1)
786 self._pages[bank_id][page.page_id] = page
787 return page.page_id
789 async def get_page(self, page_id: str, bank_id: str) -> WikiPage | None:
790 return self._pages.get(bank_id, {}).get(page_id)
792 async def list_pages(
793 self,
794 bank_id: str,
795 scope: str | None = None,
796 kind: str | None = None,
797 ) -> list[WikiPage]:
798 pages = list(self._pages.get(bank_id, {}).values())
799 if scope is not None:
800 pages = [p for p in pages if p.scope == scope]
801 if kind is not None:
802 pages = [p for p in pages if p.kind == kind]
803 return pages
805 async def delete_page(self, page_id: str, bank_id: str) -> bool:
806 bank_pages = self._pages.get(bank_id, {})
807 if page_id not in bank_pages:
808 return False
809 del bank_pages[page_id]
810 self._history.get(bank_id, {}).pop(page_id, None)
811 return True
813 async def health(self) -> HealthStatus:
814 return HealthStatus(healthy=True, message="in-memory wiki store")
816 def revision_history(self, page_id: str, bank_id: str) -> list[WikiPage]:
817 """Return past revisions for a page (oldest first). Testing helper."""
818 return list(self._history.get(bank_id, {}).get(page_id, []))
821# ---------------------------------------------------------------------------
822# In-Memory PageIndex Store (M9 — section recall; see ADR-006/007)
823# ---------------------------------------------------------------------------
826class InMemoryPageIndexStore:
827 """Fully functional in-memory PageIndex store for testing (M9 section recall).
829 Holds documents (one per conversation), sections (tree nodes), entity
830 mentions per section, and section-to-section links. PR1 only exercises
831 the document/section path; PR2 populates entities and links.
833 See :class:`~astrocyte.provider.PageIndexStore` for the SPI.
834 """
836 SPI_VERSION: ClassVar[int] = 1
838 def __init__(self) -> None:
839 # (bank_id, source_id) → PageIndexDocument
840 self._documents: dict[tuple[str, str], PageIndexDocument] = {}
841 # document_id → list[PageIndexSection], ordered by line_num
842 self._sections: dict[str, list[PageIndexSection]] = {}
843 # document_id → list[PageIndexSectionEntity]
844 self._section_entities: dict[str, list[PageIndexSectionEntity]] = {}
845 # document_id → list[PageIndexSectionLink]
846 self._section_links: dict[str, list[PageIndexSectionLink]] = {}
847 # M10.1: bank_id → list[(WikiPage, embedding)]; provenance:
848 # page_id → list[(document_id, line_num)]
849 self._wiki_pages: dict[str, list[tuple["WikiPage", list[float] | None]]] = {}
850 self._wiki_provenance: dict[str, list[tuple[str, int]]] = {}
851 # M12.1: document_id → list[PageIndexFact]
852 self._facts: dict[str, list[PageIndexFact]] = {}
854 async def save_document(self, doc: PageIndexDocument) -> str:
855 # Upsert keyed on (bank_id, source_id). Behaviour matches the
856 # Postgres adapter:
857 # - If an existing row matches the key, preserve its id so
858 # child rows (sections, links) remain valid.
859 # - If no existing row AND the input id is empty/falsy,
860 # assign a fresh UUID (mirrors gen_random_uuid() default).
861 # - Otherwise honour the caller-supplied id.
862 import uuid as _uuid
863 from dataclasses import replace as _replace
865 key = (doc.bank_id, doc.source_id)
866 existing = self._documents.get(key)
867 if existing is not None:
868 doc = _replace(doc, id=existing.id)
869 elif not doc.id:
870 doc = _replace(doc, id=str(_uuid.uuid4()))
871 self._documents[key] = doc
872 return doc.id
874 async def save_sections(
875 self,
876 document_id: str,
877 sections: list[PageIndexSection],
878 ) -> int:
879 # Atomic replace — drop any prior sections for this document
880 # before writing the new tree. This matches the SQL adapter's
881 # DELETE + INSERT pattern; partial trees are never observable.
882 sorted_sections = sorted(sections, key=lambda s: s.line_num)
883 self._sections[document_id] = sorted_sections
884 # Cascading wipe of dependent rows mirrors the FK ON DELETE CASCADE
885 # in migration 015. PR2 will repopulate these.
886 self._section_entities.pop(document_id, None)
887 self._section_links.pop(document_id, None)
888 return len(sorted_sections)
890 async def load_document(
891 self,
892 bank_id: str,
893 source_id: str,
894 ) -> PageIndexDocument | None:
895 return self._documents.get((bank_id, source_id))
897 async def load_skeleton(self, document_id: str) -> list[PageIndexSection]:
898 # Return a shallow copy without summary_embedding to mirror the
899 # SQL adapter's projection (the picker doesn't need embeddings).
900 from dataclasses import replace as _replace
902 return [_replace(s, summary_embedding=None) for s in self._sections.get(document_id, [])]
904 async def save_section_embeddings(
905 self,
906 document_id: str,
907 embeddings: list[tuple[int, list[float]]],
908 ) -> int:
909 # Update existing PageIndexSection rows in place. PR2 commit A
910 # mirrors the Postgres adapter's UPDATE pattern: skip rows with
911 # no matching line_num (we don't auto-create sections from
912 # embeddings — the tree-build step is the source of truth).
913 from dataclasses import replace as _replace
915 sections = self._sections.get(document_id, [])
916 if not sections:
917 return 0
918 wanted = dict(embeddings) # line_num → vec
919 n = 0
920 new_sections: list[PageIndexSection] = []
921 for s in sections:
922 vec = wanted.pop(s.line_num, None)
923 if vec is not None:
924 new_sections.append(_replace(s, summary_embedding=list(vec)))
925 n += 1
926 else:
927 new_sections.append(s)
928 if n > 0:
929 self._sections[document_id] = new_sections
930 return n
932 async def save_section_entities(
933 self,
934 entities: list[PageIndexSectionEntity],
935 ) -> int:
936 n = 0
937 for e in entities:
938 bucket = self._section_entities.setdefault(e.document_id, [])
939 # Idempotent on the composite key.
940 if not any(
941 existing.line_num == e.line_num and existing.entity_name == e.entity_name for existing in bucket
942 ):
943 bucket.append(e)
944 n += 1
945 return n
947 async def save_section_links(
948 self,
949 links: list[PageIndexSectionLink],
950 ) -> int:
951 n = 0
952 for link in links:
953 bucket = self._section_links.setdefault(link.from_doc, [])
954 key = (link.from_line, link.to_doc, link.to_line, link.link_type)
955 if not any(
956 (existing.from_line, existing.to_doc, existing.to_line, existing.link_type) == key
957 for existing in bucket
958 ):
959 bucket.append(link)
960 n += 1
961 return n
963 async def populate_semantic_knn_links(
964 self,
965 document_id: str,
966 *,
967 top_k: int = 5,
968 min_similarity: float = 0.5,
969 ) -> int:
970 """PR2 D.7.1 in-memory implementation: pure cosine similarity
971 over each pair of sections in the document. Mirrors the SQL
972 adapter's LATERAL kNN — for each section, link to top_k
973 most-similar OTHER sections (cosine sim ≥ ``min_similarity``)."""
974 sections = self._sections.get(document_id, [])
975 embedded = [s for s in sections if s.summary_embedding]
976 if len(embedded) < 2:
977 return 0
979 new_links: list[PageIndexSectionLink] = []
980 for s1 in embedded:
981 scored: list[tuple[float, PageIndexSection]] = []
982 for s2 in embedded:
983 if s2.line_num == s1.line_num:
984 continue
985 sim = _cosine_sim(s1.summary_embedding, s2.summary_embedding)
986 if sim >= min_similarity:
987 scored.append((sim, s2))
988 scored.sort(key=lambda x: x[0], reverse=True)
989 for sim, s2 in scored[:top_k]:
990 new_links.append(
991 PageIndexSectionLink(
992 from_doc=document_id,
993 from_line=s1.line_num,
994 to_doc=document_id,
995 to_line=s2.line_num,
996 link_type="semantic_knn",
997 weight=float(sim),
998 )
999 )
1000 # Reuse save_section_links for idempotent insert semantics.
1001 return await self.save_section_links(new_links)
1003 # ── PR2 commit B: parallel-strategy query methods ─────────────────
1005 def _docs_in_bank(self, bank_id: str) -> set[str]:
1006 return {d.id for (b, _src), d in self._documents.items() if b == bank_id}
1008 def _section_passes_session(self, doc_id: str, line_num: int, session_filter: str | None) -> bool:
1009 """M31 Fix 2 — section-level helper (analogue of _matches_session
1010 for facts). Returns True when ``session_filter`` is None OR
1011 the section at (doc_id, line_num) has matching session_id.
1012 """
1013 if session_filter is None:
1014 return True
1015 for s in self._sections.get(doc_id, []):
1016 if s.line_num == line_num:
1017 return s.session_id == session_filter
1018 return False
1020 async def search_sections_semantic(
1021 self,
1022 bank_id: str,
1023 query_embedding: list[float],
1024 *,
1025 top_k: int = 20,
1026 session_filter: str | None = None,
1027 ) -> list[tuple[str, int, float]]:
1028 if not query_embedding:
1029 return []
1030 doc_ids = self._docs_in_bank(bank_id)
1031 scored: list[tuple[str, int, float]] = []
1032 for doc_id in doc_ids:
1033 for s in self._sections.get(doc_id, []):
1034 if not s.summary_embedding:
1035 continue
1036 if session_filter is not None and s.session_id != session_filter:
1037 continue
1038 # Cosine similarity (Postgres adapter uses 1 - distance).
1039 score = _cosine_sim(query_embedding, s.summary_embedding)
1040 scored.append((doc_id, s.line_num, score))
1041 scored.sort(key=lambda x: x[2], reverse=True)
1042 return scored[:top_k]
1044 async def search_sections_keyword(
1045 self,
1046 bank_id: str,
1047 query: str,
1048 *,
1049 top_k: int = 20,
1050 speaker: str | None = None,
1051 document_id: str | None = None,
1052 session_filter: str | None = None,
1053 ) -> list[tuple[str, int, float]]:
1054 if not query.strip():
1055 return []
1056 # Lightweight in-memory keyword scoring: count distinct query
1057 # tokens that appear in title+summary (case-insensitive).
1058 # Mirrors the spirit of ts_rank_cd's "more matched terms = higher
1059 # score" without re-implementing the full tsvector machinery.
1060 terms = [t for t in query.lower().split() if t]
1061 if not terms:
1062 return []
1063 if document_id is not None:
1064 # Caller wants single-doc scope (PR2.6 temporal-arithmetic
1065 # find_event_date use case); skip the bank fanout entirely.
1066 doc_ids = [document_id] if document_id in self._sections else []
1067 else:
1068 doc_ids = self._docs_in_bank(bank_id)
1069 scored: list[tuple[str, int, float]] = []
1070 for doc_id in doc_ids:
1071 for s in self._sections.get(doc_id, []):
1072 if speaker is not None and s.speaker != speaker:
1073 continue
1074 if session_filter is not None and s.session_id != session_filter:
1075 continue
1076 haystack = (s.title or "").lower() + " " + (s.summary or "").lower()
1077 hits = sum(1 for t in terms if t in haystack)
1078 if hits > 0:
1079 scored.append((doc_id, s.line_num, float(hits)))
1080 scored.sort(key=lambda x: x[2], reverse=True)
1081 return scored[:top_k]
1083 async def search_sections_by_entities(
1084 self,
1085 bank_id: str,
1086 entity_names: list[str],
1087 *,
1088 top_k: int = 20,
1089 session_filter: str | None = None,
1090 ) -> list[tuple[str, int, float]]:
1091 if not entity_names:
1092 return []
1093 # Hindsight's CTE pattern: count distinct matching entities per
1094 # section. Case-insensitive match on entity_name.
1095 wanted = {n.lower() for n in entity_names if n and n.strip()}
1096 if not wanted:
1097 return []
1098 doc_ids = self._docs_in_bank(bank_id)
1099 per_section: dict[tuple[str, int], set[str]] = {}
1100 for doc_id in doc_ids:
1101 for e in self._section_entities.get(doc_id, []):
1102 low = e.entity_name.lower()
1103 if low in wanted:
1104 if not self._section_passes_session(doc_id, e.line_num, session_filter):
1105 continue
1106 per_section.setdefault((doc_id, e.line_num), set()).add(low)
1107 scored = [(doc_id, line_num, float(len(matches))) for (doc_id, line_num), matches in per_section.items()]
1108 scored.sort(key=lambda x: x[2], reverse=True)
1109 return scored[:top_k]
1111 async def search_sections_temporal(
1112 self,
1113 bank_id: str,
1114 date_range,
1115 *,
1116 top_k: int = 20,
1117 session_filter: str | None = None,
1118 ) -> list[tuple[str, int, float]]:
1119 start, end = date_range
1120 doc_ids = self._docs_in_bank(bank_id)
1121 scored: list[tuple[str, int, float]] = []
1122 for doc_id in doc_ids:
1123 for s in self._sections.get(doc_id, []):
1124 if session_filter is not None and s.session_id != session_filter:
1125 continue
1126 sd = s.session_date
1127 if sd is not None and start <= sd <= end:
1128 scored.append((doc_id, s.line_num, 1.0))
1129 scored.sort(key=lambda x: (x[0], x[1])) # stable order
1130 return scored[:top_k]
1132 async def expand_section_links(
1133 self,
1134 seeds: list[tuple[str, int]],
1135 *,
1136 link_types: list[str] | None = None,
1137 top_k: int = 20,
1138 ) -> list[tuple[str, int, float]]:
1139 if not seeds:
1140 return []
1141 seed_set = {(d, ln) for d, ln in seeds}
1142 type_filter = set(link_types) if link_types else None
1143 weights: dict[tuple[str, int], float] = {}
1144 for doc_id, links in self._section_links.items():
1145 for link in links:
1146 if type_filter is not None and link.link_type not in type_filter:
1147 continue
1148 # Outgoing edge from a seed
1149 if (link.from_doc, link.from_line) in seed_set:
1150 key = (link.to_doc, link.to_line)
1151 weights[key] = weights.get(key, 0.0) + link.weight
1152 # Incoming edge to a seed
1153 if (link.to_doc, link.to_line) in seed_set:
1154 key = (link.from_doc, link.from_line)
1155 weights[key] = weights.get(key, 0.0) + link.weight
1156 scored = [(d, ln, w) for (d, ln), w in weights.items()]
1157 scored.sort(key=lambda x: x[2], reverse=True)
1158 return scored[:top_k]
1160 async def expand_sections_by_shared_entities(
1161 self,
1162 bank_id: str,
1163 seeds: list[tuple[str, int]],
1164 *,
1165 top_k: int = 20,
1166 exclude_seeds: bool = True,
1167 ) -> list[tuple[str, int, float]]:
1168 if not seeds:
1169 return []
1170 seed_set = {(d, ln) for d, ln in seeds}
1171 doc_ids = self._docs_in_bank(bank_id)
1172 # Collect seed entity names (case-insensitive set).
1173 seed_entities: set[str] = set()
1174 for doc_id, line_num in seed_set:
1175 for e in self._section_entities.get(doc_id, []):
1176 if e.line_num == line_num:
1177 seed_entities.add(e.entity_name.lower())
1178 if not seed_entities:
1179 return []
1180 # For every section in the bank, count distinct overlap with seed_entities.
1181 per_section: dict[tuple[str, int], set[str]] = {}
1182 for doc_id in doc_ids:
1183 for e in self._section_entities.get(doc_id, []):
1184 low = e.entity_name.lower()
1185 if low not in seed_entities:
1186 continue
1187 key = (doc_id, e.line_num)
1188 if exclude_seeds and key in seed_set:
1189 continue
1190 per_section.setdefault(key, set()).add(low)
1191 scored = [(doc_id, line_num, float(len(matches))) for (doc_id, line_num), matches in per_section.items()]
1192 scored.sort(key=lambda x: x[2], reverse=True)
1193 return scored[:top_k]
1195 async def list_distinct_entities(
1196 self,
1197 bank_id: str,
1198 document_id: str,
1199 *,
1200 pattern: str | None = None,
1201 limit: int = 50,
1202 ) -> list[tuple[str, int]]:
1203 # Filter section_entities to this document; substring-match the
1204 # entity_name when pattern is set (case-insensitive). Mirrors
1205 # the Postgres ILIKE semantics: pattern is a substring, ``%``
1206 # wildcards are NOT auto-injected — caller passes them if
1207 # desired (matches Postgres behaviour).
1208 entries = self._section_entities.get(document_id, [])
1209 if not entries:
1210 return []
1211 if pattern is not None:
1212 needle = pattern.lower().replace("%", "")
1213 entries = [e for e in entries if needle in e.entity_name.lower()]
1214 # Distinct (line_num, entity_name) collapses to one mention per
1215 # section — matches the Postgres PK (document_id, line_num,
1216 # entity_name).
1217 seen: set[tuple[int, str]] = set()
1218 counts: dict[str, int] = {}
1219 for e in entries:
1220 key = (e.line_num, e.entity_name)
1221 if key in seen:
1222 continue
1223 seen.add(key)
1224 counts[e.entity_name] = counts.get(e.entity_name, 0) + 1
1225 # Order by count desc, then name asc for determinism.
1226 ordered = sorted(counts.items(), key=lambda kv: (-kv[1], kv[0]))
1227 return ordered[: max(1, limit)]
1229 # ── M12.1 fact-grain ───────────────────────────────────────────
1231 async def save_facts(self, facts) -> int:
1232 if not facts:
1233 return 0
1234 for f in facts:
1235 self._facts.setdefault(f.document_id, []).append(f)
1236 return len(facts)
1238 async def update_fact_embeddings(self, embeddings) -> int:
1239 from dataclasses import replace as _replace
1241 if not embeddings:
1242 return 0
1243 emb_by_id = dict(embeddings)
1244 updated = 0
1245 for doc_id, bucket in self._facts.items():
1246 for i, f in enumerate(bucket):
1247 if f.id in emb_by_id:
1248 bucket[i] = _replace(f, embedding=emb_by_id[f.id])
1249 updated += 1
1250 return updated
1252 def _matches_session(self, fact: "PageIndexFact", session_filter: str | None) -> bool:
1253 """M31 Fix 2 — InMemory analogue of the Postgres EXISTS clause.
1255 Returns True when the fact's anchoring section's session_id matches
1256 ``session_filter``. Top-level facts (no anchor) and facts whose
1257 section has no session_id both fail the filter — intentional, since
1258 ``session_filter`` semantics is "limit to this session's content".
1259 """
1260 if session_filter is None:
1261 return True
1262 if fact.document_id is None or fact.line_num is None:
1263 return False
1264 for s in self._sections.get(fact.document_id, []):
1265 if s.line_num == fact.line_num:
1266 return s.session_id == session_filter
1267 return False
1269 async def search_facts_semantic(
1270 self,
1271 bank_id: str,
1272 query_embedding: list[float],
1273 *,
1274 top_k: int = 20,
1275 document_id: str | None = None,
1276 fact_type: str | None = None,
1277 session_filter: str | None = None,
1278 ):
1279 from astrocyte.types import PageIndexFactHit
1281 if not query_embedding:
1282 return []
1284 def _cos(a, b):
1285 import math
1287 dot = sum(x * y for x, y in zip(a, b))
1288 na = math.sqrt(sum(x * x for x in a))
1289 nb = math.sqrt(sum(x * x for x in b))
1290 return dot / (na * nb) if na > 0 and nb > 0 else 0.0
1292 scored: list[tuple[float, PageIndexFact]] = []
1293 scope = (
1294 self._facts.get(document_id, []) if document_id else [f for bucket in self._facts.values() for f in bucket]
1295 )
1296 for f in scope:
1297 if f.bank_id != bank_id:
1298 continue
1299 if fact_type is not None and f.fact_type != fact_type:
1300 continue
1301 if not f.embedding:
1302 continue
1303 if not self._matches_session(f, session_filter): # M31 Fix 2
1304 continue
1305 scored.append((_cos(query_embedding, f.embedding), f))
1306 scored.sort(key=lambda kv: kv[0], reverse=True)
1307 return [
1308 PageIndexFactHit(
1309 fact_id=f.id,
1310 document_id=f.document_id,
1311 line_num=f.line_num,
1312 text=f.text,
1313 fact_type=f.fact_type,
1314 speaker=f.speaker,
1315 occurred_start=f.occurred_start,
1316 occurred_end=f.occurred_end,
1317 entities=list(f.entities or []),
1318 confidence_score=getattr(f, "confidence_score", None), # M27 / M28-A
1319 mentioned_at=getattr(f, "mentioned_at", None), # M27
1320 event_date=getattr(f, "event_date", None), # M31 Fix 4
1321 score=float(s),
1322 )
1323 for s, f in scored[: max(1, top_k)]
1324 ]
1326 async def search_facts_by_entity(
1327 self,
1328 bank_id: str,
1329 entity_name: str,
1330 *,
1331 top_k: int = 50,
1332 document_id: str | None = None,
1333 fact_type: str | None = None, # M34-3 — per-fact-type segmentation
1334 session_filter: str | None = None,
1335 ):
1336 from astrocyte.types import PageIndexFactHit
1338 needle = entity_name.casefold()
1339 scope = (
1340 self._facts.get(document_id, []) if document_id else [f for bucket in self._facts.values() for f in bucket]
1341 )
1342 hits: list["PageIndexFactHit"] = []
1343 for f in scope:
1344 if f.bank_id != bank_id:
1345 continue
1346 if fact_type is not None and f.fact_type != fact_type:
1347 continue
1348 if not any(needle in (e or "").casefold() for e in (f.entities or [])):
1349 continue
1350 if not self._matches_session(f, session_filter): # M31 Fix 2
1351 continue
1352 hits.append(
1353 PageIndexFactHit(
1354 fact_id=f.id,
1355 document_id=f.document_id,
1356 line_num=f.line_num,
1357 text=f.text,
1358 fact_type=f.fact_type,
1359 speaker=f.speaker,
1360 occurred_start=f.occurred_start,
1361 occurred_end=f.occurred_end,
1362 entities=list(f.entities or []),
1363 confidence_score=getattr(f, "confidence_score", None),
1364 mentioned_at=getattr(f, "mentioned_at", None),
1365 event_date=getattr(f, "event_date", None), # M31 Fix 4
1366 score=1.0,
1367 )
1368 )
1369 if len(hits) >= top_k:
1370 break
1371 return hits
1373 async def search_facts_keyword(
1374 self,
1375 bank_id: str,
1376 query: str,
1377 *,
1378 top_k: int = 20,
1379 document_id: str | None = None,
1380 fact_type: str | None = None, # M34-3 — per-fact-type segmentation
1381 session_filter: str | None = None,
1382 ):
1383 """M31c — token-overlap analogue of Postgres BM25 over fact_text.
1385 Mirrors the spirit of ``search_sections_keyword``: lower-cases
1386 both query and fact text, splits on whitespace, counts how
1387 many distinct query tokens appear as substrings of the fact
1388 text, and ranks by that count desc. Not BM25 proper but the
1389 same family of "literal-keyword-match wins over generic
1390 thematic match" behaviour that fixes the SSU bench failures.
1391 """
1392 from astrocyte.types import PageIndexFactHit
1394 if not query or not query.strip():
1395 return []
1396 terms = [t for t in query.lower().split() if t and len(t) > 1]
1397 if not terms:
1398 return []
1399 terms = list(dict.fromkeys(terms)) # dedupe, preserve order
1400 scope = (
1401 self._facts.get(document_id, [])
1402 if document_id
1403 else [f for bucket in self._facts.values() for f in bucket]
1404 )
1405 scored: list[tuple[float, "PageIndexFact"]] = []
1406 for f in scope:
1407 if f.bank_id != bank_id:
1408 continue
1409 if fact_type is not None and f.fact_type != fact_type:
1410 continue
1411 if not self._matches_session(f, session_filter): # M31 Fix 2
1412 continue
1413 haystack = (f.text or "").lower()
1414 hits = sum(1 for t in terms if t in haystack)
1415 if hits > 0:
1416 scored.append((float(hits), f))
1417 scored.sort(key=lambda kv: kv[0], reverse=True)
1418 return [
1419 PageIndexFactHit(
1420 fact_id=f.id,
1421 document_id=f.document_id,
1422 line_num=f.line_num,
1423 text=f.text,
1424 fact_type=f.fact_type,
1425 speaker=f.speaker,
1426 occurred_start=f.occurred_start,
1427 occurred_end=f.occurred_end,
1428 entities=list(f.entities or []),
1429 confidence_score=getattr(f, "confidence_score", None),
1430 mentioned_at=getattr(f, "mentioned_at", None),
1431 event_date=getattr(f, "event_date", None),
1432 score=float(s),
1433 )
1434 for s, f in scored[: max(1, top_k)]
1435 ]
1437 async def search_facts_temporal(
1438 self,
1439 bank_id: str,
1440 date_range,
1441 *,
1442 top_k: int = 50,
1443 document_id: str | None = None,
1444 fact_type: str | None = None, # M34-3 — per-fact-type segmentation
1445 session_filter: str | None = None,
1446 ):
1447 from astrocyte.types import PageIndexFactHit
1449 start, end = date_range
1450 scope = (
1451 self._facts.get(document_id, []) if document_id else [f for bucket in self._facts.values() for f in bucket]
1452 )
1454 # M33-3b — Hindsight-parity 4-way OR (mirrors the Postgres SQL):
1455 # a fact qualifies if ANY of {spanning occurred range overlap,
1456 # mentioned_at in range, occurred_start in range, occurred_end
1457 # in range}. Sort DESC by freshest available date so the top_k
1458 # cap retains newest candidates (Hindsight: ORDER BY
1459 # COALESCE(occurred_start, mentioned_at, occurred_end) DESC).
1460 def _in_range(d):
1461 return d is not None and start <= d <= end
1463 def _qualifies(f) -> bool:
1464 os_ = f.occurred_start
1465 oe_ = f.occurred_end
1466 ma_ = getattr(f, "mentioned_at", None)
1467 if os_ is not None and oe_ is not None and os_ <= end and oe_ >= start:
1468 return True
1469 return _in_range(ma_) or _in_range(os_) or _in_range(oe_)
1471 def _sort_key(f):
1472 # COALESCE(occurred_start, mentioned_at, occurred_end).
1473 # Push None to the bottom by pairing with a falsy flag.
1474 d = f.occurred_start or getattr(f, "mentioned_at", None) or f.occurred_end
1475 return (d is None, d)
1477 matched = [
1478 f
1479 for f in scope
1480 if f.bank_id == bank_id
1481 and (fact_type is None or f.fact_type == fact_type) # M34-3
1482 and _qualifies(f)
1483 and self._matches_session(f, session_filter) # M31 Fix 2
1484 ]
1485 matched.sort(key=_sort_key, reverse=True)
1486 hits: list["PageIndexFactHit"] = []
1487 for f in matched[:top_k]:
1488 hits.append(
1489 PageIndexFactHit(
1490 fact_id=f.id,
1491 document_id=f.document_id,
1492 line_num=f.line_num,
1493 text=f.text,
1494 fact_type=f.fact_type,
1495 speaker=f.speaker,
1496 occurred_start=f.occurred_start,
1497 occurred_end=f.occurred_end,
1498 entities=list(f.entities or []),
1499 confidence_score=getattr(f, "confidence_score", None),
1500 mentioned_at=getattr(f, "mentioned_at", None),
1501 event_date=getattr(f, "event_date", None), # M31 Fix 4
1502 score=1.0,
1503 )
1504 )
1505 return hits
1507 async def count_facts_matching(
1508 self,
1509 bank_id: str,
1510 document_id: str,
1511 *,
1512 entity_pattern: str | None = None,
1513 fact_type: str | None = None,
1514 ) -> int:
1515 count = 0
1516 for f in self._facts.get(document_id, []):
1517 if f.bank_id != bank_id:
1518 continue
1519 if fact_type is not None and f.fact_type != fact_type:
1520 continue
1521 if entity_pattern is not None:
1522 needle = entity_pattern.lower().replace("%", "")
1523 if not any(needle in (e or "").lower() for e in (f.entities or [])):
1524 continue
1525 count += 1
1526 return count
1528 async def save_section_event_dates(
1529 self,
1530 document_id: str,
1531 event_dates,
1532 ) -> int:
1533 if not event_dates:
1534 return 0
1535 from dataclasses import replace as _replace
1537 sections = self._sections.get(document_id) or []
1538 if not sections:
1539 return 0
1540 date_by_line = {line_num: (start, end) for line_num, start, end in event_dates}
1541 updated = 0
1542 new_list = []
1543 for s in sections:
1544 if s.line_num in date_by_line:
1545 start, end = date_by_line[s.line_num]
1546 new_list.append(_replace(s, occurred_start=start, occurred_end=end))
1547 updated += 1
1548 else:
1549 new_list.append(s)
1550 self._sections[document_id] = new_list
1551 return updated
1553 # ── M10.1 wiki / consolidation ─────────────────────────────────
1555 async def load_sections_with_embeddings(
1556 self,
1557 bank_id: str,
1558 document_id: str,
1559 ) -> list[PageIndexSection]:
1560 del bank_id # documents are bank-scoped via _documents key
1561 return list(self._sections.get(document_id, []))
1563 async def save_wiki_page(
1564 self,
1565 *,
1566 page, # WikiPage
1567 embedding: list[float] | None,
1568 provenance: list[tuple[str, int]],
1569 ) -> str:
1570 bucket = self._wiki_pages.setdefault(page.bank_id, [])
1571 # Upsert: replace any existing page with the same page_id.
1572 bucket[:] = [(p, e) for (p, e) in bucket if p.page_id != page.page_id]
1573 bucket.append((page, embedding))
1574 self._wiki_provenance[page.page_id] = list(provenance)
1575 return page.page_id
1577 async def search_wiki_pages_semantic(
1578 self,
1579 bank_id: str,
1580 query_embedding: list[float],
1581 *,
1582 top_k: int = 5,
1583 document_id: str | None = None,
1584 ):
1585 from astrocyte.types import WikiPageHit
1587 bucket = self._wiki_pages.get(bank_id, [])
1588 if not bucket or not query_embedding:
1589 return []
1590 scope_filter = f"document:{document_id}" if document_id is not None else None
1592 def _cos(a: list[float], b: list[float]) -> float:
1593 import math
1595 dot = sum(x * y for x, y in zip(a, b))
1596 na = math.sqrt(sum(x * x for x in a))
1597 nb = math.sqrt(sum(x * x for x in b))
1598 if na == 0 or nb == 0:
1599 return 0.0
1600 return dot / (na * nb)
1602 scored: list[tuple[float, WikiPage]] = []
1603 for page, emb in bucket:
1604 if scope_filter is not None and page.scope != scope_filter:
1605 continue
1606 if not emb:
1607 continue
1608 scored.append((_cos(query_embedding, emb), page))
1609 scored.sort(key=lambda kv: kv[0], reverse=True)
1610 out: list[WikiPageHit] = []
1611 for score, page in scored[: max(1, top_k)]:
1612 out.append(
1613 WikiPageHit(
1614 page_id=page.page_id,
1615 title=page.title,
1616 content=page.content,
1617 scope=page.scope,
1618 kind=page.kind,
1619 score=float(score),
1620 source_ids=list(page.source_ids),
1621 bank_id=page.bank_id,
1622 )
1623 )
1624 return out
1626 async def count_wiki_pages_for_doc(
1627 self,
1628 bank_id: str,
1629 document_id: str,
1630 ) -> int:
1631 bucket = self._wiki_pages.get(bank_id, [])
1632 scope = f"document:{document_id}"
1633 return sum(1 for p, _ in bucket if p.scope == scope)
1635 async def list_wiki_pages_for_doc(
1636 self,
1637 bank_id: str,
1638 document_id: str,
1639 ) -> list[WikiPage]:
1640 bucket = self._wiki_pages.get(bank_id, [])
1641 scope = f"document:{document_id}"
1642 return [p for p, _ in bucket if p.scope == scope]
1644 async def list_wikis_affected_by_entities(
1645 self,
1646 bank_id: str,
1647 entities: list[str],
1648 *,
1649 min_overlap: int = 1,
1650 limit: int = 8,
1651 ) -> list[tuple[WikiPage, int, list[str]]]:
1652 """M14.2: scan in-memory wiki provenance for entity overlap.
1654 Mirrors the Postgres JOIN — for each wiki page in the bank,
1655 intersect its provenance sections' entity sets with the input
1656 entities. Returns the fully-hydrated ``WikiPage`` row, the
1657 overlap count, and the shared entity names — sorted descending
1658 by overlap count with page_id ascending as a tie-break.
1659 """
1660 if not entities:
1661 return []
1662 query_set = set(entities)
1663 bucket = self._wiki_pages.get(bank_id, [])
1664 results: list[tuple[WikiPage, int, list[str]]] = []
1665 for page, _embedding in bucket:
1666 prov = self._wiki_provenance.get(page.page_id, [])
1667 if not prov:
1668 continue
1669 wiki_entities: set[str] = set()
1670 for doc_id, line_num in prov:
1671 for ent in self._section_entities.get(doc_id, []):
1672 if ent.line_num == line_num:
1673 wiki_entities.add(ent.entity_name)
1674 shared = sorted(wiki_entities & query_set)
1675 if len(shared) >= min_overlap:
1676 results.append((page, len(shared), shared))
1677 results.sort(key=lambda r: (-r[1], r[0].page_id))
1678 return results[:limit]
1680 async def health(self) -> HealthStatus:
1681 return HealthStatus(healthy=True, message="in-memory pageindex store")
1684# ---------------------------------------------------------------------------
1685# In-Memory Mental Model Store (M9 — first-class, replaces wiki piggyback)
1686# ---------------------------------------------------------------------------
1689class InMemoryMentalModelStore:
1690 """Fully functional in-memory mental-model store for testing.
1692 Mirrors :class:`InMemoryWikiStore`'s shape but for the dedicated
1693 :class:`~astrocyte.provider.MentalModelStore` SPI. Holds the current
1694 revision of each model plus a history list of past revisions.
1695 """
1697 SPI_VERSION: ClassVar[int] = 1
1699 def __init__(self) -> None:
1700 # bank_id → {model_id → MentalModel} (current revisions)
1701 self._models: dict[str, dict[str, "MentalModel"]] = {}
1702 # bank_id → {model_id → list[MentalModel]} (past revisions, oldest first)
1703 self._history: dict[str, dict[str, list["MentalModel"]]] = {}
1705 async def upsert(self, model: "MentalModel", bank_id: str) -> int:
1706 from dataclasses import replace as _replace
1707 from datetime import UTC, datetime
1709 if bank_id not in self._models:
1710 self._models[bank_id] = {}
1711 self._history[bank_id] = {}
1713 existing = self._models[bank_id].get(model.model_id)
1714 new_revision = (existing.revision + 1) if existing is not None else 1
1715 if existing is not None:
1716 # Archive the prior current-revision before replacing.
1717 self._history[bank_id].setdefault(model.model_id, []).append(existing)
1719 # Store always assigns revision + refreshed_at (callers don't need
1720 # to fill these in). M40 — validate source_timestamps alignment;
1721 # drop on mismatch rather than persist a wrong-alignment array
1722 # that would silently corrupt trend computation downstream.
1723 src_ts = model.source_timestamps
1724 if src_ts is not None and len(src_ts) != len(model.source_ids):
1725 src_ts = None
1726 stamped = _replace(
1727 model,
1728 bank_id=bank_id,
1729 revision=new_revision,
1730 refreshed_at=datetime.now(UTC),
1731 source_timestamps=src_ts,
1732 )
1733 self._models[bank_id][model.model_id] = stamped
1734 return new_revision
1736 async def get(self, model_id: str, bank_id: str) -> "MentalModel | None":
1737 return self._models.get(bank_id, {}).get(model_id)
1739 async def list(
1740 self,
1741 bank_id: str,
1742 *,
1743 scope: str | None = None,
1744 kind: str | None = None,
1745 ) -> list["MentalModel"]:
1746 models = list(self._models.get(bank_id, {}).values())
1747 if scope is not None:
1748 models = [m for m in models if m.scope == scope]
1749 if kind is not None:
1750 models = [m for m in models if m.kind == kind]
1751 return models
1753 async def delete(self, model_id: str, bank_id: str) -> bool:
1754 bank_models = self._models.get(bank_id, {})
1755 if model_id not in bank_models:
1756 return False
1757 del bank_models[model_id]
1758 # Keep history for audit; tests can call ``revision_history()`` to
1759 # observe past revisions even after delete.
1760 return True
1762 async def update_via_ops(
1763 self,
1764 model_id: str,
1765 bank_id: str,
1766 operations_json: list[dict],
1767 ) -> "tuple[int, dict] | None":
1768 """M21 — apply structured delta operations and re-upsert.
1770 Lazy-migrates legacy rows by parsing ``content`` on first
1771 refresh. Re-renders ``content`` from the new structured doc so
1772 markdown readers see the updated text without a separate
1773 migration pass.
1774 """
1775 from dataclasses import replace as _replace
1777 from astrocyte.pipeline.delta_ops import (
1778 DeltaOperationList,
1779 apply_operations,
1780 )
1781 from astrocyte.pipeline.structured_doc import (
1782 StructuredDocument,
1783 parse_markdown,
1784 render_document,
1785 )
1787 current = self._models.get(bank_id, {}).get(model_id)
1788 if current is None:
1789 return None
1791 # Resolve the current structured doc — lazy-migrate from
1792 # raw markdown when the field is None (legacy rows).
1793 if current.structured_doc is None:
1794 doc = parse_markdown(current.content or "")
1795 else:
1796 doc = StructuredDocument.model_validate(current.structured_doc)
1798 # Parse the LLM/caller's op list via the Pydantic schema.
1799 # Schema-invalid op shapes (unknown ``op`` discriminator,
1800 # missing required fields) raise ValidationError; per the
1801 # conservative-failure contract, we catch and return a
1802 # zero-ops apply (doc stays as-is, no revision bump).
1803 try:
1804 ops_container = DeltaOperationList.model_validate({"operations": operations_json})
1805 except Exception:
1806 # Schema-level failure: nothing applied. Return current revision.
1807 return (current.revision, {"applied": [], "skipped": [], "changed": False})
1809 applied = apply_operations(doc, ops_container.operations)
1810 if not applied.changed:
1811 return (
1812 current.revision,
1813 {"applied": applied.applied, "skipped": applied.skipped, "changed": False},
1814 )
1816 new_doc_dict = applied.document.model_dump()
1817 new_content = render_document(applied.document)
1818 new_model = _replace(
1819 current,
1820 content=new_content,
1821 structured_doc=new_doc_dict,
1822 )
1823 new_revision = await self.upsert(new_model, bank_id)
1824 return (
1825 new_revision,
1826 {"applied": applied.applied, "skipped": applied.skipped, "changed": True},
1827 )
1829 async def refresh(
1830 self,
1831 model_id: str,
1832 bank_id: str,
1833 new_source_ids: list[str],
1834 ) -> "MentalModel | None":
1835 """M28 — merge new source_ids into existing model and bump revision.
1837 In-memory semantics: the production store will re-run an LLM
1838 compile against the merged source set; the test store simply
1839 acknowledges the new sources (deduped, preserving existing
1840 order then appending novel ids) and bumps the revision via
1841 :meth:`upsert`. That keeps tests focused on the wiring and
1842 contract surface without coupling them to LLM behaviour.
1843 """
1844 from dataclasses import replace as _replace
1846 current = self._models.get(bank_id, {}).get(model_id)
1847 if current is None:
1848 return None
1850 # Order-preserving dedup: keep existing source_ids in original
1851 # order, then append novel ids in input order.
1852 seen: set[str] = set(current.source_ids)
1853 merged = list(current.source_ids)
1854 for sid in new_source_ids:
1855 if sid not in seen:
1856 seen.add(sid)
1857 merged.append(sid)
1859 refreshed = _replace(current, source_ids=merged)
1860 await self.upsert(refreshed, bank_id)
1861 return self._models.get(bank_id, {}).get(model_id)
1863 async def health(self) -> HealthStatus:
1864 return HealthStatus(healthy=True, message="in-memory mental model store")
1866 def revision_history(self, model_id: str, bank_id: str) -> list["MentalModel"]:
1867 """Return past revisions for a model (oldest first). Testing helper."""
1868 return list(self._history.get(bank_id, {}).get(model_id, []))
1871# ---------------------------------------------------------------------------
1872# In-Memory Source Store (M10 — documents + chunks normalisation)
1873# ---------------------------------------------------------------------------
1876class InMemorySourceStore:
1877 """Fully functional in-memory source-document store for testing.
1879 Mirrors the :class:`~astrocyte.provider.SourceStore` SPI: document
1880 create/get/list/delete + bulk-chunk insert with content_hash dedup.
1881 Soft-deletes match the Postgres adapter's lifecycle (deleted_at on
1882 documents; cascade-by-store-eviction on chunks).
1883 """
1885 SPI_VERSION: ClassVar[int] = 1
1887 def __init__(self) -> None:
1888 # bank_id → {document_id → SourceDocument}
1889 self._docs: dict[str, dict[str, "SourceDocument"]] = {}
1890 # bank_id → {chunk_id → SourceChunk}
1891 self._chunks: dict[str, dict[str, "SourceChunk"]] = {}
1892 # bank_id → {(content_hash) → document_id} for fast dedup lookup
1893 self._doc_hash_index: dict[str, dict[str, str]] = {}
1894 # bank_id → {(content_hash) → chunk_id} for fast dedup lookup
1895 self._chunk_hash_index: dict[str, dict[str, str]] = {}
1897 async def store_document(self, document: "SourceDocument") -> str:
1898 from dataclasses import replace as _replace
1899 from datetime import UTC
1900 from datetime import datetime as _dt
1902 bank_id = document.bank_id
1903 self._docs.setdefault(bank_id, {})
1904 self._doc_hash_index.setdefault(bank_id, {})
1906 # Dedup by content_hash when set.
1907 if document.content_hash:
1908 existing_id = self._doc_hash_index[bank_id].get(document.content_hash)
1909 if existing_id is not None and existing_id in self._docs[bank_id]:
1910 return existing_id
1912 # Stamp created_at if not provided (store owns this field).
1913 stamped = _replace(
1914 document,
1915 created_at=document.created_at or _dt.now(UTC),
1916 )
1917 self._docs[bank_id][document.id] = stamped
1918 if document.content_hash:
1919 self._doc_hash_index[bank_id][document.content_hash] = document.id
1920 return document.id
1922 async def get_document(
1923 self,
1924 document_id: str,
1925 bank_id: str,
1926 ) -> "SourceDocument | None":
1927 return self._docs.get(bank_id, {}).get(document_id)
1929 async def find_document_by_hash(
1930 self,
1931 content_hash: str,
1932 bank_id: str,
1933 ) -> "SourceDocument | None":
1934 doc_id = self._doc_hash_index.get(bank_id, {}).get(content_hash)
1935 if doc_id is None:
1936 return None
1937 return self._docs.get(bank_id, {}).get(doc_id)
1939 async def list_documents(
1940 self,
1941 bank_id: str,
1942 *,
1943 limit: int = 100,
1944 ) -> list["SourceDocument"]:
1945 docs = list(self._docs.get(bank_id, {}).values())
1946 # Newest-first by created_at; tolerate Nones via a fallback.
1947 from datetime import UTC
1948 from datetime import datetime as _dt
1950 docs.sort(key=lambda d: d.created_at or _dt.fromtimestamp(0, UTC), reverse=True)
1951 return docs[:limit]
1953 async def delete_document(self, document_id: str, bank_id: str) -> bool:
1954 bank_docs = self._docs.get(bank_id, {})
1955 doc = bank_docs.pop(document_id, None)
1956 if doc is None:
1957 return False
1958 # Drop the hash-index entry too so a follow-up re-store doesn't
1959 # silently dedup against the deleted row.
1960 if doc.content_hash:
1961 self._doc_hash_index.get(bank_id, {}).pop(doc.content_hash, None)
1962 # Cascade: drop all chunks of this document.
1963 bank_chunks = self._chunks.get(bank_id, {})
1964 to_drop = [cid for cid, chunk in bank_chunks.items() if chunk.document_id == document_id]
1965 for cid in to_drop:
1966 chunk = bank_chunks.pop(cid)
1967 if chunk.content_hash:
1968 self._chunk_hash_index.get(bank_id, {}).pop(chunk.content_hash, None)
1969 return True
1971 async def store_chunks(self, chunks: list["SourceChunk"]) -> list[str]:
1972 from dataclasses import replace as _replace
1973 from datetime import UTC
1974 from datetime import datetime as _dt
1976 ids: list[str] = []
1977 for chunk in chunks:
1978 bank_id = chunk.bank_id
1979 self._chunks.setdefault(bank_id, {})
1980 self._chunk_hash_index.setdefault(bank_id, {})
1982 # Dedup by content_hash when set.
1983 if chunk.content_hash:
1984 existing = self._chunk_hash_index[bank_id].get(chunk.content_hash)
1985 if existing is not None and existing in self._chunks[bank_id]:
1986 ids.append(existing)
1987 continue
1989 stamped = _replace(
1990 chunk,
1991 created_at=chunk.created_at or _dt.now(UTC),
1992 )
1993 self._chunks[bank_id][chunk.id] = stamped
1994 if chunk.content_hash:
1995 self._chunk_hash_index[bank_id][chunk.content_hash] = chunk.id
1996 ids.append(chunk.id)
1997 return ids
1999 async def get_chunk(self, chunk_id: str, bank_id: str) -> "SourceChunk | None":
2000 return self._chunks.get(bank_id, {}).get(chunk_id)
2002 async def list_chunks(
2003 self,
2004 document_id: str,
2005 bank_id: str,
2006 ) -> list["SourceChunk"]:
2007 chunks = [c for c in self._chunks.get(bank_id, {}).values() if c.document_id == document_id]
2008 chunks.sort(key=lambda c: c.chunk_index)
2009 return chunks
2011 async def find_chunk_by_hash(
2012 self,
2013 content_hash: str,
2014 bank_id: str,
2015 ) -> "SourceChunk | None":
2016 chunk_id = self._chunk_hash_index.get(bank_id, {}).get(content_hash)
2017 if chunk_id is None:
2018 return None
2019 return self._chunks.get(bank_id, {}).get(chunk_id)
2021 async def health(self) -> HealthStatus:
2022 return HealthStatus(healthy=True, message="in-memory source store")
2025# ---------------------------------------------------------------------------
2026# Mock LLM Provider
2027# ---------------------------------------------------------------------------
2030class MockLLMProvider:
2031 """Mock LLM provider with bag-of-words embeddings and extractive synthesis.
2033 Embeddings use term-frequency vectors over a shared vocabulary, so
2034 semantically related texts have high cosine similarity — enabling
2035 meaningful vector retrieval without an API key.
2037 Synthesis extracts the most relevant memory text from the prompt context
2038 rather than returning a static string.
2039 """
2041 SPI_VERSION: ClassVar[int] = 1
2043 _embed_dim: ClassVar[int] = 128
2045 def __init__(
2046 self,
2047 default_response: str = "Mock LLM response",
2048 embedding_dimensions: int | None = None,
2049 ) -> None:
2050 self._default_response = default_response
2051 if embedding_dimensions is not None:
2052 self._embed_dim = int(embedding_dimensions)
2053 self._call_count = 0
2054 #: Last ``complete()`` user message content (for tests asserting prompt structure).
2055 self.last_user_message: str | None = None
2057 def capabilities(self) -> LLMCapabilities:
2058 return LLMCapabilities()
2060 async def complete(
2061 self,
2062 messages: list[Message],
2063 model: str | None = None,
2064 max_tokens: int = 1024,
2065 temperature: float = 0.0,
2066 tools: list = None, # list[ToolDefinition] | None — kept untyped to avoid import cycle
2067 tool_choice: str | None = None,
2068 response_format: dict | None = None,
2069 ) -> Completion:
2070 self._call_count += 1
2071 for m in reversed(messages):
2072 if m.role == "user" and isinstance(m.content, str):
2073 self.last_user_message = m.content
2074 break
2075 all_text = " ".join(m.content for m in messages if isinstance(m.content, str))
2076 # Entity extraction prompt
2077 if "extract named entities" in all_text.lower():
2078 return Completion(
2079 text='[{"name": "Test Entity", "entity_type": "OTHER", "aliases": []}]',
2080 model=model or "mock",
2081 usage=TokenUsage(input_tokens=10, output_tokens=20),
2082 )
2083 # Memory synthesis prompt — extract most relevant memory text
2084 if "<memories>" in all_text and "<query>" in all_text:
2085 answer = _extractive_synthesize(all_text)
2086 return Completion(
2087 text=answer,
2088 model=model or "mock",
2089 usage=TokenUsage(input_tokens=50, output_tokens=30),
2090 )
2091 return Completion(
2092 text=self._default_response,
2093 model=model or "mock",
2094 usage=TokenUsage(input_tokens=10, output_tokens=20),
2095 )
2097 async def embed(self, texts: list[str], model: str | None = None) -> list[list[float]]:
2098 """Generate bag-of-words embeddings with real semantic signal.
2100 Uses term-frequency vectors over a shared vocabulary. Texts that share
2101 words have high cosine similarity, enabling meaningful retrieval.
2102 """
2103 return [self._bow_embed(text) for text in texts]
2105 def _bow_embed(self, text: str) -> list[float]:
2106 """Build a term-frequency vector using the hashing trick.
2108 Each word maps to a bucket via ``hash(word) % dim`` and increments
2109 that dimension. All components are non-negative, guaranteeing
2110 non-negative cosine similarities between any two vectors — required
2111 because ``VectorHit.score`` enforces ``>= 0.0``.
2112 """
2113 import hashlib
2114 from string import punctuation
2116 dim = self._embed_dim
2117 vec = [0.0] * dim
2118 tokens = [t.strip(punctuation).lower() for t in text.split() if t.strip(punctuation)]
2120 for token in tokens:
2121 if not token:
2122 continue
2123 h = int(hashlib.md5(token.encode()).hexdigest(), 16)
2124 bucket = h % dim
2125 vec[bucket] += 1.0
2127 # L2 normalize
2128 norm = math.sqrt(sum(x * x for x in vec))
2129 if norm > 0:
2130 vec = [x / norm for x in vec]
2131 return vec
2134def _normalize_terms(text: str) -> set[str]:
2135 """Lowercase, strip punctuation, drop short words."""
2136 from string import punctuation
2138 return {t.strip(punctuation).lower() for t in text.split() if len(t.strip(punctuation)) > 2}
2141def _extractive_synthesize(prompt_text: str) -> str:
2142 """Extract the most query-relevant memory from the synthesis prompt.
2144 Parses ``<memories>`` and ``<query>`` blocks from the reflect prompt,
2145 scores each memory by word overlap with the query, and returns
2146 the top memories concatenated. This gives the mock provider
2147 meaningful reflect answers without a real LLM.
2148 """
2149 import re
2151 # Extract query
2152 query_match = re.search(r"<query>\s*(.*?)\s*</query>", prompt_text, re.DOTALL)
2153 if not query_match:
2154 return "No query found."
2155 query = query_match.group(1).strip()
2156 query_terms = _normalize_terms(query)
2158 # Extract individual memories
2159 memories_match = re.search(r"<memories>\s*(.*?)\s*</memories>", prompt_text, re.DOTALL)
2160 if not memories_match:
2161 return "No memories found."
2163 # Group continuation lines with their [Memory N] header.
2164 # Lines starting with [Memory N] start a new block; other lines
2165 # are appended to the current block.
2166 raw_lines = memories_match.group(1).strip().split("\n")
2167 blocks: list[str] = []
2168 for line in raw_lines:
2169 line = line.strip()
2170 if not line:
2171 continue
2172 if re.match(r"^\[Memory \d+\]", line):
2173 # Strip [Memory N] prefix for cleaner output
2174 text = re.sub(r"^\[Memory \d+\](\s*\([^)]*\))?(\s*\[[^\]]*\])?\s*:\s*", "", line)
2175 blocks.append(text)
2176 elif blocks:
2177 blocks[-1] += " " + line
2178 else:
2179 blocks.append(line)
2181 # Score each memory block by query term overlap
2182 scored: list[tuple[float, str]] = []
2183 for text in blocks:
2184 block_terms = _normalize_terms(text)
2185 overlap = len(query_terms & block_terms)
2186 if overlap > 0:
2187 scored.append((overlap, text))
2189 if not scored:
2190 return "I don't have relevant information to answer this question."
2192 scored.sort(key=lambda x: x[0], reverse=True)
2193 # Return top 3 most relevant memories
2194 top = [text for _, text in scored[:3]]
2195 return " ".join(top)