Coverage for astrocyte/_policy.py: 96%
110 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Policy enforcer — access control, rate limiting, PII scanning, input validation."""
3from __future__ import annotations
5import threading
6from typing import TYPE_CHECKING
8from astrocyte._validation import validate_bank_id
9from astrocyte.config import AstrocyteConfig
10from astrocyte.errors import AccessDenied, ConfigError, RateLimited
11from astrocyte.identity import (
12 BankResolver,
13 accessible_read_banks,
14 context_principal_label,
15 effective_permissions,
16)
17from astrocyte.policy.barriers import ContentValidator, MetadataSanitizer, PiiScanner
18from astrocyte.policy.escalation import CircuitBreaker, DegradedModeHandler
19from astrocyte.policy.homeostasis import QuotaTracker, RateLimiter
20from astrocyte.types import (
21 AccessGrant,
22 AstrocyteContext,
23 PiiMatch,
24 RecallResult,
25 RetainResult,
26)
28if TYPE_CHECKING:
29 from astrocyte.types import Metadata
32class PolicyEnforcer:
33 """Centralizes all policy enforcement: access control, rate limiting,
34 PII scanning, content validation, metadata sanitization, circuit breaking,
35 and input validation.
36 """
38 def __init__(self, config: AstrocyteConfig) -> None:
39 self._config = config
41 # PII scanner
42 self._pii_scanner = PiiScanner(
43 mode=config.barriers.pii.mode,
44 action=config.barriers.pii.action,
45 countries=config.barriers.pii.countries,
46 type_overrides=config.barriers.pii.type_overrides,
47 )
49 # Content validation
50 self._content_validator = ContentValidator(
51 max_content_length=config.barriers.validation.max_content_length,
52 reject_empty=config.barriers.validation.reject_empty_content,
53 allowed_content_types=config.barriers.validation.allowed_content_types,
54 )
56 # Metadata sanitization
57 self._metadata_sanitizer = MetadataSanitizer(
58 blocked_keys=config.barriers.metadata.blocked_keys,
59 max_size_bytes=config.barriers.metadata.max_metadata_size_bytes,
60 )
62 # Rate limiters (per operation)
63 self._rate_limiters: dict[str, RateLimiter] = {}
64 rl = config.homeostasis.rate_limits
65 if rl.retain_per_minute:
66 self._rate_limiters["retain"] = RateLimiter(rl.retain_per_minute)
67 if rl.recall_per_minute:
68 self._rate_limiters["recall"] = RateLimiter(rl.recall_per_minute)
69 if rl.reflect_per_minute:
70 self._rate_limiters["reflect"] = RateLimiter(rl.reflect_per_minute)
72 # Quota tracker
73 self._quota_tracker = QuotaTracker()
74 self._quota_limits: dict[str, int | None] = {
75 "retain": config.homeostasis.quotas.retain_per_day,
76 "reflect": config.homeostasis.quotas.reflect_per_day,
77 }
79 # Atomic lock for rate + quota checks
80 self._rate_quota_lock = threading.Lock()
82 # Circuit breaker
83 cb = config.escalation.circuit_breaker
84 self._circuit_breaker = CircuitBreaker(
85 failure_threshold=cb.failure_threshold,
86 recovery_timeout_seconds=cb.recovery_timeout_seconds,
87 half_open_max_calls=cb.half_open_max_calls,
88 )
89 self._degraded_handler = DegradedModeHandler(mode=config.escalation.degraded_mode)
91 # Access control grants
92 self._access_grants: list[AccessGrant] = []
94 # -- Access grants --
96 def set_access_grants(self, grants: list[AccessGrant]) -> None:
97 """Configure access grants."""
98 self._access_grants = grants
100 @property
101 def access_grants(self) -> list[AccessGrant]:
102 return self._access_grants
104 # -- Access control --
106 def check_access(self, bank_id: str, permission: str, context: AstrocyteContext | None) -> None:
107 """Check access control. Raises AccessDenied if denied."""
108 if not self._config.access_control.enabled:
109 return
110 if context is None:
111 if self._config.access_control.default_policy == "open":
112 return
113 raise AccessDenied("anonymous", bank_id, permission)
115 eff = effective_permissions(context, self._access_grants, bank_id)
116 if permission in eff:
117 return
119 if self._config.access_control.default_policy == "open":
120 return
122 raise AccessDenied(context_principal_label(context), bank_id, permission)
124 # -- Bank resolution --
126 def make_bank_resolver(self) -> BankResolver:
127 i = self._config.identity
128 return BankResolver(
129 user_prefix=i.user_bank_prefix,
130 agent_prefix=i.agent_bank_prefix,
131 service_prefix=i.service_bank_prefix,
132 )
134 def resolve_read_bank_ids(
135 self,
136 bank_id: str | None,
137 banks: list[str] | None,
138 context: AstrocyteContext | None,
139 ) -> list[str]:
140 """Resolve bank list for recall/reflect; optional identity-driven auto-resolve."""
141 bank_ids = banks or ([bank_id] if bank_id else [])
142 if not bank_ids and self._config.identity.auto_resolve_banks and context is not None:
143 known = list((self._config.banks or {}).keys())
144 bank_ids = accessible_read_banks(
145 context,
146 self._access_grants,
147 known_bank_ids=known or None,
148 resolver=self.make_bank_resolver(),
149 )
150 if not bank_ids:
151 raise ConfigError("Either bank_id or banks must be provided")
152 for bid in bank_ids:
153 validate_bank_id(bid)
154 return bank_ids
156 # -- Rate limiting + quota --
158 def check_rate_and_quota(self, bank_id: str, operation: str) -> None:
159 """Atomically check rate limit and quota under a shared lock."""
160 with self._rate_quota_lock:
161 self._check_rate_limit(bank_id, operation)
162 self._check_quota(bank_id, operation)
164 def _check_rate_limit(self, bank_id: str, operation: str) -> None:
165 limiter = self._rate_limiters.get(operation)
166 if limiter:
167 limiter.check_and_record(bank_id, operation)
169 def _check_quota(self, bank_id: str, operation: str) -> None:
170 limit = self._quota_limits.get(operation)
171 if not self._quota_tracker.check(bank_id, operation, limit):
172 raise RateLimited(bank_id=bank_id, operation=operation)
174 def record_quota(self, bank_id: str, operation: str) -> None:
175 """Record a successful operation against the quota tracker."""
176 self._quota_tracker.record(bank_id, operation)
178 # -- PII scanning --
180 async def scan_pii(self, content: str, mode: str) -> tuple[str, list[PiiMatch]]:
181 """Scan content for PII. Returns (possibly redacted content, matches)."""
182 if mode in ("llm", "rules_then_llm"):
183 return await self._pii_scanner.apply_async(content)
184 return self._pii_scanner.apply(content)
186 @property
187 def pii_action(self) -> str:
188 return self._pii_scanner.action
190 def scan_pii_output(self, text: str) -> list[PiiMatch]:
191 """Scan text for PII matches (for DLP output scanning)."""
192 return self._pii_scanner.scan(text)
194 # -- Content validation --
196 def validate_content(self, content: str, content_type: str) -> list[str]:
197 """Validate content. Returns list of error strings (empty = valid)."""
198 return self._content_validator.validate(content, content_type)
200 # -- Metadata sanitization --
202 def sanitize_metadata(self, metadata: "Metadata | None") -> tuple["Metadata | None", list[str]]:
203 """Sanitize metadata. Returns (sanitized metadata, warnings)."""
204 return self._metadata_sanitizer.sanitize(metadata)
206 # -- Circuit breaker --
208 def check_circuit(self, provider_name: str) -> None:
209 """Check circuit breaker. Raises ProviderUnavailable if open."""
210 self._circuit_breaker.check(provider_name)
212 def record_success(self) -> None:
213 self._circuit_breaker.record_success()
215 def record_failure(self) -> None:
216 self._circuit_breaker.record_failure()
218 def handle_degraded_retain(self, provider_name: str) -> RetainResult:
219 self._degraded_handler.handle_retain(provider_name)
220 return RetainResult(stored=False, error="Provider unavailable (degraded mode)")
222 def handle_degraded_recall(self, provider_name: str) -> RecallResult:
223 return self._degraded_handler.handle_recall(provider_name)
225 # -- Input validation --
227 def validate_retain_input(self, content: str, tags: list[str] | None) -> str | None:
228 """Validate retain input sizes. Returns error string or None if valid."""
229 max_content_bytes = self._config.homeostasis.retain_max_content_bytes
230 if max_content_bytes:
231 size = len(content.encode("utf-8"))
232 if size > max_content_bytes:
233 return f"Content exceeds maximum size ({size} > {max_content_bytes} bytes)"
234 if tags and len(tags) > 100:
235 return f"Too many tags ({len(tags)} > 100)"
236 return None