Coverage for astrocyte/pipeline/recall_cache.py: 82%
67 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Recall cache — LRU cache keyed by query embedding similarity.
3Avoids redundant retrieval for repeated or similar queries.
4Invalidated on retain (bank contents changed).
6Sync, self-contained — Rust migration candidate.
7Inspired by ByteRover's Tier 0/1 progressive retrieval.
8"""
10from __future__ import annotations
12import threading
13import time
14from dataclasses import dataclass
16from astrocyte.policy.signal_quality import cosine_similarity
17from astrocyte.types import RecallResult
19#: Default maximum cache entries across all banks.
20DEFAULT_CACHE_MAX_ENTRIES = 256
22#: Default cache TTL in seconds.
23DEFAULT_CACHE_TTL_SECONDS = 300.0
26@dataclass
27class _CacheEntry:
28 query_vector: list[float]
29 result: RecallResult
30 timestamp: float
33class RecallCache:
34 """LRU recall cache with similarity-based lookup.
36 Entries are keyed by (bank_id, query_vector). A cache hit occurs when
37 cosine similarity between the query vector and a cached vector exceeds
38 the threshold. Entries expire after ttl_seconds.
40 Invalidate a bank's cache on retain (contents changed).
42 Thread-safe: all mutations are protected by a lock, consistent with
43 RateLimiter and CircuitBreaker in the policy layer.
44 """
46 def __init__(
47 self,
48 similarity_threshold: float = 0.95,
49 max_entries: int = DEFAULT_CACHE_MAX_ENTRIES,
50 ttl_seconds: float = DEFAULT_CACHE_TTL_SECONDS,
51 ) -> None:
52 self.similarity_threshold = similarity_threshold
53 self.max_entries = max_entries
54 self.ttl_seconds = ttl_seconds
55 self._cache: dict[str, list[_CacheEntry]] = {} # bank_id → entries
56 self._lock = threading.Lock()
58 def get(self, bank_id: str, query_vector: list[float]) -> RecallResult | None:
59 """Look up a cached recall result by query similarity.
61 Returns None on miss. Evicts expired entries during lookup.
62 """
63 with self._lock:
64 entries = self._cache.get(bank_id)
65 if not entries:
66 return None
68 now = time.monotonic()
70 # Evict expired entries
71 entries[:] = [e for e in entries if (now - e.timestamp) < self.ttl_seconds]
73 # Search for similar query
74 for entry in entries:
75 sim = cosine_similarity(query_vector, entry.query_vector)
76 if sim >= self.similarity_threshold:
77 # Move to end (LRU)
78 entries.remove(entry)
79 entries.append(entry)
80 return entry.result
82 return None
84 def put(self, bank_id: str, query_vector: list[float], result: RecallResult) -> None:
85 """Store a recall result in the cache."""
86 with self._lock:
87 if bank_id not in self._cache:
88 self._cache[bank_id] = []
90 entries = self._cache[bank_id]
92 # Evict LRU if this bank is at capacity
93 while len(entries) >= self.max_entries:
94 entries.pop(0)
96 # Enforce global capacity across all banks
97 total = sum(len(e) for e in self._cache.values())
98 while total >= self.max_entries * 4: # Global cap: 4x per-bank limit
99 # Evict oldest entry across all banks
100 oldest_bank = None
101 oldest_time = float("inf")
102 for bid, bank_entries in self._cache.items():
103 if bank_entries and bank_entries[0].timestamp < oldest_time:
104 oldest_time = bank_entries[0].timestamp
105 oldest_bank = bid
106 if oldest_bank is not None:
107 self._cache[oldest_bank].pop(0)
108 if not self._cache[oldest_bank]:
109 del self._cache[oldest_bank]
110 total -= 1
111 else:
112 break
114 entries.append(
115 _CacheEntry(
116 query_vector=query_vector,
117 result=result,
118 timestamp=time.monotonic(),
119 )
120 )
122 def invalidate_bank(self, bank_id: str) -> None:
123 """Clear all cached results for a bank (called on retain)."""
124 with self._lock:
125 self._cache.pop(bank_id, None)
127 def invalidate_all(self) -> None:
128 """Clear the entire cache."""
129 with self._lock:
130 self._cache.clear()
132 def size(self, bank_id: str | None = None) -> int:
133 """Number of cached entries (total or per bank)."""
134 with self._lock:
135 if bank_id:
136 return len(self._cache.get(bank_id, []))
137 return sum(len(entries) for entries in self._cache.values())