Coverage for astrocyte/policy/homeostasis.py: 80%
131 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Homeostasis policies — rate limiting, token budgets, quotas.
3All functions are sync (Rust migration candidates).
4See docs/_design/policy-layer.md section 1 and docs/_design/implementation-language-strategy.md.
5"""
7from __future__ import annotations
9import threading
10import time
11from dataclasses import dataclass
13from astrocyte.errors import RateLimited
14from astrocyte.types import MemoryHit
16# ---------------------------------------------------------------------------
17# Rate limiter (token bucket algorithm)
18# ---------------------------------------------------------------------------
21@dataclass
22class _BucketState:
23 tokens: float
24 last_refill: float
27class RateLimiter:
28 """Per-bank, per-operation token bucket rate limiter.
30 Sync, stateful, self-contained — Rust migration candidate.
31 """
33 _MAX_BUCKETS = 10000
35 def __init__(self, max_per_minute: int) -> None:
36 self._max_per_minute = max_per_minute
37 self._tokens_per_second = max_per_minute / 60.0
38 self._max_tokens = float(max_per_minute)
39 self._buckets: dict[str, _BucketState] = {}
40 self._lock = threading.Lock()
42 def _get_bucket(self, key: str) -> _BucketState:
43 if key not in self._buckets:
44 if len(self._buckets) >= self._MAX_BUCKETS:
45 oldest_key = next(iter(self._buckets))
46 del self._buckets[oldest_key]
47 self._buckets[key] = _BucketState(tokens=self._max_tokens, last_refill=time.monotonic())
48 return self._buckets[key]
50 def _refill(self, bucket: _BucketState) -> None:
51 now = time.monotonic()
52 elapsed = now - bucket.last_refill
53 bucket.tokens = min(self._max_tokens, bucket.tokens + elapsed * self._tokens_per_second)
54 bucket.last_refill = now
56 def check(self, bank_id: str, operation: str) -> None:
57 """Check rate limit. Raises RateLimited if exceeded."""
58 with self._lock:
59 key = f"{bank_id}:{operation}"
60 bucket = self._get_bucket(key)
61 self._refill(bucket)
63 if bucket.tokens < 1.0:
64 retry_after = (1.0 - bucket.tokens) / self._tokens_per_second
65 raise RateLimited(bank_id=bank_id, operation=operation, retry_after_seconds=retry_after)
67 def record(self, bank_id: str, operation: str) -> None:
68 """Record a successful operation (consume one token)."""
69 with self._lock:
70 key = f"{bank_id}:{operation}"
71 bucket = self._get_bucket(key)
72 self._refill(bucket)
73 bucket.tokens = max(0.0, bucket.tokens - 1.0)
75 def check_and_record(self, bank_id: str, operation: str) -> None:
76 """Check and consume in one call."""
77 with self._lock:
78 key = f"{bank_id}:{operation}"
79 bucket = self._get_bucket(key)
80 self._refill(bucket)
82 if bucket.tokens < 1.0:
83 retry_after = (1.0 - bucket.tokens) / self._tokens_per_second
84 raise RateLimited(bank_id=bank_id, operation=operation, retry_after_seconds=retry_after)
86 bucket.tokens = max(0.0, bucket.tokens - 1.0)
89# ---------------------------------------------------------------------------
90# Token budget enforcement
91# ---------------------------------------------------------------------------
94# ---------------------------------------------------------------------------
95# Token counting — tiktoken if available, heuristic fallback
96# ---------------------------------------------------------------------------
99class _TiktokenCache:
100 """Lazy singleton for optional tiktoken encoder."""
102 encoder: object | None = None
103 checked: bool = False
105 @classmethod
106 def get(cls) -> object | None:
107 """Return tiktoken encoder (cl100k_base), or None if not installed."""
108 if cls.checked:
109 return cls.encoder
110 cls.checked = True
111 try:
112 import tiktoken # type: ignore[import-untyped]
114 cls.encoder = tiktoken.get_encoding("cl100k_base")
115 except (ImportError, Exception):
116 cls.encoder = None
117 return cls.encoder
120def _heuristic_token_count(text: str) -> int:
121 """Hybrid heuristic: character-based + word-based, CJK-aware."""
122 char_count = len(text)
123 word_count = len(text.split())
125 # Detect CJK content
126 cjk_chars = sum(
127 1 for c in text if "\u4e00" <= c <= "\u9fff" or "\u3040" <= c <= "\u30ff" or "\uac00" <= c <= "\ud7af"
128 )
129 cjk_ratio = cjk_chars / max(char_count, 1)
131 if cjk_ratio > 0.3:
132 char_estimate = int(char_count / 1.5)
133 else:
134 char_estimate = int(char_count / 4)
136 word_estimate = int(word_count * 1.33)
137 return max(1, max(char_estimate, word_estimate))
140def count_tokens(text: str) -> int:
141 """Count tokens. Uses tiktoken (cl100k_base) if installed, otherwise a heuristic.
143 Install ``tiktoken`` for accurate counts: ``pip install tiktoken``.
144 """
145 if not text:
146 return 1
148 enc = _TiktokenCache.get()
149 if enc is not None:
150 return len(enc.encode(text)) # type: ignore[union-attr]
152 return _heuristic_token_count(text)
155def enforce_token_budget(hits: list[MemoryHit], max_tokens: int) -> tuple[list[MemoryHit], bool]:
156 """Truncate hit list to fit within token budget.
158 Returns (truncated_hits, was_truncated).
159 Sync, pure computation — Rust migration candidate.
160 """
161 result: list[MemoryHit] = []
162 total = 0
163 truncated = False
165 for hit in hits:
166 tokens = count_tokens(hit.text)
167 if total + tokens > max_tokens:
168 truncated = True
169 break
170 result.append(hit)
171 total += tokens
173 return result, truncated
176# ---------------------------------------------------------------------------
177# Quota tracker (daily counters)
178# ---------------------------------------------------------------------------
181class QuotaTracker:
182 """Per-bank daily quota tracking.
184 Sync, self-contained — Rust migration candidate.
185 Resets are time-based (tracks day boundaries).
186 """
188 _MAX_COUNTERS = 10000
190 def __init__(self) -> None:
191 # {bank_id:operation -> (count, day_number)}
192 self._counters: dict[str, tuple[int, int]] = {}
194 @staticmethod
195 def _today() -> int:
196 return int(time.time() // 86400)
198 def check(self, bank_id: str, operation: str, limit: int | None) -> bool:
199 """Check if quota allows the operation. Returns True if allowed."""
200 if limit is None:
201 return True
203 key = f"{bank_id}:{operation}"
204 today = self._today()
206 count, day = self._counters.get(key, (0, today))
207 if day != today:
208 count = 0 # Reset on new day
210 return count < limit
212 def record(self, bank_id: str, operation: str) -> None:
213 """Record one operation against the quota."""
214 key = f"{bank_id}:{operation}"
215 today = self._today()
217 count, day = self._counters.get(key, (0, today))
218 if day != today:
219 count = 0
220 day = today
222 self._counters[key] = (count + 1, day)
224 # Evict stale entries if over capacity
225 if len(self._counters) > self._MAX_COUNTERS:
226 stale = [k for k, (_, d) in self._counters.items() if d != today]
227 for k in stale:
228 del self._counters[k]
230 def get_count(self, bank_id: str, operation: str) -> int:
231 """Get current count for a bank+operation."""
232 key = f"{bank_id}:{operation}"
233 today = self._today()
234 count, day = self._counters.get(key, (0, today))
235 if day != today:
236 return 0
237 return count