Coverage for astrocyte/pipeline/recall_cache.py: 82%

67 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Recall cache — LRU cache keyed by query embedding similarity. 

2 

3Avoids redundant retrieval for repeated or similar queries. 

4Invalidated on retain (bank contents changed). 

5 

6Sync, self-contained — Rust migration candidate. 

7Inspired by ByteRover's Tier 0/1 progressive retrieval. 

8""" 

9 

10from __future__ import annotations 

11 

12import threading 

13import time 

14from dataclasses import dataclass 

15 

16from astrocyte.policy.signal_quality import cosine_similarity 

17from astrocyte.types import RecallResult 

18 

19#: Default maximum cache entries across all banks. 

20DEFAULT_CACHE_MAX_ENTRIES = 256 

21 

22#: Default cache TTL in seconds. 

23DEFAULT_CACHE_TTL_SECONDS = 300.0 

24 

25 

26@dataclass 

27class _CacheEntry: 

28 query_vector: list[float] 

29 result: RecallResult 

30 timestamp: float 

31 

32 

33class RecallCache: 

34 """LRU recall cache with similarity-based lookup. 

35 

36 Entries are keyed by (bank_id, query_vector). A cache hit occurs when 

37 cosine similarity between the query vector and a cached vector exceeds 

38 the threshold. Entries expire after ttl_seconds. 

39 

40 Invalidate a bank's cache on retain (contents changed). 

41 

42 Thread-safe: all mutations are protected by a lock, consistent with 

43 RateLimiter and CircuitBreaker in the policy layer. 

44 """ 

45 

46 def __init__( 

47 self, 

48 similarity_threshold: float = 0.95, 

49 max_entries: int = DEFAULT_CACHE_MAX_ENTRIES, 

50 ttl_seconds: float = DEFAULT_CACHE_TTL_SECONDS, 

51 ) -> None: 

52 self.similarity_threshold = similarity_threshold 

53 self.max_entries = max_entries 

54 self.ttl_seconds = ttl_seconds 

55 self._cache: dict[str, list[_CacheEntry]] = {} # bank_id → entries 

56 self._lock = threading.Lock() 

57 

58 def get(self, bank_id: str, query_vector: list[float]) -> RecallResult | None: 

59 """Look up a cached recall result by query similarity. 

60 

61 Returns None on miss. Evicts expired entries during lookup. 

62 """ 

63 with self._lock: 

64 entries = self._cache.get(bank_id) 

65 if not entries: 

66 return None 

67 

68 now = time.monotonic() 

69 

70 # Evict expired entries 

71 entries[:] = [e for e in entries if (now - e.timestamp) < self.ttl_seconds] 

72 

73 # Search for similar query 

74 for entry in entries: 

75 sim = cosine_similarity(query_vector, entry.query_vector) 

76 if sim >= self.similarity_threshold: 

77 # Move to end (LRU) 

78 entries.remove(entry) 

79 entries.append(entry) 

80 return entry.result 

81 

82 return None 

83 

84 def put(self, bank_id: str, query_vector: list[float], result: RecallResult) -> None: 

85 """Store a recall result in the cache.""" 

86 with self._lock: 

87 if bank_id not in self._cache: 

88 self._cache[bank_id] = [] 

89 

90 entries = self._cache[bank_id] 

91 

92 # Evict LRU if this bank is at capacity 

93 while len(entries) >= self.max_entries: 

94 entries.pop(0) 

95 

96 # Enforce global capacity across all banks 

97 total = sum(len(e) for e in self._cache.values()) 

98 while total >= self.max_entries * 4: # Global cap: 4x per-bank limit 

99 # Evict oldest entry across all banks 

100 oldest_bank = None 

101 oldest_time = float("inf") 

102 for bid, bank_entries in self._cache.items(): 

103 if bank_entries and bank_entries[0].timestamp < oldest_time: 

104 oldest_time = bank_entries[0].timestamp 

105 oldest_bank = bid 

106 if oldest_bank is not None: 

107 self._cache[oldest_bank].pop(0) 

108 if not self._cache[oldest_bank]: 

109 del self._cache[oldest_bank] 

110 total -= 1 

111 else: 

112 break 

113 

114 entries.append( 

115 _CacheEntry( 

116 query_vector=query_vector, 

117 result=result, 

118 timestamp=time.monotonic(), 

119 ) 

120 ) 

121 

122 def invalidate_bank(self, bank_id: str) -> None: 

123 """Clear all cached results for a bank (called on retain).""" 

124 with self._lock: 

125 self._cache.pop(bank_id, None) 

126 

127 def invalidate_all(self) -> None: 

128 """Clear the entire cache.""" 

129 with self._lock: 

130 self._cache.clear() 

131 

132 def size(self, bank_id: str | None = None) -> int: 

133 """Number of cached entries (total or per bank).""" 

134 with self._lock: 

135 if bank_id: 

136 return len(self._cache.get(bank_id, [])) 

137 return sum(len(entries) for entries in self._cache.values())