Coverage for astrocyte/policy/homeostasis.py: 80%

131 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Homeostasis policies — rate limiting, token budgets, quotas. 

2 

3All functions are sync (Rust migration candidates). 

4See docs/_design/policy-layer.md section 1 and docs/_design/implementation-language-strategy.md. 

5""" 

6 

7from __future__ import annotations 

8 

9import threading 

10import time 

11from dataclasses import dataclass 

12 

13from astrocyte.errors import RateLimited 

14from astrocyte.types import MemoryHit 

15 

16# --------------------------------------------------------------------------- 

17# Rate limiter (token bucket algorithm) 

18# --------------------------------------------------------------------------- 

19 

20 

21@dataclass 

22class _BucketState: 

23 tokens: float 

24 last_refill: float 

25 

26 

27class RateLimiter: 

28 """Per-bank, per-operation token bucket rate limiter. 

29 

30 Sync, stateful, self-contained — Rust migration candidate. 

31 """ 

32 

33 _MAX_BUCKETS = 10000 

34 

35 def __init__(self, max_per_minute: int) -> None: 

36 self._max_per_minute = max_per_minute 

37 self._tokens_per_second = max_per_minute / 60.0 

38 self._max_tokens = float(max_per_minute) 

39 self._buckets: dict[str, _BucketState] = {} 

40 self._lock = threading.Lock() 

41 

42 def _get_bucket(self, key: str) -> _BucketState: 

43 if key not in self._buckets: 

44 if len(self._buckets) >= self._MAX_BUCKETS: 

45 oldest_key = next(iter(self._buckets)) 

46 del self._buckets[oldest_key] 

47 self._buckets[key] = _BucketState(tokens=self._max_tokens, last_refill=time.monotonic()) 

48 return self._buckets[key] 

49 

50 def _refill(self, bucket: _BucketState) -> None: 

51 now = time.monotonic() 

52 elapsed = now - bucket.last_refill 

53 bucket.tokens = min(self._max_tokens, bucket.tokens + elapsed * self._tokens_per_second) 

54 bucket.last_refill = now 

55 

56 def check(self, bank_id: str, operation: str) -> None: 

57 """Check rate limit. Raises RateLimited if exceeded.""" 

58 with self._lock: 

59 key = f"{bank_id}:{operation}" 

60 bucket = self._get_bucket(key) 

61 self._refill(bucket) 

62 

63 if bucket.tokens < 1.0: 

64 retry_after = (1.0 - bucket.tokens) / self._tokens_per_second 

65 raise RateLimited(bank_id=bank_id, operation=operation, retry_after_seconds=retry_after) 

66 

67 def record(self, bank_id: str, operation: str) -> None: 

68 """Record a successful operation (consume one token).""" 

69 with self._lock: 

70 key = f"{bank_id}:{operation}" 

71 bucket = self._get_bucket(key) 

72 self._refill(bucket) 

73 bucket.tokens = max(0.0, bucket.tokens - 1.0) 

74 

75 def check_and_record(self, bank_id: str, operation: str) -> None: 

76 """Check and consume in one call.""" 

77 with self._lock: 

78 key = f"{bank_id}:{operation}" 

79 bucket = self._get_bucket(key) 

80 self._refill(bucket) 

81 

82 if bucket.tokens < 1.0: 

83 retry_after = (1.0 - bucket.tokens) / self._tokens_per_second 

84 raise RateLimited(bank_id=bank_id, operation=operation, retry_after_seconds=retry_after) 

85 

86 bucket.tokens = max(0.0, bucket.tokens - 1.0) 

87 

88 

89# --------------------------------------------------------------------------- 

90# Token budget enforcement 

91# --------------------------------------------------------------------------- 

92 

93 

94# --------------------------------------------------------------------------- 

95# Token counting — tiktoken if available, heuristic fallback 

96# --------------------------------------------------------------------------- 

97 

98 

99class _TiktokenCache: 

100 """Lazy singleton for optional tiktoken encoder.""" 

101 

102 encoder: object | None = None 

103 checked: bool = False 

104 

105 @classmethod 

106 def get(cls) -> object | None: 

107 """Return tiktoken encoder (cl100k_base), or None if not installed.""" 

108 if cls.checked: 

109 return cls.encoder 

110 cls.checked = True 

111 try: 

112 import tiktoken # type: ignore[import-untyped] 

113 

114 cls.encoder = tiktoken.get_encoding("cl100k_base") 

115 except (ImportError, Exception): 

116 cls.encoder = None 

117 return cls.encoder 

118 

119 

120def _heuristic_token_count(text: str) -> int: 

121 """Hybrid heuristic: character-based + word-based, CJK-aware.""" 

122 char_count = len(text) 

123 word_count = len(text.split()) 

124 

125 # Detect CJK content 

126 cjk_chars = sum( 

127 1 for c in text if "\u4e00" <= c <= "\u9fff" or "\u3040" <= c <= "\u30ff" or "\uac00" <= c <= "\ud7af" 

128 ) 

129 cjk_ratio = cjk_chars / max(char_count, 1) 

130 

131 if cjk_ratio > 0.3: 

132 char_estimate = int(char_count / 1.5) 

133 else: 

134 char_estimate = int(char_count / 4) 

135 

136 word_estimate = int(word_count * 1.33) 

137 return max(1, max(char_estimate, word_estimate)) 

138 

139 

140def count_tokens(text: str) -> int: 

141 """Count tokens. Uses tiktoken (cl100k_base) if installed, otherwise a heuristic. 

142 

143 Install ``tiktoken`` for accurate counts: ``pip install tiktoken``. 

144 """ 

145 if not text: 

146 return 1 

147 

148 enc = _TiktokenCache.get() 

149 if enc is not None: 

150 return len(enc.encode(text)) # type: ignore[union-attr] 

151 

152 return _heuristic_token_count(text) 

153 

154 

155def enforce_token_budget(hits: list[MemoryHit], max_tokens: int) -> tuple[list[MemoryHit], bool]: 

156 """Truncate hit list to fit within token budget. 

157 

158 Returns (truncated_hits, was_truncated). 

159 Sync, pure computation — Rust migration candidate. 

160 """ 

161 result: list[MemoryHit] = [] 

162 total = 0 

163 truncated = False 

164 

165 for hit in hits: 

166 tokens = count_tokens(hit.text) 

167 if total + tokens > max_tokens: 

168 truncated = True 

169 break 

170 result.append(hit) 

171 total += tokens 

172 

173 return result, truncated 

174 

175 

176# --------------------------------------------------------------------------- 

177# Quota tracker (daily counters) 

178# --------------------------------------------------------------------------- 

179 

180 

181class QuotaTracker: 

182 """Per-bank daily quota tracking. 

183 

184 Sync, self-contained — Rust migration candidate. 

185 Resets are time-based (tracks day boundaries). 

186 """ 

187 

188 _MAX_COUNTERS = 10000 

189 

190 def __init__(self) -> None: 

191 # {bank_id:operation -> (count, day_number)} 

192 self._counters: dict[str, tuple[int, int]] = {} 

193 

194 @staticmethod 

195 def _today() -> int: 

196 return int(time.time() // 86400) 

197 

198 def check(self, bank_id: str, operation: str, limit: int | None) -> bool: 

199 """Check if quota allows the operation. Returns True if allowed.""" 

200 if limit is None: 

201 return True 

202 

203 key = f"{bank_id}:{operation}" 

204 today = self._today() 

205 

206 count, day = self._counters.get(key, (0, today)) 

207 if day != today: 

208 count = 0 # Reset on new day 

209 

210 return count < limit 

211 

212 def record(self, bank_id: str, operation: str) -> None: 

213 """Record one operation against the quota.""" 

214 key = f"{bank_id}:{operation}" 

215 today = self._today() 

216 

217 count, day = self._counters.get(key, (0, today)) 

218 if day != today: 

219 count = 0 

220 day = today 

221 

222 self._counters[key] = (count + 1, day) 

223 

224 # Evict stale entries if over capacity 

225 if len(self._counters) > self._MAX_COUNTERS: 

226 stale = [k for k, (_, d) in self._counters.items() if d != today] 

227 for k in stale: 

228 del self._counters[k] 

229 

230 def get_count(self, bank_id: str, operation: str) -> int: 

231 """Get current count for a bank+operation.""" 

232 key = f"{bank_id}:{operation}" 

233 today = self._today() 

234 count, day = self._counters.get(key, (0, today)) 

235 if day != today: 

236 return 0 

237 return count