Coverage for astrocyte/pipeline/recent_buffer.py: 93%

84 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Recent memory buffer — ring buffer of recently stored chunks for Tier 1 fuzzy matching. 

2 

3Provides sub-5ms fuzzy text search on recent memories without embedding cost. 

4Handles typos and morphological variations via character-level similarity. 

5 

6Sync, self-contained — Rust migration candidate. 

7""" 

8 

9from __future__ import annotations 

10 

11import logging 

12import threading 

13from collections import deque 

14from dataclasses import dataclass 

15from difflib import SequenceMatcher 

16from string import punctuation 

17 

18from astrocyte.types import MemoryHit 

19 

20logger = logging.getLogger(__name__) 

21 

22#: Default max recent items per bank. 

23DEFAULT_MAX_PER_BANK = 100 

24 

25#: Minimum character-level similarity for a token to count as a fuzzy match. 

26_TOKEN_SIMILARITY_THRESHOLD = 0.75 

27 

28# Query words too common to be meaningful for fuzzy matching. 

29_STOP_WORDS = frozenset( 

30 "a an the is are was were be been being have has had do does did " 

31 "will would shall should may might can could of in on at to for " 

32 "with by from and or but not no nor so yet both either neither " 

33 "this that these those it its he she they them his her their " 

34 "who what which when where how i me my we our you your " 

35 "about also just like very much more than then some any".split() 

36) 

37 

38 

39@dataclass(slots=True) 

40class _RecentEntry: 

41 memory_id: str 

42 text: str 

43 tokens: list[str] # pre-tokenized for fast matching 

44 bank_id: str 

45 metadata: dict | None = None 

46 

47 

48def _tokenize(text: str) -> list[str]: 

49 """Lowercase, strip punctuation, remove stop words.""" 

50 return [w for raw in text.lower().split() if len(w := raw.strip(punctuation)) > 1 and w not in _STOP_WORDS] 

51 

52 

53def _fuzzy_token_score(query_tokens: list[str], memory_tokens: list[str]) -> float: 

54 """Score a memory against query tokens using character-level fuzzy matching. 

55 

56 For each query token, find the best fuzzy match among memory tokens. 

57 A match counts if SequenceMatcher ratio >= threshold (handles typos). 

58 Returns fraction of query tokens that matched (0.0–1.0). 

59 """ 

60 if not query_tokens: 

61 return 0.0 

62 if not memory_tokens: 

63 return 0.0 

64 

65 matched = 0 

66 # Build set for exact-match fast path 

67 memory_set = set(memory_tokens) 

68 

69 for qt in query_tokens: 

70 # Fast path: exact match 

71 if qt in memory_set: 

72 matched += 1 

73 continue 

74 # Slow path: fuzzy match (character-level similarity) 

75 best = 0.0 

76 for mt in memory_tokens: 

77 # Quick length filter — very different lengths can't match well 

78 if abs(len(qt) - len(mt)) > max(len(qt), len(mt)) * 0.5: 

79 continue 

80 ratio = SequenceMatcher(None, qt, mt).ratio() 

81 if ratio > best: 

82 best = ratio 

83 if best >= _TOKEN_SIMILARITY_THRESHOLD: 

84 break # Good enough, stop searching 

85 if best >= _TOKEN_SIMILARITY_THRESHOLD: 

86 matched += 1 

87 

88 return matched / len(query_tokens) 

89 

90 

91class RecentMemoryBuffer: 

92 """Ring buffer of recently stored text chunks per bank for Tier 1 fuzzy matching. 

93 

94 Designed for sub-5ms search on the last N stored memories per bank. 

95 Uses character-level fuzzy matching (SequenceMatcher) to handle typos 

96 and morphological variations that BM25 would miss. 

97 

98 Thread-safe: all mutations are protected by a lock. 

99 """ 

100 

101 _MAX_BANKS = 500 

102 

103 def __init__(self, max_per_bank: int = DEFAULT_MAX_PER_BANK) -> None: 

104 self.max_per_bank = max_per_bank 

105 self._buffers: dict[str, deque[_RecentEntry]] = {} 

106 self._lock = threading.Lock() 

107 

108 def add(self, bank_id: str, memory_id: str, text: str, metadata: dict | None = None) -> None: 

109 """Add a recently stored chunk to the buffer.""" 

110 tokens = _tokenize(text) 

111 if not tokens: 

112 return # Skip empty/stop-word-only text 

113 

114 entry = _RecentEntry( 

115 memory_id=memory_id, 

116 text=text, 

117 tokens=tokens, 

118 bank_id=bank_id, 

119 metadata=metadata, 

120 ) 

121 

122 with self._lock: 

123 if bank_id not in self._buffers: 

124 # Evict LRU bank if at capacity 

125 if len(self._buffers) >= self._MAX_BANKS: 

126 lru_bank = next(iter(self._buffers)) 

127 logger.warning( 

128 "RecentMemoryBuffer at capacity (%d banks); evicting LRU bank '%s'", 

129 self._MAX_BANKS, 

130 lru_bank, 

131 ) 

132 del self._buffers[lru_bank] 

133 self._buffers[bank_id] = deque(maxlen=self.max_per_bank) 

134 self._buffers[bank_id].append(entry) 

135 

136 def search( 

137 self, 

138 query: str, 

139 bank_id: str, 

140 limit: int = 10, 

141 min_score: float = 0.3, 

142 ) -> list[MemoryHit]: 

143 """Fuzzy search recent memories for a bank. 

144 

145 Returns scored MemoryHits sorted by relevance, up to ``limit``. 

146 Typically completes in <5ms for buffers of 100 entries. 

147 """ 

148 query_tokens = _tokenize(query) 

149 if not query_tokens: 

150 return [] 

151 

152 with self._lock: 

153 entries = list(self._buffers.get(bank_id, [])) 

154 

155 scored: list[tuple[float, _RecentEntry]] = [] 

156 for entry in entries: 

157 score = _fuzzy_token_score(query_tokens, entry.tokens) 

158 if score >= min_score: 

159 scored.append((score, entry)) 

160 

161 scored.sort(key=lambda x: x[0], reverse=True) 

162 

163 return [ 

164 MemoryHit( 

165 text=entry.text, 

166 score=score, 

167 metadata=entry.metadata, 

168 memory_id=entry.memory_id, 

169 bank_id=bank_id, 

170 ) 

171 for score, entry in scored[:limit] 

172 ] 

173 

174 def clear_bank(self, bank_id: str) -> None: 

175 """Clear buffer for a bank.""" 

176 with self._lock: 

177 self._buffers.pop(bank_id, None) 

178 

179 def size(self, bank_id: str | None = None) -> int: 

180 """Number of buffered entries (total or per bank).""" 

181 with self._lock: 

182 if bank_id: 

183 return len(self._buffers.get(bank_id, [])) 

184 return sum(len(b) for b in self._buffers.values())