Coverage for astrocyte/testing/in_memory.py: 84%

1057 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""In-memory provider implementations for testing. 

2 

3These are fully functional providers backed by Python dicts/lists. 

4Used by conformance tests and integration tests. 

5""" 

6 

7from __future__ import annotations 

8 

9import math 

10import uuid 

11from datetime import UTC, datetime 

12from typing import ClassVar 

13 

14from astrocyte.types import ( 

15 Completion, 

16 Document, 

17 DocumentFilters, 

18 DocumentHit, 

19 EngineCapabilities, 

20 Entity, 

21 EntityCandidateMatch, 

22 EntityLink, 

23 ForgetRequest, 

24 ForgetResult, 

25 GraphHit, 

26 HealthStatus, 

27 LLMCapabilities, 

28 MemoryEntityAssociation, 

29 MemoryHit, 

30 MentalModel, 

31 Message, 

32 PageIndexDocument, 

33 PageIndexFact, 

34 PageIndexSection, 

35 PageIndexSectionEntity, 

36 PageIndexSectionLink, 

37 RecallRequest, 

38 RecallResult, 

39 ReflectRequest, 

40 ReflectResult, 

41 RetainRequest, 

42 RetainResult, 

43 SourceChunk, 

44 SourceDocument, 

45 TokenUsage, 

46 VectorFilters, 

47 VectorHit, 

48 VectorItem, 

49 WikiPage, 

50) 

51 

52 

53def _cosine_sim(a: list[float], b: list[float]) -> float: 

54 dot = sum(x * y for x, y in zip(a, b)) 

55 na = math.sqrt(sum(x * x for x in a)) 

56 nb = math.sqrt(sum(x * x for x in b)) 

57 if na == 0 or nb == 0: 

58 return 0.0 

59 return dot / (na * nb) 

60 

61 

62# --------------------------------------------------------------------------- 

63# In-memory Vector Store 

64# --------------------------------------------------------------------------- 

65 

66 

67class InMemoryVectorStore: 

68 """Fully functional in-memory vector store for testing.""" 

69 

70 SPI_VERSION: ClassVar[int] = 1 

71 

72 def __init__(self) -> None: 

73 self._vectors: dict[str, VectorItem] = {} 

74 

75 async def store_vectors(self, items: list[VectorItem]) -> list[str]: 

76 ids = [] 

77 for item in items: 

78 self._vectors[item.id] = item 

79 ids.append(item.id) 

80 return ids 

81 

82 async def search_similar( 

83 self, 

84 query_vector: list[float], 

85 bank_id: str, 

86 limit: int = 10, 

87 filters: VectorFilters | None = None, 

88 ) -> list[VectorHit]: 

89 results: list[tuple[float, VectorItem]] = [] 

90 for item in self._vectors.values(): 

91 if item.bank_id != bank_id: 

92 continue 

93 if filters: 

94 if filters.tags and item.tags: 

95 if not set(filters.tags) & set(item.tags): 

96 continue 

97 if filters.fact_types and item.fact_type: 

98 if item.fact_type not in filters.fact_types: 

99 continue 

100 # M9: time-travel filter — exclude items retained after as_of 

101 if filters.as_of is not None and item.retained_at is not None: 

102 if item.retained_at > filters.as_of: 

103 continue 

104 if filters.time_range is not None and item.occurred_at is not None: 

105 start, end = filters.time_range 

106 if item.occurred_at < start or item.occurred_at > end: 

107 continue 

108 # M31 Fix 2 — session_id filter. ``metadata['session_id']`` 

109 # is the canonical location (preserves the same shape 

110 # production callers use when calling retain()). When the 

111 # filter is set but the item has no session_id metadata, 

112 # the item is excluded (consistent with the 

113 # PageIndexStore behaviour where top-level facts without 

114 # an anchoring section are excluded from session-filtered 

115 # results). 

116 if filters.session_id is not None: 

117 item_session = None 

118 if item.metadata: 

119 item_session = item.metadata.get("session_id") if isinstance(item.metadata, dict) else None 

120 if item_session != filters.session_id: 

121 continue 

122 sim = _cosine_sim(query_vector, item.vector) 

123 results.append((sim, item)) 

124 

125 results.sort(key=lambda x: x[0], reverse=True) 

126 return [ 

127 VectorHit( 

128 id=item.id, 

129 text=item.text, 

130 score=sim, 

131 metadata=item.metadata, 

132 tags=item.tags, 

133 fact_type=item.fact_type, 

134 occurred_at=item.occurred_at, 

135 memory_layer=item.memory_layer, 

136 retained_at=item.retained_at, # M9 

137 chunk_id=item.chunk_id, # M10 

138 ) 

139 for sim, item in results[:limit] 

140 ] 

141 

142 async def delete(self, ids: list[str], bank_id: str) -> int: 

143 count = 0 

144 for vid in ids: 

145 if vid in self._vectors and self._vectors[vid].bank_id == bank_id: 

146 del self._vectors[vid] 

147 count += 1 

148 return count 

149 

150 async def get_by_chunk_ids( 

151 self, 

152 chunk_ids: list[str], 

153 bank_id: str, 

154 ) -> list[VectorHit]: 

155 """M10 chunk expansion: return all vectors whose ``chunk_id`` is in the list. 

156 

157 Used by the recall pipeline's chunk-expansion stage when 

158 ``source_aware_retrieval.chunk_expansion`` is on. Score is set 

159 to ``1.0`` (the caller applies the expansion multiplier); 

160 ``chunk_id`` round-trips so callers can group results by source. 

161 """ 

162 if not chunk_ids: 

163 return [] 

164 wanted = set(chunk_ids) 

165 hits: list[VectorHit] = [] 

166 for item in self._vectors.values(): 

167 if item.bank_id != bank_id: 

168 continue 

169 if item.chunk_id is None or item.chunk_id not in wanted: 

170 continue 

171 hits.append( 

172 VectorHit( 

173 id=item.id, 

174 text=item.text, 

175 score=1.0, 

176 metadata=item.metadata, 

177 tags=item.tags, 

178 fact_type=item.fact_type, 

179 occurred_at=item.occurred_at, 

180 memory_layer=item.memory_layer, 

181 retained_at=item.retained_at, 

182 chunk_id=item.chunk_id, 

183 ) 

184 ) 

185 return hits 

186 

187 async def list_vectors( 

188 self, 

189 bank_id: str, 

190 offset: int = 0, 

191 limit: int = 100, 

192 ) -> list[VectorItem]: 

193 bank_items = sorted( 

194 (v for v in self._vectors.values() if v.bank_id == bank_id), 

195 key=lambda v: v.id, 

196 ) 

197 return bank_items[offset : offset + limit] 

198 

199 async def health(self) -> HealthStatus: 

200 return HealthStatus(healthy=True, message="in-memory vector store") 

201 

202 

203# --------------------------------------------------------------------------- 

204# In-memory Graph Store 

205# --------------------------------------------------------------------------- 

206 

207 

208class InMemoryGraphStore: 

209 """Fully functional in-memory graph store for testing.""" 

210 

211 SPI_VERSION: ClassVar[int] = 1 

212 

213 def __init__(self) -> None: 

214 self._entities: dict[str, dict[str, Entity]] = {} # bank_id → {entity_id → Entity} 

215 self._links: dict[str, list[EntityLink]] = {} # bank_id → links 

216 self._memory_entity_map: dict[str, list[MemoryEntityAssociation]] = {} # bank_id → assocs 

217 # Memory-to-memory links (Hindsight parity) — caused_by, semantic, etc. 

218 self._memory_links: dict[str, list] = {} # bank_id → list[MemoryLink] 

219 

220 async def store_entities(self, entities: list[Entity], bank_id: str) -> list[str]: 

221 if bank_id not in self._entities: 

222 self._entities[bank_id] = {} 

223 ids = [] 

224 for entity in entities: 

225 self._entities[bank_id][entity.id] = entity 

226 ids.append(entity.id) 

227 return ids 

228 

229 async def store_links(self, links: list[EntityLink], bank_id: str) -> list[str]: 

230 if bank_id not in self._links: 

231 self._links[bank_id] = [] 

232 ids = [] 

233 for link in links: 

234 lid = uuid.uuid4().hex[:12] 

235 self._links[bank_id].append(link) 

236 ids.append(lid) 

237 return ids 

238 

239 async def link_memories_to_entities( 

240 self, 

241 associations: list[MemoryEntityAssociation], 

242 bank_id: str, 

243 ) -> None: 

244 if bank_id not in self._memory_entity_map: 

245 self._memory_entity_map[bank_id] = [] 

246 self._memory_entity_map[bank_id].extend(associations) 

247 

248 async def query_neighbors( 

249 self, 

250 entity_ids: list[str], 

251 bank_id: str, 

252 max_depth: int = 2, 

253 limit: int = 20, 

254 ) -> list[GraphHit]: 

255 # Find memories linked to these entities within this bank. 

256 # Per-memory ``connected_entities`` is the SUBSET of entity_ids 

257 # that this memory is associated with — spreading activation uses 

258 # this to decide which activated entity surfaced the memory. 

259 per_memory: dict[str, set[str]] = {} 

260 for assoc in self._memory_entity_map.get(bank_id, []): 

261 if assoc.entity_id in entity_ids: 

262 per_memory.setdefault(assoc.memory_id, set()).add(assoc.entity_id) 

263 

264 return [ 

265 GraphHit( 

266 memory_id=mid, 

267 text=f"[graph result for {mid}]", 

268 connected_entities=sorted(connected), 

269 depth=1, 

270 score=0.5, 

271 ) 

272 for mid, connected in list(per_memory.items())[:limit] 

273 ] 

274 

275 async def query_entities(self, query: str, bank_id: str, limit: int = 10) -> list[Entity]: 

276 query_lower = query.lower() 

277 bank_entities = self._entities.get(bank_id, {}) 

278 results = [e for e in bank_entities.values() if query_lower in e.name.lower()] 

279 return results[:limit] 

280 

281 async def find_entity_candidates( 

282 self, 

283 name: str, 

284 bank_id: str, 

285 threshold: float = 0.8, 

286 limit: int = 5, 

287 ) -> list[Entity]: 

288 """Return entities whose name contains *name* as a substring (case-insensitive). 

289 

290 The in-memory implementation uses substring overlap as a proxy for 

291 similarity; production adapters use vector or edit-distance similarity. 

292 The *threshold* parameter is accepted for interface compatibility but 

293 ignored — substring match is all-or-nothing. 

294 """ 

295 name_lower = name.lower() 

296 bank_entities = self._entities.get(bank_id, {}) 

297 results = [e for e in bank_entities.values() if name_lower in e.name.lower() or e.name.lower() in name_lower] 

298 return results[:limit] 

299 

300 async def find_entity_candidates_scored( 

301 self, 

302 name: str, 

303 bank_id: str, 

304 *, 

305 name_embedding: list[float] | None = None, 

306 trigram_threshold: float = 0.3, 

307 limit: int = 10, 

308 ) -> list[EntityCandidateMatch]: 

309 """In-memory equivalent of pg_trgm + cosine candidate scoring. 

310 

311 - ``name_similarity`` is computed with :class:`difflib.SequenceMatcher` 

312 on lowercased names, which approximates PostgreSQL's pg_trgm 

313 ``similarity()`` closely enough for tests. 

314 - ``embedding_similarity`` is computed with the module's 

315 :func:`_cosine_sim` against the candidate's stored embedding when 

316 both sides have one; ``None`` otherwise. 

317 - ``co_occurring_names`` are derived from the in-memory link list — 

318 for each ``co_occurs`` link involving the candidate, the OTHER 

319 entity's lowercased name is collected. 

320 - ``last_seen`` mirrors PostgreSQL's ``updated_at`` semantics — 

321 the in-memory store doesn't track update timestamps, so this is 

322 ``None`` unless the entity has a ``last_seen`` in metadata. 

323 

324 Candidates with ``name_similarity < trigram_threshold`` are dropped. 

325 Results are ordered by ``max(name_similarity, embedding_similarity or 0)`` 

326 descending. 

327 """ 

328 from difflib import SequenceMatcher 

329 

330 bank_entities = self._entities.get(bank_id, {}) 

331 if not bank_entities: 

332 return [] 

333 

334 name_norm = (name or "").strip().lower() 

335 if not name_norm: 

336 return [] 

337 

338 # Pre-build a map: candidate_id → list of co-occurring entity names. 

339 # Walks the bank's links once; the resolver only iterates a 

340 # pre-filtered candidate slice afterward. 

341 cooccurrence_map: dict[str, list[str]] = {} 

342 for link in self._links.get(bank_id, []): 

343 if link.link_type != "co_occurs": 

344 continue 

345 other_a = bank_entities.get(link.entity_b) 

346 other_b = bank_entities.get(link.entity_a) 

347 if other_a is not None: 

348 cooccurrence_map.setdefault(link.entity_a, []).append((other_a.name or "").strip().lower()) 

349 if other_b is not None: 

350 cooccurrence_map.setdefault(link.entity_b, []).append((other_b.name or "").strip().lower()) 

351 

352 scored: list[EntityCandidateMatch] = [] 

353 for entity in bank_entities.values(): 

354 cand_name = (entity.name or "").strip().lower() 

355 if not cand_name: 

356 continue 

357 name_sim = SequenceMatcher(None, name_norm, cand_name).ratio() 

358 if name_sim < trigram_threshold: 

359 continue 

360 

361 emb_sim: float | None = None 

362 if name_embedding is not None and entity.embedding is not None: 

363 emb_sim = _cosine_sim(name_embedding, entity.embedding) 

364 # Clamp to [0, 1] — _cosine_sim can return negative values 

365 # for vectors that aren't non-negative. 

366 emb_sim = max(0.0, min(1.0, emb_sim)) 

367 

368 co_names = cooccurrence_map.get(entity.id, []) 

369 last_seen = None 

370 if entity.metadata and "last_seen" in entity.metadata: 

371 # InMemoryGraphStore doesn't auto-stamp last_seen — tests can 

372 # populate via metadata to exercise the temporal tier. 

373 raw = entity.metadata["last_seen"] 

374 if isinstance(raw, str): 

375 try: 

376 last_seen = datetime.fromisoformat(raw) 

377 except ValueError: 

378 last_seen = None 

379 

380 scored.append( 

381 EntityCandidateMatch( 

382 entity=entity, 

383 name_similarity=name_sim, 

384 embedding_similarity=emb_sim, 

385 co_occurring_names=co_names, 

386 last_seen=last_seen, 

387 mention_count=getattr(entity, "mention_count", 1), 

388 ) 

389 ) 

390 

391 def _sort_key(m: EntityCandidateMatch) -> float: 

392 return max(m.name_similarity, m.embedding_similarity or 0.0) 

393 

394 scored.sort(key=_sort_key, reverse=True) 

395 return scored[:limit] 

396 

397 async def store_memory_links( 

398 self, 

399 links: list, # list[MemoryLink] — kept loose to avoid forward-ref noise 

400 bank_id: str, 

401 ) -> list[str]: 

402 """Persist memory-to-memory links (Hindsight-parity).""" 

403 store = self._memory_links.setdefault(bank_id, []) 

404 ids: list[str] = [] 

405 for link in links: 

406 lid = uuid.uuid4().hex[:12] 

407 store.append(link) 

408 ids.append(lid) 

409 return ids 

410 

411 async def find_memory_links( 

412 self, 

413 seed_memory_ids: list[str], 

414 bank_id: str, 

415 *, 

416 link_types: list[str] | None = None, 

417 limit: int = 200, 

418 ) -> list: # list[MemoryLink] 

419 """Find links touching any seed memory in either direction.""" 

420 if not seed_memory_ids: 

421 return [] 

422 seeds = set(seed_memory_ids) 

423 type_filter = set(link_types) if link_types else None 

424 out: list = [] 

425 for link in self._memory_links.get(bank_id, []): 

426 if type_filter is not None and link.link_type not in type_filter: 

427 continue 

428 if link.source_memory_id in seeds or link.target_memory_id in seeds: 

429 out.append(link) 

430 if len(out) >= limit: 

431 break 

432 return out 

433 

434 async def get_entity_ids_for_memories( 

435 self, 

436 memory_ids: list[str], 

437 bank_id: str, 

438 ) -> dict[str, list[str]]: 

439 """Return ``{memory_id: [entity_id, ...]}`` from the association table. 

440 

441 Used by the spreading-activation pipeline to seed entity IDs 

442 directly from memory↔entity associations (the most-accurate 

443 path; metadata-based extraction is the fallback). 

444 """ 

445 if not memory_ids: 

446 return {} 

447 target = set(memory_ids) 

448 out: dict[str, list[str]] = {} 

449 for assoc in self._memory_entity_map.get(bank_id, []): 

450 if assoc.memory_id in target: 

451 out.setdefault(assoc.memory_id, []).append(assoc.entity_id) 

452 return out 

453 

454 async def expand_entities_via_links( 

455 self, 

456 entity_ids: list[str], 

457 bank_id: str, 

458 *, 

459 max_hops: int = 2, 

460 link_types: list[str] | None = None, 

461 ) -> dict[str, int]: 

462 """BFS over entity-to-entity links to ``max_hops`` distance. 

463 

464 Returns ``{entity_id: hop_distance}`` where ``hop_distance == 0`` 

465 for entities in the input set, ``1`` for direct neighbors, etc. 

466 

467 Used by the spreading-activation pipeline to walk 

468 ``co_occurs`` links from seed entities to their graph 

469 neighborhood. Mirrors the AGE adapter's recursive-CTE 

470 implementation so behavior is consistent across stores. 

471 """ 

472 if not entity_ids: 

473 return {} 

474 accepted_types = set(link_types or ["co_occurs"]) 

475 bank_links = self._links.get(bank_id, []) 

476 

477 # Pre-build a neighbor map keyed by entity_id for O(1) lookup 

478 # during the BFS frontier expansion. 

479 neighbors: dict[str, set[str]] = {} 

480 for link in bank_links: 

481 if link.link_type not in accepted_types: 

482 continue 

483 neighbors.setdefault(link.entity_a, set()).add(link.entity_b) 

484 neighbors.setdefault(link.entity_b, set()).add(link.entity_a) 

485 

486 distances: dict[str, int] = {eid: 0 for eid in entity_ids} 

487 frontier = set(entity_ids) 

488 for hop in range(1, max(1, max_hops) + 1): 

489 next_frontier: set[str] = set() 

490 for eid in frontier: 

491 for nb in neighbors.get(eid, ()): 

492 if nb not in distances: 

493 distances[nb] = hop 

494 next_frontier.add(nb) 

495 if not next_frontier: 

496 break 

497 frontier = next_frontier 

498 return distances 

499 

500 async def increment_mention_counts( 

501 self, 

502 entity_ids: list[str], 

503 bank_id: str, 

504 ) -> None: 

505 """Bump ``mention_count`` for each canonical entity ID by 1. 

506 

507 Mirrors the AGE adapter's :meth:`increment_mention_counts` for 

508 in-memory tests of the resolver's mention-count flow. 

509 """ 

510 bank_entities = self._entities.get(bank_id, {}) 

511 for eid in entity_ids: 

512 entity = bank_entities.get(eid) 

513 if entity is None: 

514 continue 

515 # Entity is a frozen-ish dataclass — replace via dataclasses.replace 

516 # to avoid mutating the original object held elsewhere. 

517 from dataclasses import replace as _dc_replace 

518 

519 bank_entities[eid] = _dc_replace( 

520 entity, 

521 mention_count=int(getattr(entity, "mention_count", 1)) + 1, 

522 ) 

523 

524 async def store_entity_link(self, link: EntityLink, bank_id: str) -> str: 

525 """Store a single resolved entity link (M11 entity resolution).""" 

526 if bank_id not in self._links: 

527 self._links[bank_id] = [] 

528 lid = uuid.uuid4().hex[:12] 

529 self._links[bank_id].append(link) 

530 return lid 

531 

532 async def health(self) -> HealthStatus: 

533 return HealthStatus(healthy=True, message="in-memory graph store") 

534 

535 

536# --------------------------------------------------------------------------- 

537# In-memory Document Store 

538# --------------------------------------------------------------------------- 

539 

540 

541class InMemoryDocumentStore: 

542 """Fully functional in-memory document store for testing.""" 

543 

544 SPI_VERSION: ClassVar[int] = 1 

545 

546 def __init__(self) -> None: 

547 self._docs: dict[str, tuple[Document, str]] = {} # id -> (doc, bank_id) 

548 

549 async def store_document(self, document: Document, bank_id: str) -> str: 

550 self._docs[document.id] = (document, bank_id) 

551 return document.id 

552 

553 async def search_fulltext( 

554 self, 

555 query: str, 

556 bank_id: str, 

557 limit: int = 10, 

558 filters: DocumentFilters | None = None, 

559 ) -> list[DocumentHit]: 

560 query_terms = set(query.lower().split()) 

561 results: list[tuple[float, Document]] = [] 

562 for doc, bid in self._docs.values(): 

563 if bid != bank_id: 

564 continue 

565 doc_terms = set(doc.text.lower().split()) 

566 overlap = len(query_terms & doc_terms) 

567 if overlap > 0: 

568 score = overlap / max(len(query_terms), 1) 

569 results.append((score, doc)) 

570 

571 results.sort(key=lambda x: x[0], reverse=True) 

572 return [ 

573 DocumentHit(document_id=doc.id, text=doc.text, score=score, metadata=doc.metadata) 

574 for score, doc in results[:limit] 

575 ] 

576 

577 async def get_document(self, document_id: str, bank_id: str) -> Document | None: 

578 entry = self._docs.get(document_id) 

579 if entry and entry[1] == bank_id: 

580 return entry[0] 

581 return None 

582 

583 async def health(self) -> HealthStatus: 

584 return HealthStatus(healthy=True, message="in-memory document store") 

585 

586 

587# --------------------------------------------------------------------------- 

588# In-memory Engine Provider 

589# --------------------------------------------------------------------------- 

590 

591 

592class InMemoryEngineProvider: 

593 """Fully functional in-memory engine provider for testing.""" 

594 

595 SPI_VERSION: ClassVar[int] = 1 

596 

597 def __init__(self, supports_reflect: bool = True, supports_forget: bool = True) -> None: 

598 self._memories: dict[str, list[MemoryHit]] = {} # bank_id -> memories 

599 self._supports_reflect = supports_reflect 

600 self._supports_forget = supports_forget 

601 

602 def capabilities(self) -> EngineCapabilities: 

603 return EngineCapabilities( 

604 supports_reflect=self._supports_reflect, 

605 supports_forget=self._supports_forget, 

606 supports_semantic_search=True, 

607 supports_tags=True, 

608 supports_metadata=True, 

609 ) 

610 

611 async def health(self) -> HealthStatus: 

612 return HealthStatus(healthy=True, message="in-memory engine") 

613 

614 async def retain(self, request: RetainRequest) -> RetainResult: 

615 if request.bank_id not in self._memories: 

616 self._memories[request.bank_id] = [] 

617 mem_id = uuid.uuid4().hex[:16] 

618 # Stamp _created_at so MIP forget min_age_days can be enforced. Existing 

619 # callers that already set _created_at (eg. import/replay) win. 

620 meta = dict(request.metadata) if request.metadata else {} 

621 meta.setdefault("_created_at", datetime.now(UTC).isoformat()) 

622 self._memories[request.bank_id].append( 

623 MemoryHit( 

624 text=request.content, 

625 score=1.0, 

626 fact_type="world", 

627 metadata=meta, 

628 tags=request.tags, 

629 occurred_at=request.occurred_at, 

630 source=request.source, 

631 memory_id=mem_id, 

632 bank_id=request.bank_id, 

633 ) 

634 ) 

635 return RetainResult(stored=True, memory_id=mem_id) 

636 

637 async def recall(self, request: RecallRequest) -> RecallResult: 

638 memories = self._memories.get(request.bank_id, []) 

639 

640 # Apply filters before scoring 

641 # MIP soft-delete: hide records flagged with `_deleted: true` in metadata. 

642 # The flag is set by `soft_delete()` when forget.mode == "soft" so that 

643 # records survive on disk for audit/restore but disappear from recall. 

644 filtered = [m for m in memories if not (m.metadata and m.metadata.get("_deleted") is True)] 

645 if request.tags: 

646 tag_set = set(request.tags) 

647 filtered = [m for m in filtered if m.tags and tag_set & set(m.tags)] 

648 if request.fact_types: 

649 ft_set = set(request.fact_types) 

650 filtered = [m for m in filtered if m.fact_type in ft_set] 

651 if request.time_range: 

652 t_start, t_end = request.time_range 

653 filtered = [m for m in filtered if m.occurred_at and t_start <= m.occurred_at <= t_end] 

654 

655 # Wildcard query returns all memories (used by export) 

656 if request.query.strip() == "*": 

657 scored = [(1.0, mem) for mem in filtered] 

658 else: 

659 # Simple keyword matching 

660 query_terms = set(request.query.lower().split()) 

661 scored = [] 

662 for mem in filtered: 

663 mem_terms = set(mem.text.lower().split()) 

664 overlap = len(query_terms & mem_terms) 

665 if overlap > 0: 

666 score = overlap / max(len(query_terms), 1) 

667 scored.append((score, mem)) 

668 

669 scored.sort(key=lambda x: x[0], reverse=True) 

670 hits = [ 

671 MemoryHit( 

672 text=mem.text, 

673 score=score, 

674 fact_type=mem.fact_type, 

675 metadata=mem.metadata, 

676 tags=mem.tags, 

677 occurred_at=mem.occurred_at, 

678 source=mem.source, 

679 memory_id=mem.memory_id, 

680 bank_id=request.bank_id, 

681 ) 

682 for score, mem in scored[: request.max_results] 

683 ] 

684 return RecallResult(hits=hits, total_available=len(scored), truncated=len(scored) > request.max_results) 

685 

686 async def reflect(self, request: ReflectRequest) -> ReflectResult: 

687 if not self._supports_reflect: 

688 raise NotImplementedError("reflect not supported") 

689 recall_result = await self.recall(RecallRequest(query=request.query, bank_id=request.bank_id)) 

690 answer = "Synthesis: " + "; ".join(h.text for h in recall_result.hits[:5]) 

691 return ReflectResult(answer=answer, sources=recall_result.hits[:5]) 

692 

693 async def soft_delete(self, bank_id: str, memory_ids: list[str]) -> int: 

694 """Mark records as soft-deleted by setting ``_deleted: true`` in metadata. 

695 

696 Used by MIP forget when ``forget.mode == "soft"``: records remain in 

697 the store (auditable, restorable) but are hidden from recall. 

698 Returns the number of records that were transitioned to deleted. 

699 """ 

700 if not memory_ids: 

701 return 0 

702 ids_set = set(memory_ids) 

703 count = 0 

704 for m in self._memories.get(bank_id, []): 

705 if m.memory_id in ids_set: 

706 meta = dict(m.metadata) if m.metadata else {} 

707 if meta.get("_deleted") is True: 

708 continue # already soft-deleted 

709 meta["_deleted"] = True 

710 m.metadata = meta 

711 count += 1 

712 return count 

713 

714 async def forget(self, request: ForgetRequest) -> ForgetResult: 

715 if not self._supports_forget: 

716 raise NotImplementedError("forget not supported") 

717 bank_memories = self._memories.get(request.bank_id, []) 

718 

719 if request.scope == "all": 

720 count = len(bank_memories) 

721 self._memories[request.bank_id] = [] 

722 return ForgetResult(deleted_count=count) 

723 

724 if request.memory_ids: 

725 ids_set = set(request.memory_ids) 

726 before = len(bank_memories) 

727 bank_memories = [m for m in bank_memories if m.memory_id not in ids_set] 

728 self._memories[request.bank_id] = bank_memories 

729 return ForgetResult(deleted_count=before - len(bank_memories)) 

730 

731 # Tag-based and/or date-based deletion 

732 if request.tags or request.before_date: 

733 before = len(bank_memories) 

734 tag_set = set(request.tags) if request.tags else None 

735 

736 def _should_keep(m: MemoryHit) -> bool: 

737 if tag_set and m.tags and tag_set & set(m.tags): 

738 return False 

739 if request.before_date and m.occurred_at and m.occurred_at < request.before_date: 

740 return False 

741 return True 

742 

743 bank_memories = [m for m in bank_memories if _should_keep(m)] 

744 self._memories[request.bank_id] = bank_memories 

745 return ForgetResult(deleted_count=before - len(bank_memories)) 

746 

747 return ForgetResult(deleted_count=0) 

748 

749 

750# --------------------------------------------------------------------------- 

751# In-memory Wiki Store (M8) 

752# --------------------------------------------------------------------------- 

753 

754 

755class InMemoryWikiStore: 

756 """Fully functional in-memory wiki store for testing (M8). 

757 

758 Stores current-revision pages and an audit log of past revisions. 

759 Vector embeddings of pages are managed separately by the VectorStore 

760 (stored with ``memory_layer="compiled"``); this store holds only the 

761 structured WikiPage metadata. 

762 """ 

763 

764 SPI_VERSION: ClassVar[int] = 1 

765 

766 def __init__(self) -> None: 

767 # bank_id → {page_id → WikiPage} (current revisions) 

768 self._pages: dict[str, dict[str, WikiPage]] = {} 

769 # bank_id → {page_id → list[WikiPage]} (past revisions, newest last) 

770 self._history: dict[str, dict[str, list[WikiPage]]] = {} 

771 

772 async def upsert_page(self, page: WikiPage, bank_id: str) -> str: 

773 if bank_id not in self._pages: 

774 self._pages[bank_id] = {} 

775 self._history[bank_id] = {} 

776 

777 existing = self._pages[bank_id].get(page.page_id) 

778 if existing is not None: 

779 # Archive current revision before replacing 

780 self._history[bank_id].setdefault(page.page_id, []).append(existing) 

781 # Increment revision on the incoming page 

782 from dataclasses import replace as _replace 

783 

784 page = _replace(page, revision=existing.revision + 1) 

785 

786 self._pages[bank_id][page.page_id] = page 

787 return page.page_id 

788 

789 async def get_page(self, page_id: str, bank_id: str) -> WikiPage | None: 

790 return self._pages.get(bank_id, {}).get(page_id) 

791 

792 async def list_pages( 

793 self, 

794 bank_id: str, 

795 scope: str | None = None, 

796 kind: str | None = None, 

797 ) -> list[WikiPage]: 

798 pages = list(self._pages.get(bank_id, {}).values()) 

799 if scope is not None: 

800 pages = [p for p in pages if p.scope == scope] 

801 if kind is not None: 

802 pages = [p for p in pages if p.kind == kind] 

803 return pages 

804 

805 async def delete_page(self, page_id: str, bank_id: str) -> bool: 

806 bank_pages = self._pages.get(bank_id, {}) 

807 if page_id not in bank_pages: 

808 return False 

809 del bank_pages[page_id] 

810 self._history.get(bank_id, {}).pop(page_id, None) 

811 return True 

812 

813 async def health(self) -> HealthStatus: 

814 return HealthStatus(healthy=True, message="in-memory wiki store") 

815 

816 def revision_history(self, page_id: str, bank_id: str) -> list[WikiPage]: 

817 """Return past revisions for a page (oldest first). Testing helper.""" 

818 return list(self._history.get(bank_id, {}).get(page_id, [])) 

819 

820 

821# --------------------------------------------------------------------------- 

822# In-Memory PageIndex Store (M9 — section recall; see ADR-006/007) 

823# --------------------------------------------------------------------------- 

824 

825 

826class InMemoryPageIndexStore: 

827 """Fully functional in-memory PageIndex store for testing (M9 section recall). 

828 

829 Holds documents (one per conversation), sections (tree nodes), entity 

830 mentions per section, and section-to-section links. PR1 only exercises 

831 the document/section path; PR2 populates entities and links. 

832 

833 See :class:`~astrocyte.provider.PageIndexStore` for the SPI. 

834 """ 

835 

836 SPI_VERSION: ClassVar[int] = 1 

837 

838 def __init__(self) -> None: 

839 # (bank_id, source_id) → PageIndexDocument 

840 self._documents: dict[tuple[str, str], PageIndexDocument] = {} 

841 # document_id → list[PageIndexSection], ordered by line_num 

842 self._sections: dict[str, list[PageIndexSection]] = {} 

843 # document_id → list[PageIndexSectionEntity] 

844 self._section_entities: dict[str, list[PageIndexSectionEntity]] = {} 

845 # document_id → list[PageIndexSectionLink] 

846 self._section_links: dict[str, list[PageIndexSectionLink]] = {} 

847 # M10.1: bank_id → list[(WikiPage, embedding)]; provenance: 

848 # page_id → list[(document_id, line_num)] 

849 self._wiki_pages: dict[str, list[tuple["WikiPage", list[float] | None]]] = {} 

850 self._wiki_provenance: dict[str, list[tuple[str, int]]] = {} 

851 # M12.1: document_id → list[PageIndexFact] 

852 self._facts: dict[str, list[PageIndexFact]] = {} 

853 

854 async def save_document(self, doc: PageIndexDocument) -> str: 

855 # Upsert keyed on (bank_id, source_id). Behaviour matches the 

856 # Postgres adapter: 

857 # - If an existing row matches the key, preserve its id so 

858 # child rows (sections, links) remain valid. 

859 # - If no existing row AND the input id is empty/falsy, 

860 # assign a fresh UUID (mirrors gen_random_uuid() default). 

861 # - Otherwise honour the caller-supplied id. 

862 import uuid as _uuid 

863 from dataclasses import replace as _replace 

864 

865 key = (doc.bank_id, doc.source_id) 

866 existing = self._documents.get(key) 

867 if existing is not None: 

868 doc = _replace(doc, id=existing.id) 

869 elif not doc.id: 

870 doc = _replace(doc, id=str(_uuid.uuid4())) 

871 self._documents[key] = doc 

872 return doc.id 

873 

874 async def save_sections( 

875 self, 

876 document_id: str, 

877 sections: list[PageIndexSection], 

878 ) -> int: 

879 # Atomic replace — drop any prior sections for this document 

880 # before writing the new tree. This matches the SQL adapter's 

881 # DELETE + INSERT pattern; partial trees are never observable. 

882 sorted_sections = sorted(sections, key=lambda s: s.line_num) 

883 self._sections[document_id] = sorted_sections 

884 # Cascading wipe of dependent rows mirrors the FK ON DELETE CASCADE 

885 # in migration 015. PR2 will repopulate these. 

886 self._section_entities.pop(document_id, None) 

887 self._section_links.pop(document_id, None) 

888 return len(sorted_sections) 

889 

890 async def load_document( 

891 self, 

892 bank_id: str, 

893 source_id: str, 

894 ) -> PageIndexDocument | None: 

895 return self._documents.get((bank_id, source_id)) 

896 

897 async def load_skeleton(self, document_id: str) -> list[PageIndexSection]: 

898 # Return a shallow copy without summary_embedding to mirror the 

899 # SQL adapter's projection (the picker doesn't need embeddings). 

900 from dataclasses import replace as _replace 

901 

902 return [_replace(s, summary_embedding=None) for s in self._sections.get(document_id, [])] 

903 

904 async def save_section_embeddings( 

905 self, 

906 document_id: str, 

907 embeddings: list[tuple[int, list[float]]], 

908 ) -> int: 

909 # Update existing PageIndexSection rows in place. PR2 commit A 

910 # mirrors the Postgres adapter's UPDATE pattern: skip rows with 

911 # no matching line_num (we don't auto-create sections from 

912 # embeddings — the tree-build step is the source of truth). 

913 from dataclasses import replace as _replace 

914 

915 sections = self._sections.get(document_id, []) 

916 if not sections: 

917 return 0 

918 wanted = dict(embeddings) # line_num → vec 

919 n = 0 

920 new_sections: list[PageIndexSection] = [] 

921 for s in sections: 

922 vec = wanted.pop(s.line_num, None) 

923 if vec is not None: 

924 new_sections.append(_replace(s, summary_embedding=list(vec))) 

925 n += 1 

926 else: 

927 new_sections.append(s) 

928 if n > 0: 

929 self._sections[document_id] = new_sections 

930 return n 

931 

932 async def save_section_entities( 

933 self, 

934 entities: list[PageIndexSectionEntity], 

935 ) -> int: 

936 n = 0 

937 for e in entities: 

938 bucket = self._section_entities.setdefault(e.document_id, []) 

939 # Idempotent on the composite key. 

940 if not any( 

941 existing.line_num == e.line_num and existing.entity_name == e.entity_name for existing in bucket 

942 ): 

943 bucket.append(e) 

944 n += 1 

945 return n 

946 

947 async def save_section_links( 

948 self, 

949 links: list[PageIndexSectionLink], 

950 ) -> int: 

951 n = 0 

952 for link in links: 

953 bucket = self._section_links.setdefault(link.from_doc, []) 

954 key = (link.from_line, link.to_doc, link.to_line, link.link_type) 

955 if not any( 

956 (existing.from_line, existing.to_doc, existing.to_line, existing.link_type) == key 

957 for existing in bucket 

958 ): 

959 bucket.append(link) 

960 n += 1 

961 return n 

962 

963 async def populate_semantic_knn_links( 

964 self, 

965 document_id: str, 

966 *, 

967 top_k: int = 5, 

968 min_similarity: float = 0.5, 

969 ) -> int: 

970 """PR2 D.7.1 in-memory implementation: pure cosine similarity 

971 over each pair of sections in the document. Mirrors the SQL 

972 adapter's LATERAL kNN — for each section, link to top_k 

973 most-similar OTHER sections (cosine sim ≥ ``min_similarity``).""" 

974 sections = self._sections.get(document_id, []) 

975 embedded = [s for s in sections if s.summary_embedding] 

976 if len(embedded) < 2: 

977 return 0 

978 

979 new_links: list[PageIndexSectionLink] = [] 

980 for s1 in embedded: 

981 scored: list[tuple[float, PageIndexSection]] = [] 

982 for s2 in embedded: 

983 if s2.line_num == s1.line_num: 

984 continue 

985 sim = _cosine_sim(s1.summary_embedding, s2.summary_embedding) 

986 if sim >= min_similarity: 

987 scored.append((sim, s2)) 

988 scored.sort(key=lambda x: x[0], reverse=True) 

989 for sim, s2 in scored[:top_k]: 

990 new_links.append( 

991 PageIndexSectionLink( 

992 from_doc=document_id, 

993 from_line=s1.line_num, 

994 to_doc=document_id, 

995 to_line=s2.line_num, 

996 link_type="semantic_knn", 

997 weight=float(sim), 

998 ) 

999 ) 

1000 # Reuse save_section_links for idempotent insert semantics. 

1001 return await self.save_section_links(new_links) 

1002 

1003 # ── PR2 commit B: parallel-strategy query methods ───────────────── 

1004 

1005 def _docs_in_bank(self, bank_id: str) -> set[str]: 

1006 return {d.id for (b, _src), d in self._documents.items() if b == bank_id} 

1007 

1008 def _section_passes_session(self, doc_id: str, line_num: int, session_filter: str | None) -> bool: 

1009 """M31 Fix 2 — section-level helper (analogue of _matches_session 

1010 for facts). Returns True when ``session_filter`` is None OR 

1011 the section at (doc_id, line_num) has matching session_id. 

1012 """ 

1013 if session_filter is None: 

1014 return True 

1015 for s in self._sections.get(doc_id, []): 

1016 if s.line_num == line_num: 

1017 return s.session_id == session_filter 

1018 return False 

1019 

1020 async def search_sections_semantic( 

1021 self, 

1022 bank_id: str, 

1023 query_embedding: list[float], 

1024 *, 

1025 top_k: int = 20, 

1026 session_filter: str | None = None, 

1027 ) -> list[tuple[str, int, float]]: 

1028 if not query_embedding: 

1029 return [] 

1030 doc_ids = self._docs_in_bank(bank_id) 

1031 scored: list[tuple[str, int, float]] = [] 

1032 for doc_id in doc_ids: 

1033 for s in self._sections.get(doc_id, []): 

1034 if not s.summary_embedding: 

1035 continue 

1036 if session_filter is not None and s.session_id != session_filter: 

1037 continue 

1038 # Cosine similarity (Postgres adapter uses 1 - distance). 

1039 score = _cosine_sim(query_embedding, s.summary_embedding) 

1040 scored.append((doc_id, s.line_num, score)) 

1041 scored.sort(key=lambda x: x[2], reverse=True) 

1042 return scored[:top_k] 

1043 

1044 async def search_sections_keyword( 

1045 self, 

1046 bank_id: str, 

1047 query: str, 

1048 *, 

1049 top_k: int = 20, 

1050 speaker: str | None = None, 

1051 document_id: str | None = None, 

1052 session_filter: str | None = None, 

1053 ) -> list[tuple[str, int, float]]: 

1054 if not query.strip(): 

1055 return [] 

1056 # Lightweight in-memory keyword scoring: count distinct query 

1057 # tokens that appear in title+summary (case-insensitive). 

1058 # Mirrors the spirit of ts_rank_cd's "more matched terms = higher 

1059 # score" without re-implementing the full tsvector machinery. 

1060 terms = [t for t in query.lower().split() if t] 

1061 if not terms: 

1062 return [] 

1063 if document_id is not None: 

1064 # Caller wants single-doc scope (PR2.6 temporal-arithmetic 

1065 # find_event_date use case); skip the bank fanout entirely. 

1066 doc_ids = [document_id] if document_id in self._sections else [] 

1067 else: 

1068 doc_ids = self._docs_in_bank(bank_id) 

1069 scored: list[tuple[str, int, float]] = [] 

1070 for doc_id in doc_ids: 

1071 for s in self._sections.get(doc_id, []): 

1072 if speaker is not None and s.speaker != speaker: 

1073 continue 

1074 if session_filter is not None and s.session_id != session_filter: 

1075 continue 

1076 haystack = (s.title or "").lower() + " " + (s.summary or "").lower() 

1077 hits = sum(1 for t in terms if t in haystack) 

1078 if hits > 0: 

1079 scored.append((doc_id, s.line_num, float(hits))) 

1080 scored.sort(key=lambda x: x[2], reverse=True) 

1081 return scored[:top_k] 

1082 

1083 async def search_sections_by_entities( 

1084 self, 

1085 bank_id: str, 

1086 entity_names: list[str], 

1087 *, 

1088 top_k: int = 20, 

1089 session_filter: str | None = None, 

1090 ) -> list[tuple[str, int, float]]: 

1091 if not entity_names: 

1092 return [] 

1093 # Hindsight's CTE pattern: count distinct matching entities per 

1094 # section. Case-insensitive match on entity_name. 

1095 wanted = {n.lower() for n in entity_names if n and n.strip()} 

1096 if not wanted: 

1097 return [] 

1098 doc_ids = self._docs_in_bank(bank_id) 

1099 per_section: dict[tuple[str, int], set[str]] = {} 

1100 for doc_id in doc_ids: 

1101 for e in self._section_entities.get(doc_id, []): 

1102 low = e.entity_name.lower() 

1103 if low in wanted: 

1104 if not self._section_passes_session(doc_id, e.line_num, session_filter): 

1105 continue 

1106 per_section.setdefault((doc_id, e.line_num), set()).add(low) 

1107 scored = [(doc_id, line_num, float(len(matches))) for (doc_id, line_num), matches in per_section.items()] 

1108 scored.sort(key=lambda x: x[2], reverse=True) 

1109 return scored[:top_k] 

1110 

1111 async def search_sections_temporal( 

1112 self, 

1113 bank_id: str, 

1114 date_range, 

1115 *, 

1116 top_k: int = 20, 

1117 session_filter: str | None = None, 

1118 ) -> list[tuple[str, int, float]]: 

1119 start, end = date_range 

1120 doc_ids = self._docs_in_bank(bank_id) 

1121 scored: list[tuple[str, int, float]] = [] 

1122 for doc_id in doc_ids: 

1123 for s in self._sections.get(doc_id, []): 

1124 if session_filter is not None and s.session_id != session_filter: 

1125 continue 

1126 sd = s.session_date 

1127 if sd is not None and start <= sd <= end: 

1128 scored.append((doc_id, s.line_num, 1.0)) 

1129 scored.sort(key=lambda x: (x[0], x[1])) # stable order 

1130 return scored[:top_k] 

1131 

1132 async def expand_section_links( 

1133 self, 

1134 seeds: list[tuple[str, int]], 

1135 *, 

1136 link_types: list[str] | None = None, 

1137 top_k: int = 20, 

1138 ) -> list[tuple[str, int, float]]: 

1139 if not seeds: 

1140 return [] 

1141 seed_set = {(d, ln) for d, ln in seeds} 

1142 type_filter = set(link_types) if link_types else None 

1143 weights: dict[tuple[str, int], float] = {} 

1144 for doc_id, links in self._section_links.items(): 

1145 for link in links: 

1146 if type_filter is not None and link.link_type not in type_filter: 

1147 continue 

1148 # Outgoing edge from a seed 

1149 if (link.from_doc, link.from_line) in seed_set: 

1150 key = (link.to_doc, link.to_line) 

1151 weights[key] = weights.get(key, 0.0) + link.weight 

1152 # Incoming edge to a seed 

1153 if (link.to_doc, link.to_line) in seed_set: 

1154 key = (link.from_doc, link.from_line) 

1155 weights[key] = weights.get(key, 0.0) + link.weight 

1156 scored = [(d, ln, w) for (d, ln), w in weights.items()] 

1157 scored.sort(key=lambda x: x[2], reverse=True) 

1158 return scored[:top_k] 

1159 

1160 async def expand_sections_by_shared_entities( 

1161 self, 

1162 bank_id: str, 

1163 seeds: list[tuple[str, int]], 

1164 *, 

1165 top_k: int = 20, 

1166 exclude_seeds: bool = True, 

1167 ) -> list[tuple[str, int, float]]: 

1168 if not seeds: 

1169 return [] 

1170 seed_set = {(d, ln) for d, ln in seeds} 

1171 doc_ids = self._docs_in_bank(bank_id) 

1172 # Collect seed entity names (case-insensitive set). 

1173 seed_entities: set[str] = set() 

1174 for doc_id, line_num in seed_set: 

1175 for e in self._section_entities.get(doc_id, []): 

1176 if e.line_num == line_num: 

1177 seed_entities.add(e.entity_name.lower()) 

1178 if not seed_entities: 

1179 return [] 

1180 # For every section in the bank, count distinct overlap with seed_entities. 

1181 per_section: dict[tuple[str, int], set[str]] = {} 

1182 for doc_id in doc_ids: 

1183 for e in self._section_entities.get(doc_id, []): 

1184 low = e.entity_name.lower() 

1185 if low not in seed_entities: 

1186 continue 

1187 key = (doc_id, e.line_num) 

1188 if exclude_seeds and key in seed_set: 

1189 continue 

1190 per_section.setdefault(key, set()).add(low) 

1191 scored = [(doc_id, line_num, float(len(matches))) for (doc_id, line_num), matches in per_section.items()] 

1192 scored.sort(key=lambda x: x[2], reverse=True) 

1193 return scored[:top_k] 

1194 

1195 async def list_distinct_entities( 

1196 self, 

1197 bank_id: str, 

1198 document_id: str, 

1199 *, 

1200 pattern: str | None = None, 

1201 limit: int = 50, 

1202 ) -> list[tuple[str, int]]: 

1203 # Filter section_entities to this document; substring-match the 

1204 # entity_name when pattern is set (case-insensitive). Mirrors 

1205 # the Postgres ILIKE semantics: pattern is a substring, ``%`` 

1206 # wildcards are NOT auto-injected — caller passes them if 

1207 # desired (matches Postgres behaviour). 

1208 entries = self._section_entities.get(document_id, []) 

1209 if not entries: 

1210 return [] 

1211 if pattern is not None: 

1212 needle = pattern.lower().replace("%", "") 

1213 entries = [e for e in entries if needle in e.entity_name.lower()] 

1214 # Distinct (line_num, entity_name) collapses to one mention per 

1215 # section — matches the Postgres PK (document_id, line_num, 

1216 # entity_name). 

1217 seen: set[tuple[int, str]] = set() 

1218 counts: dict[str, int] = {} 

1219 for e in entries: 

1220 key = (e.line_num, e.entity_name) 

1221 if key in seen: 

1222 continue 

1223 seen.add(key) 

1224 counts[e.entity_name] = counts.get(e.entity_name, 0) + 1 

1225 # Order by count desc, then name asc for determinism. 

1226 ordered = sorted(counts.items(), key=lambda kv: (-kv[1], kv[0])) 

1227 return ordered[: max(1, limit)] 

1228 

1229 # ── M12.1 fact-grain ─────────────────────────────────────────── 

1230 

1231 async def save_facts(self, facts) -> int: 

1232 if not facts: 

1233 return 0 

1234 for f in facts: 

1235 self._facts.setdefault(f.document_id, []).append(f) 

1236 return len(facts) 

1237 

1238 async def update_fact_embeddings(self, embeddings) -> int: 

1239 from dataclasses import replace as _replace 

1240 

1241 if not embeddings: 

1242 return 0 

1243 emb_by_id = dict(embeddings) 

1244 updated = 0 

1245 for doc_id, bucket in self._facts.items(): 

1246 for i, f in enumerate(bucket): 

1247 if f.id in emb_by_id: 

1248 bucket[i] = _replace(f, embedding=emb_by_id[f.id]) 

1249 updated += 1 

1250 return updated 

1251 

1252 def _matches_session(self, fact: "PageIndexFact", session_filter: str | None) -> bool: 

1253 """M31 Fix 2 — InMemory analogue of the Postgres EXISTS clause. 

1254 

1255 Returns True when the fact's anchoring section's session_id matches 

1256 ``session_filter``. Top-level facts (no anchor) and facts whose 

1257 section has no session_id both fail the filter — intentional, since 

1258 ``session_filter`` semantics is "limit to this session's content". 

1259 """ 

1260 if session_filter is None: 

1261 return True 

1262 if fact.document_id is None or fact.line_num is None: 

1263 return False 

1264 for s in self._sections.get(fact.document_id, []): 

1265 if s.line_num == fact.line_num: 

1266 return s.session_id == session_filter 

1267 return False 

1268 

1269 async def search_facts_semantic( 

1270 self, 

1271 bank_id: str, 

1272 query_embedding: list[float], 

1273 *, 

1274 top_k: int = 20, 

1275 document_id: str | None = None, 

1276 fact_type: str | None = None, 

1277 session_filter: str | None = None, 

1278 ): 

1279 from astrocyte.types import PageIndexFactHit 

1280 

1281 if not query_embedding: 

1282 return [] 

1283 

1284 def _cos(a, b): 

1285 import math 

1286 

1287 dot = sum(x * y for x, y in zip(a, b)) 

1288 na = math.sqrt(sum(x * x for x in a)) 

1289 nb = math.sqrt(sum(x * x for x in b)) 

1290 return dot / (na * nb) if na > 0 and nb > 0 else 0.0 

1291 

1292 scored: list[tuple[float, PageIndexFact]] = [] 

1293 scope = ( 

1294 self._facts.get(document_id, []) if document_id else [f for bucket in self._facts.values() for f in bucket] 

1295 ) 

1296 for f in scope: 

1297 if f.bank_id != bank_id: 

1298 continue 

1299 if fact_type is not None and f.fact_type != fact_type: 

1300 continue 

1301 if not f.embedding: 

1302 continue 

1303 if not self._matches_session(f, session_filter): # M31 Fix 2 

1304 continue 

1305 scored.append((_cos(query_embedding, f.embedding), f)) 

1306 scored.sort(key=lambda kv: kv[0], reverse=True) 

1307 return [ 

1308 PageIndexFactHit( 

1309 fact_id=f.id, 

1310 document_id=f.document_id, 

1311 line_num=f.line_num, 

1312 text=f.text, 

1313 fact_type=f.fact_type, 

1314 speaker=f.speaker, 

1315 occurred_start=f.occurred_start, 

1316 occurred_end=f.occurred_end, 

1317 entities=list(f.entities or []), 

1318 confidence_score=getattr(f, "confidence_score", None), # M27 / M28-A 

1319 mentioned_at=getattr(f, "mentioned_at", None), # M27 

1320 event_date=getattr(f, "event_date", None), # M31 Fix 4 

1321 score=float(s), 

1322 ) 

1323 for s, f in scored[: max(1, top_k)] 

1324 ] 

1325 

1326 async def search_facts_by_entity( 

1327 self, 

1328 bank_id: str, 

1329 entity_name: str, 

1330 *, 

1331 top_k: int = 50, 

1332 document_id: str | None = None, 

1333 fact_type: str | None = None, # M34-3 — per-fact-type segmentation 

1334 session_filter: str | None = None, 

1335 ): 

1336 from astrocyte.types import PageIndexFactHit 

1337 

1338 needle = entity_name.casefold() 

1339 scope = ( 

1340 self._facts.get(document_id, []) if document_id else [f for bucket in self._facts.values() for f in bucket] 

1341 ) 

1342 hits: list["PageIndexFactHit"] = [] 

1343 for f in scope: 

1344 if f.bank_id != bank_id: 

1345 continue 

1346 if fact_type is not None and f.fact_type != fact_type: 

1347 continue 

1348 if not any(needle in (e or "").casefold() for e in (f.entities or [])): 

1349 continue 

1350 if not self._matches_session(f, session_filter): # M31 Fix 2 

1351 continue 

1352 hits.append( 

1353 PageIndexFactHit( 

1354 fact_id=f.id, 

1355 document_id=f.document_id, 

1356 line_num=f.line_num, 

1357 text=f.text, 

1358 fact_type=f.fact_type, 

1359 speaker=f.speaker, 

1360 occurred_start=f.occurred_start, 

1361 occurred_end=f.occurred_end, 

1362 entities=list(f.entities or []), 

1363 confidence_score=getattr(f, "confidence_score", None), 

1364 mentioned_at=getattr(f, "mentioned_at", None), 

1365 event_date=getattr(f, "event_date", None), # M31 Fix 4 

1366 score=1.0, 

1367 ) 

1368 ) 

1369 if len(hits) >= top_k: 

1370 break 

1371 return hits 

1372 

1373 async def search_facts_keyword( 

1374 self, 

1375 bank_id: str, 

1376 query: str, 

1377 *, 

1378 top_k: int = 20, 

1379 document_id: str | None = None, 

1380 fact_type: str | None = None, # M34-3 — per-fact-type segmentation 

1381 session_filter: str | None = None, 

1382 ): 

1383 """M31c — token-overlap analogue of Postgres BM25 over fact_text. 

1384 

1385 Mirrors the spirit of ``search_sections_keyword``: lower-cases 

1386 both query and fact text, splits on whitespace, counts how 

1387 many distinct query tokens appear as substrings of the fact 

1388 text, and ranks by that count desc. Not BM25 proper but the 

1389 same family of "literal-keyword-match wins over generic 

1390 thematic match" behaviour that fixes the SSU bench failures. 

1391 """ 

1392 from astrocyte.types import PageIndexFactHit 

1393 

1394 if not query or not query.strip(): 

1395 return [] 

1396 terms = [t for t in query.lower().split() if t and len(t) > 1] 

1397 if not terms: 

1398 return [] 

1399 terms = list(dict.fromkeys(terms)) # dedupe, preserve order 

1400 scope = ( 

1401 self._facts.get(document_id, []) 

1402 if document_id 

1403 else [f for bucket in self._facts.values() for f in bucket] 

1404 ) 

1405 scored: list[tuple[float, "PageIndexFact"]] = [] 

1406 for f in scope: 

1407 if f.bank_id != bank_id: 

1408 continue 

1409 if fact_type is not None and f.fact_type != fact_type: 

1410 continue 

1411 if not self._matches_session(f, session_filter): # M31 Fix 2 

1412 continue 

1413 haystack = (f.text or "").lower() 

1414 hits = sum(1 for t in terms if t in haystack) 

1415 if hits > 0: 

1416 scored.append((float(hits), f)) 

1417 scored.sort(key=lambda kv: kv[0], reverse=True) 

1418 return [ 

1419 PageIndexFactHit( 

1420 fact_id=f.id, 

1421 document_id=f.document_id, 

1422 line_num=f.line_num, 

1423 text=f.text, 

1424 fact_type=f.fact_type, 

1425 speaker=f.speaker, 

1426 occurred_start=f.occurred_start, 

1427 occurred_end=f.occurred_end, 

1428 entities=list(f.entities or []), 

1429 confidence_score=getattr(f, "confidence_score", None), 

1430 mentioned_at=getattr(f, "mentioned_at", None), 

1431 event_date=getattr(f, "event_date", None), 

1432 score=float(s), 

1433 ) 

1434 for s, f in scored[: max(1, top_k)] 

1435 ] 

1436 

1437 async def search_facts_temporal( 

1438 self, 

1439 bank_id: str, 

1440 date_range, 

1441 *, 

1442 top_k: int = 50, 

1443 document_id: str | None = None, 

1444 fact_type: str | None = None, # M34-3 — per-fact-type segmentation 

1445 session_filter: str | None = None, 

1446 ): 

1447 from astrocyte.types import PageIndexFactHit 

1448 

1449 start, end = date_range 

1450 scope = ( 

1451 self._facts.get(document_id, []) if document_id else [f for bucket in self._facts.values() for f in bucket] 

1452 ) 

1453 

1454 # M33-3b — Hindsight-parity 4-way OR (mirrors the Postgres SQL): 

1455 # a fact qualifies if ANY of {spanning occurred range overlap, 

1456 # mentioned_at in range, occurred_start in range, occurred_end 

1457 # in range}. Sort DESC by freshest available date so the top_k 

1458 # cap retains newest candidates (Hindsight: ORDER BY 

1459 # COALESCE(occurred_start, mentioned_at, occurred_end) DESC). 

1460 def _in_range(d): 

1461 return d is not None and start <= d <= end 

1462 

1463 def _qualifies(f) -> bool: 

1464 os_ = f.occurred_start 

1465 oe_ = f.occurred_end 

1466 ma_ = getattr(f, "mentioned_at", None) 

1467 if os_ is not None and oe_ is not None and os_ <= end and oe_ >= start: 

1468 return True 

1469 return _in_range(ma_) or _in_range(os_) or _in_range(oe_) 

1470 

1471 def _sort_key(f): 

1472 # COALESCE(occurred_start, mentioned_at, occurred_end). 

1473 # Push None to the bottom by pairing with a falsy flag. 

1474 d = f.occurred_start or getattr(f, "mentioned_at", None) or f.occurred_end 

1475 return (d is None, d) 

1476 

1477 matched = [ 

1478 f 

1479 for f in scope 

1480 if f.bank_id == bank_id 

1481 and (fact_type is None or f.fact_type == fact_type) # M34-3 

1482 and _qualifies(f) 

1483 and self._matches_session(f, session_filter) # M31 Fix 2 

1484 ] 

1485 matched.sort(key=_sort_key, reverse=True) 

1486 hits: list["PageIndexFactHit"] = [] 

1487 for f in matched[:top_k]: 

1488 hits.append( 

1489 PageIndexFactHit( 

1490 fact_id=f.id, 

1491 document_id=f.document_id, 

1492 line_num=f.line_num, 

1493 text=f.text, 

1494 fact_type=f.fact_type, 

1495 speaker=f.speaker, 

1496 occurred_start=f.occurred_start, 

1497 occurred_end=f.occurred_end, 

1498 entities=list(f.entities or []), 

1499 confidence_score=getattr(f, "confidence_score", None), 

1500 mentioned_at=getattr(f, "mentioned_at", None), 

1501 event_date=getattr(f, "event_date", None), # M31 Fix 4 

1502 score=1.0, 

1503 ) 

1504 ) 

1505 return hits 

1506 

1507 async def count_facts_matching( 

1508 self, 

1509 bank_id: str, 

1510 document_id: str, 

1511 *, 

1512 entity_pattern: str | None = None, 

1513 fact_type: str | None = None, 

1514 ) -> int: 

1515 count = 0 

1516 for f in self._facts.get(document_id, []): 

1517 if f.bank_id != bank_id: 

1518 continue 

1519 if fact_type is not None and f.fact_type != fact_type: 

1520 continue 

1521 if entity_pattern is not None: 

1522 needle = entity_pattern.lower().replace("%", "") 

1523 if not any(needle in (e or "").lower() for e in (f.entities or [])): 

1524 continue 

1525 count += 1 

1526 return count 

1527 

1528 async def save_section_event_dates( 

1529 self, 

1530 document_id: str, 

1531 event_dates, 

1532 ) -> int: 

1533 if not event_dates: 

1534 return 0 

1535 from dataclasses import replace as _replace 

1536 

1537 sections = self._sections.get(document_id) or [] 

1538 if not sections: 

1539 return 0 

1540 date_by_line = {line_num: (start, end) for line_num, start, end in event_dates} 

1541 updated = 0 

1542 new_list = [] 

1543 for s in sections: 

1544 if s.line_num in date_by_line: 

1545 start, end = date_by_line[s.line_num] 

1546 new_list.append(_replace(s, occurred_start=start, occurred_end=end)) 

1547 updated += 1 

1548 else: 

1549 new_list.append(s) 

1550 self._sections[document_id] = new_list 

1551 return updated 

1552 

1553 # ── M10.1 wiki / consolidation ───────────────────────────────── 

1554 

1555 async def load_sections_with_embeddings( 

1556 self, 

1557 bank_id: str, 

1558 document_id: str, 

1559 ) -> list[PageIndexSection]: 

1560 del bank_id # documents are bank-scoped via _documents key 

1561 return list(self._sections.get(document_id, [])) 

1562 

1563 async def save_wiki_page( 

1564 self, 

1565 *, 

1566 page, # WikiPage 

1567 embedding: list[float] | None, 

1568 provenance: list[tuple[str, int]], 

1569 ) -> str: 

1570 bucket = self._wiki_pages.setdefault(page.bank_id, []) 

1571 # Upsert: replace any existing page with the same page_id. 

1572 bucket[:] = [(p, e) for (p, e) in bucket if p.page_id != page.page_id] 

1573 bucket.append((page, embedding)) 

1574 self._wiki_provenance[page.page_id] = list(provenance) 

1575 return page.page_id 

1576 

1577 async def search_wiki_pages_semantic( 

1578 self, 

1579 bank_id: str, 

1580 query_embedding: list[float], 

1581 *, 

1582 top_k: int = 5, 

1583 document_id: str | None = None, 

1584 ): 

1585 from astrocyte.types import WikiPageHit 

1586 

1587 bucket = self._wiki_pages.get(bank_id, []) 

1588 if not bucket or not query_embedding: 

1589 return [] 

1590 scope_filter = f"document:{document_id}" if document_id is not None else None 

1591 

1592 def _cos(a: list[float], b: list[float]) -> float: 

1593 import math 

1594 

1595 dot = sum(x * y for x, y in zip(a, b)) 

1596 na = math.sqrt(sum(x * x for x in a)) 

1597 nb = math.sqrt(sum(x * x for x in b)) 

1598 if na == 0 or nb == 0: 

1599 return 0.0 

1600 return dot / (na * nb) 

1601 

1602 scored: list[tuple[float, WikiPage]] = [] 

1603 for page, emb in bucket: 

1604 if scope_filter is not None and page.scope != scope_filter: 

1605 continue 

1606 if not emb: 

1607 continue 

1608 scored.append((_cos(query_embedding, emb), page)) 

1609 scored.sort(key=lambda kv: kv[0], reverse=True) 

1610 out: list[WikiPageHit] = [] 

1611 for score, page in scored[: max(1, top_k)]: 

1612 out.append( 

1613 WikiPageHit( 

1614 page_id=page.page_id, 

1615 title=page.title, 

1616 content=page.content, 

1617 scope=page.scope, 

1618 kind=page.kind, 

1619 score=float(score), 

1620 source_ids=list(page.source_ids), 

1621 bank_id=page.bank_id, 

1622 ) 

1623 ) 

1624 return out 

1625 

1626 async def count_wiki_pages_for_doc( 

1627 self, 

1628 bank_id: str, 

1629 document_id: str, 

1630 ) -> int: 

1631 bucket = self._wiki_pages.get(bank_id, []) 

1632 scope = f"document:{document_id}" 

1633 return sum(1 for p, _ in bucket if p.scope == scope) 

1634 

1635 async def list_wiki_pages_for_doc( 

1636 self, 

1637 bank_id: str, 

1638 document_id: str, 

1639 ) -> list[WikiPage]: 

1640 bucket = self._wiki_pages.get(bank_id, []) 

1641 scope = f"document:{document_id}" 

1642 return [p for p, _ in bucket if p.scope == scope] 

1643 

1644 async def list_wikis_affected_by_entities( 

1645 self, 

1646 bank_id: str, 

1647 entities: list[str], 

1648 *, 

1649 min_overlap: int = 1, 

1650 limit: int = 8, 

1651 ) -> list[tuple[WikiPage, int, list[str]]]: 

1652 """M14.2: scan in-memory wiki provenance for entity overlap. 

1653 

1654 Mirrors the Postgres JOIN — for each wiki page in the bank, 

1655 intersect its provenance sections' entity sets with the input 

1656 entities. Returns the fully-hydrated ``WikiPage`` row, the 

1657 overlap count, and the shared entity names — sorted descending 

1658 by overlap count with page_id ascending as a tie-break. 

1659 """ 

1660 if not entities: 

1661 return [] 

1662 query_set = set(entities) 

1663 bucket = self._wiki_pages.get(bank_id, []) 

1664 results: list[tuple[WikiPage, int, list[str]]] = [] 

1665 for page, _embedding in bucket: 

1666 prov = self._wiki_provenance.get(page.page_id, []) 

1667 if not prov: 

1668 continue 

1669 wiki_entities: set[str] = set() 

1670 for doc_id, line_num in prov: 

1671 for ent in self._section_entities.get(doc_id, []): 

1672 if ent.line_num == line_num: 

1673 wiki_entities.add(ent.entity_name) 

1674 shared = sorted(wiki_entities & query_set) 

1675 if len(shared) >= min_overlap: 

1676 results.append((page, len(shared), shared)) 

1677 results.sort(key=lambda r: (-r[1], r[0].page_id)) 

1678 return results[:limit] 

1679 

1680 async def health(self) -> HealthStatus: 

1681 return HealthStatus(healthy=True, message="in-memory pageindex store") 

1682 

1683 

1684# --------------------------------------------------------------------------- 

1685# In-Memory Mental Model Store (M9 — first-class, replaces wiki piggyback) 

1686# --------------------------------------------------------------------------- 

1687 

1688 

1689class InMemoryMentalModelStore: 

1690 """Fully functional in-memory mental-model store for testing. 

1691 

1692 Mirrors :class:`InMemoryWikiStore`'s shape but for the dedicated 

1693 :class:`~astrocyte.provider.MentalModelStore` SPI. Holds the current 

1694 revision of each model plus a history list of past revisions. 

1695 """ 

1696 

1697 SPI_VERSION: ClassVar[int] = 1 

1698 

1699 def __init__(self) -> None: 

1700 # bank_id → {model_id → MentalModel} (current revisions) 

1701 self._models: dict[str, dict[str, "MentalModel"]] = {} 

1702 # bank_id → {model_id → list[MentalModel]} (past revisions, oldest first) 

1703 self._history: dict[str, dict[str, list["MentalModel"]]] = {} 

1704 

1705 async def upsert(self, model: "MentalModel", bank_id: str) -> int: 

1706 from dataclasses import replace as _replace 

1707 from datetime import UTC, datetime 

1708 

1709 if bank_id not in self._models: 

1710 self._models[bank_id] = {} 

1711 self._history[bank_id] = {} 

1712 

1713 existing = self._models[bank_id].get(model.model_id) 

1714 new_revision = (existing.revision + 1) if existing is not None else 1 

1715 if existing is not None: 

1716 # Archive the prior current-revision before replacing. 

1717 self._history[bank_id].setdefault(model.model_id, []).append(existing) 

1718 

1719 # Store always assigns revision + refreshed_at (callers don't need 

1720 # to fill these in). M40 — validate source_timestamps alignment; 

1721 # drop on mismatch rather than persist a wrong-alignment array 

1722 # that would silently corrupt trend computation downstream. 

1723 src_ts = model.source_timestamps 

1724 if src_ts is not None and len(src_ts) != len(model.source_ids): 

1725 src_ts = None 

1726 stamped = _replace( 

1727 model, 

1728 bank_id=bank_id, 

1729 revision=new_revision, 

1730 refreshed_at=datetime.now(UTC), 

1731 source_timestamps=src_ts, 

1732 ) 

1733 self._models[bank_id][model.model_id] = stamped 

1734 return new_revision 

1735 

1736 async def get(self, model_id: str, bank_id: str) -> "MentalModel | None": 

1737 return self._models.get(bank_id, {}).get(model_id) 

1738 

1739 async def list( 

1740 self, 

1741 bank_id: str, 

1742 *, 

1743 scope: str | None = None, 

1744 kind: str | None = None, 

1745 ) -> list["MentalModel"]: 

1746 models = list(self._models.get(bank_id, {}).values()) 

1747 if scope is not None: 

1748 models = [m for m in models if m.scope == scope] 

1749 if kind is not None: 

1750 models = [m for m in models if m.kind == kind] 

1751 return models 

1752 

1753 async def delete(self, model_id: str, bank_id: str) -> bool: 

1754 bank_models = self._models.get(bank_id, {}) 

1755 if model_id not in bank_models: 

1756 return False 

1757 del bank_models[model_id] 

1758 # Keep history for audit; tests can call ``revision_history()`` to 

1759 # observe past revisions even after delete. 

1760 return True 

1761 

1762 async def update_via_ops( 

1763 self, 

1764 model_id: str, 

1765 bank_id: str, 

1766 operations_json: list[dict], 

1767 ) -> "tuple[int, dict] | None": 

1768 """M21 — apply structured delta operations and re-upsert. 

1769 

1770 Lazy-migrates legacy rows by parsing ``content`` on first 

1771 refresh. Re-renders ``content`` from the new structured doc so 

1772 markdown readers see the updated text without a separate 

1773 migration pass. 

1774 """ 

1775 from dataclasses import replace as _replace 

1776 

1777 from astrocyte.pipeline.delta_ops import ( 

1778 DeltaOperationList, 

1779 apply_operations, 

1780 ) 

1781 from astrocyte.pipeline.structured_doc import ( 

1782 StructuredDocument, 

1783 parse_markdown, 

1784 render_document, 

1785 ) 

1786 

1787 current = self._models.get(bank_id, {}).get(model_id) 

1788 if current is None: 

1789 return None 

1790 

1791 # Resolve the current structured doc — lazy-migrate from 

1792 # raw markdown when the field is None (legacy rows). 

1793 if current.structured_doc is None: 

1794 doc = parse_markdown(current.content or "") 

1795 else: 

1796 doc = StructuredDocument.model_validate(current.structured_doc) 

1797 

1798 # Parse the LLM/caller's op list via the Pydantic schema. 

1799 # Schema-invalid op shapes (unknown ``op`` discriminator, 

1800 # missing required fields) raise ValidationError; per the 

1801 # conservative-failure contract, we catch and return a 

1802 # zero-ops apply (doc stays as-is, no revision bump). 

1803 try: 

1804 ops_container = DeltaOperationList.model_validate({"operations": operations_json}) 

1805 except Exception: 

1806 # Schema-level failure: nothing applied. Return current revision. 

1807 return (current.revision, {"applied": [], "skipped": [], "changed": False}) 

1808 

1809 applied = apply_operations(doc, ops_container.operations) 

1810 if not applied.changed: 

1811 return ( 

1812 current.revision, 

1813 {"applied": applied.applied, "skipped": applied.skipped, "changed": False}, 

1814 ) 

1815 

1816 new_doc_dict = applied.document.model_dump() 

1817 new_content = render_document(applied.document) 

1818 new_model = _replace( 

1819 current, 

1820 content=new_content, 

1821 structured_doc=new_doc_dict, 

1822 ) 

1823 new_revision = await self.upsert(new_model, bank_id) 

1824 return ( 

1825 new_revision, 

1826 {"applied": applied.applied, "skipped": applied.skipped, "changed": True}, 

1827 ) 

1828 

1829 async def refresh( 

1830 self, 

1831 model_id: str, 

1832 bank_id: str, 

1833 new_source_ids: list[str], 

1834 ) -> "MentalModel | None": 

1835 """M28 — merge new source_ids into existing model and bump revision. 

1836 

1837 In-memory semantics: the production store will re-run an LLM 

1838 compile against the merged source set; the test store simply 

1839 acknowledges the new sources (deduped, preserving existing 

1840 order then appending novel ids) and bumps the revision via 

1841 :meth:`upsert`. That keeps tests focused on the wiring and 

1842 contract surface without coupling them to LLM behaviour. 

1843 """ 

1844 from dataclasses import replace as _replace 

1845 

1846 current = self._models.get(bank_id, {}).get(model_id) 

1847 if current is None: 

1848 return None 

1849 

1850 # Order-preserving dedup: keep existing source_ids in original 

1851 # order, then append novel ids in input order. 

1852 seen: set[str] = set(current.source_ids) 

1853 merged = list(current.source_ids) 

1854 for sid in new_source_ids: 

1855 if sid not in seen: 

1856 seen.add(sid) 

1857 merged.append(sid) 

1858 

1859 refreshed = _replace(current, source_ids=merged) 

1860 await self.upsert(refreshed, bank_id) 

1861 return self._models.get(bank_id, {}).get(model_id) 

1862 

1863 async def health(self) -> HealthStatus: 

1864 return HealthStatus(healthy=True, message="in-memory mental model store") 

1865 

1866 def revision_history(self, model_id: str, bank_id: str) -> list["MentalModel"]: 

1867 """Return past revisions for a model (oldest first). Testing helper.""" 

1868 return list(self._history.get(bank_id, {}).get(model_id, [])) 

1869 

1870 

1871# --------------------------------------------------------------------------- 

1872# In-Memory Source Store (M10 — documents + chunks normalisation) 

1873# --------------------------------------------------------------------------- 

1874 

1875 

1876class InMemorySourceStore: 

1877 """Fully functional in-memory source-document store for testing. 

1878 

1879 Mirrors the :class:`~astrocyte.provider.SourceStore` SPI: document 

1880 create/get/list/delete + bulk-chunk insert with content_hash dedup. 

1881 Soft-deletes match the Postgres adapter's lifecycle (deleted_at on 

1882 documents; cascade-by-store-eviction on chunks). 

1883 """ 

1884 

1885 SPI_VERSION: ClassVar[int] = 1 

1886 

1887 def __init__(self) -> None: 

1888 # bank_id → {document_id → SourceDocument} 

1889 self._docs: dict[str, dict[str, "SourceDocument"]] = {} 

1890 # bank_id → {chunk_id → SourceChunk} 

1891 self._chunks: dict[str, dict[str, "SourceChunk"]] = {} 

1892 # bank_id → {(content_hash) → document_id} for fast dedup lookup 

1893 self._doc_hash_index: dict[str, dict[str, str]] = {} 

1894 # bank_id → {(content_hash) → chunk_id} for fast dedup lookup 

1895 self._chunk_hash_index: dict[str, dict[str, str]] = {} 

1896 

1897 async def store_document(self, document: "SourceDocument") -> str: 

1898 from dataclasses import replace as _replace 

1899 from datetime import UTC 

1900 from datetime import datetime as _dt 

1901 

1902 bank_id = document.bank_id 

1903 self._docs.setdefault(bank_id, {}) 

1904 self._doc_hash_index.setdefault(bank_id, {}) 

1905 

1906 # Dedup by content_hash when set. 

1907 if document.content_hash: 

1908 existing_id = self._doc_hash_index[bank_id].get(document.content_hash) 

1909 if existing_id is not None and existing_id in self._docs[bank_id]: 

1910 return existing_id 

1911 

1912 # Stamp created_at if not provided (store owns this field). 

1913 stamped = _replace( 

1914 document, 

1915 created_at=document.created_at or _dt.now(UTC), 

1916 ) 

1917 self._docs[bank_id][document.id] = stamped 

1918 if document.content_hash: 

1919 self._doc_hash_index[bank_id][document.content_hash] = document.id 

1920 return document.id 

1921 

1922 async def get_document( 

1923 self, 

1924 document_id: str, 

1925 bank_id: str, 

1926 ) -> "SourceDocument | None": 

1927 return self._docs.get(bank_id, {}).get(document_id) 

1928 

1929 async def find_document_by_hash( 

1930 self, 

1931 content_hash: str, 

1932 bank_id: str, 

1933 ) -> "SourceDocument | None": 

1934 doc_id = self._doc_hash_index.get(bank_id, {}).get(content_hash) 

1935 if doc_id is None: 

1936 return None 

1937 return self._docs.get(bank_id, {}).get(doc_id) 

1938 

1939 async def list_documents( 

1940 self, 

1941 bank_id: str, 

1942 *, 

1943 limit: int = 100, 

1944 ) -> list["SourceDocument"]: 

1945 docs = list(self._docs.get(bank_id, {}).values()) 

1946 # Newest-first by created_at; tolerate Nones via a fallback. 

1947 from datetime import UTC 

1948 from datetime import datetime as _dt 

1949 

1950 docs.sort(key=lambda d: d.created_at or _dt.fromtimestamp(0, UTC), reverse=True) 

1951 return docs[:limit] 

1952 

1953 async def delete_document(self, document_id: str, bank_id: str) -> bool: 

1954 bank_docs = self._docs.get(bank_id, {}) 

1955 doc = bank_docs.pop(document_id, None) 

1956 if doc is None: 

1957 return False 

1958 # Drop the hash-index entry too so a follow-up re-store doesn't 

1959 # silently dedup against the deleted row. 

1960 if doc.content_hash: 

1961 self._doc_hash_index.get(bank_id, {}).pop(doc.content_hash, None) 

1962 # Cascade: drop all chunks of this document. 

1963 bank_chunks = self._chunks.get(bank_id, {}) 

1964 to_drop = [cid for cid, chunk in bank_chunks.items() if chunk.document_id == document_id] 

1965 for cid in to_drop: 

1966 chunk = bank_chunks.pop(cid) 

1967 if chunk.content_hash: 

1968 self._chunk_hash_index.get(bank_id, {}).pop(chunk.content_hash, None) 

1969 return True 

1970 

1971 async def store_chunks(self, chunks: list["SourceChunk"]) -> list[str]: 

1972 from dataclasses import replace as _replace 

1973 from datetime import UTC 

1974 from datetime import datetime as _dt 

1975 

1976 ids: list[str] = [] 

1977 for chunk in chunks: 

1978 bank_id = chunk.bank_id 

1979 self._chunks.setdefault(bank_id, {}) 

1980 self._chunk_hash_index.setdefault(bank_id, {}) 

1981 

1982 # Dedup by content_hash when set. 

1983 if chunk.content_hash: 

1984 existing = self._chunk_hash_index[bank_id].get(chunk.content_hash) 

1985 if existing is not None and existing in self._chunks[bank_id]: 

1986 ids.append(existing) 

1987 continue 

1988 

1989 stamped = _replace( 

1990 chunk, 

1991 created_at=chunk.created_at or _dt.now(UTC), 

1992 ) 

1993 self._chunks[bank_id][chunk.id] = stamped 

1994 if chunk.content_hash: 

1995 self._chunk_hash_index[bank_id][chunk.content_hash] = chunk.id 

1996 ids.append(chunk.id) 

1997 return ids 

1998 

1999 async def get_chunk(self, chunk_id: str, bank_id: str) -> "SourceChunk | None": 

2000 return self._chunks.get(bank_id, {}).get(chunk_id) 

2001 

2002 async def list_chunks( 

2003 self, 

2004 document_id: str, 

2005 bank_id: str, 

2006 ) -> list["SourceChunk"]: 

2007 chunks = [c for c in self._chunks.get(bank_id, {}).values() if c.document_id == document_id] 

2008 chunks.sort(key=lambda c: c.chunk_index) 

2009 return chunks 

2010 

2011 async def find_chunk_by_hash( 

2012 self, 

2013 content_hash: str, 

2014 bank_id: str, 

2015 ) -> "SourceChunk | None": 

2016 chunk_id = self._chunk_hash_index.get(bank_id, {}).get(content_hash) 

2017 if chunk_id is None: 

2018 return None 

2019 return self._chunks.get(bank_id, {}).get(chunk_id) 

2020 

2021 async def health(self) -> HealthStatus: 

2022 return HealthStatus(healthy=True, message="in-memory source store") 

2023 

2024 

2025# --------------------------------------------------------------------------- 

2026# Mock LLM Provider 

2027# --------------------------------------------------------------------------- 

2028 

2029 

2030class MockLLMProvider: 

2031 """Mock LLM provider with bag-of-words embeddings and extractive synthesis. 

2032 

2033 Embeddings use term-frequency vectors over a shared vocabulary, so 

2034 semantically related texts have high cosine similarity — enabling 

2035 meaningful vector retrieval without an API key. 

2036 

2037 Synthesis extracts the most relevant memory text from the prompt context 

2038 rather than returning a static string. 

2039 """ 

2040 

2041 SPI_VERSION: ClassVar[int] = 1 

2042 

2043 _embed_dim: ClassVar[int] = 128 

2044 

2045 def __init__( 

2046 self, 

2047 default_response: str = "Mock LLM response", 

2048 embedding_dimensions: int | None = None, 

2049 ) -> None: 

2050 self._default_response = default_response 

2051 if embedding_dimensions is not None: 

2052 self._embed_dim = int(embedding_dimensions) 

2053 self._call_count = 0 

2054 #: Last ``complete()`` user message content (for tests asserting prompt structure). 

2055 self.last_user_message: str | None = None 

2056 

2057 def capabilities(self) -> LLMCapabilities: 

2058 return LLMCapabilities() 

2059 

2060 async def complete( 

2061 self, 

2062 messages: list[Message], 

2063 model: str | None = None, 

2064 max_tokens: int = 1024, 

2065 temperature: float = 0.0, 

2066 tools: list = None, # list[ToolDefinition] | None — kept untyped to avoid import cycle 

2067 tool_choice: str | None = None, 

2068 response_format: dict | None = None, 

2069 ) -> Completion: 

2070 self._call_count += 1 

2071 for m in reversed(messages): 

2072 if m.role == "user" and isinstance(m.content, str): 

2073 self.last_user_message = m.content 

2074 break 

2075 all_text = " ".join(m.content for m in messages if isinstance(m.content, str)) 

2076 # Entity extraction prompt 

2077 if "extract named entities" in all_text.lower(): 

2078 return Completion( 

2079 text='[{"name": "Test Entity", "entity_type": "OTHER", "aliases": []}]', 

2080 model=model or "mock", 

2081 usage=TokenUsage(input_tokens=10, output_tokens=20), 

2082 ) 

2083 # Memory synthesis prompt — extract most relevant memory text 

2084 if "<memories>" in all_text and "<query>" in all_text: 

2085 answer = _extractive_synthesize(all_text) 

2086 return Completion( 

2087 text=answer, 

2088 model=model or "mock", 

2089 usage=TokenUsage(input_tokens=50, output_tokens=30), 

2090 ) 

2091 return Completion( 

2092 text=self._default_response, 

2093 model=model or "mock", 

2094 usage=TokenUsage(input_tokens=10, output_tokens=20), 

2095 ) 

2096 

2097 async def embed(self, texts: list[str], model: str | None = None) -> list[list[float]]: 

2098 """Generate bag-of-words embeddings with real semantic signal. 

2099 

2100 Uses term-frequency vectors over a shared vocabulary. Texts that share 

2101 words have high cosine similarity, enabling meaningful retrieval. 

2102 """ 

2103 return [self._bow_embed(text) for text in texts] 

2104 

2105 def _bow_embed(self, text: str) -> list[float]: 

2106 """Build a term-frequency vector using the hashing trick. 

2107 

2108 Each word maps to a bucket via ``hash(word) % dim`` and increments 

2109 that dimension. All components are non-negative, guaranteeing 

2110 non-negative cosine similarities between any two vectors — required 

2111 because ``VectorHit.score`` enforces ``>= 0.0``. 

2112 """ 

2113 import hashlib 

2114 from string import punctuation 

2115 

2116 dim = self._embed_dim 

2117 vec = [0.0] * dim 

2118 tokens = [t.strip(punctuation).lower() for t in text.split() if t.strip(punctuation)] 

2119 

2120 for token in tokens: 

2121 if not token: 

2122 continue 

2123 h = int(hashlib.md5(token.encode()).hexdigest(), 16) 

2124 bucket = h % dim 

2125 vec[bucket] += 1.0 

2126 

2127 # L2 normalize 

2128 norm = math.sqrt(sum(x * x for x in vec)) 

2129 if norm > 0: 

2130 vec = [x / norm for x in vec] 

2131 return vec 

2132 

2133 

2134def _normalize_terms(text: str) -> set[str]: 

2135 """Lowercase, strip punctuation, drop short words.""" 

2136 from string import punctuation 

2137 

2138 return {t.strip(punctuation).lower() for t in text.split() if len(t.strip(punctuation)) > 2} 

2139 

2140 

2141def _extractive_synthesize(prompt_text: str) -> str: 

2142 """Extract the most query-relevant memory from the synthesis prompt. 

2143 

2144 Parses ``<memories>`` and ``<query>`` blocks from the reflect prompt, 

2145 scores each memory by word overlap with the query, and returns 

2146 the top memories concatenated. This gives the mock provider 

2147 meaningful reflect answers without a real LLM. 

2148 """ 

2149 import re 

2150 

2151 # Extract query 

2152 query_match = re.search(r"<query>\s*(.*?)\s*</query>", prompt_text, re.DOTALL) 

2153 if not query_match: 

2154 return "No query found." 

2155 query = query_match.group(1).strip() 

2156 query_terms = _normalize_terms(query) 

2157 

2158 # Extract individual memories 

2159 memories_match = re.search(r"<memories>\s*(.*?)\s*</memories>", prompt_text, re.DOTALL) 

2160 if not memories_match: 

2161 return "No memories found." 

2162 

2163 # Group continuation lines with their [Memory N] header. 

2164 # Lines starting with [Memory N] start a new block; other lines 

2165 # are appended to the current block. 

2166 raw_lines = memories_match.group(1).strip().split("\n") 

2167 blocks: list[str] = [] 

2168 for line in raw_lines: 

2169 line = line.strip() 

2170 if not line: 

2171 continue 

2172 if re.match(r"^\[Memory \d+\]", line): 

2173 # Strip [Memory N] prefix for cleaner output 

2174 text = re.sub(r"^\[Memory \d+\](\s*\([^)]*\))?(\s*\[[^\]]*\])?\s*:\s*", "", line) 

2175 blocks.append(text) 

2176 elif blocks: 

2177 blocks[-1] += " " + line 

2178 else: 

2179 blocks.append(line) 

2180 

2181 # Score each memory block by query term overlap 

2182 scored: list[tuple[float, str]] = [] 

2183 for text in blocks: 

2184 block_terms = _normalize_terms(text) 

2185 overlap = len(query_terms & block_terms) 

2186 if overlap > 0: 

2187 scored.append((overlap, text)) 

2188 

2189 if not scored: 

2190 return "I don't have relevant information to answer this question." 

2191 

2192 scored.sort(key=lambda x: x[0], reverse=True) 

2193 # Return top 3 most relevant memories 

2194 top = [text for _, text in scored[:3]] 

2195 return " ".join(top)