Coverage for astrocyte/_policy.py: 96%

110 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Policy enforcer — access control, rate limiting, PII scanning, input validation.""" 

2 

3from __future__ import annotations 

4 

5import threading 

6from typing import TYPE_CHECKING 

7 

8from astrocyte._validation import validate_bank_id 

9from astrocyte.config import AstrocyteConfig 

10from astrocyte.errors import AccessDenied, ConfigError, RateLimited 

11from astrocyte.identity import ( 

12 BankResolver, 

13 accessible_read_banks, 

14 context_principal_label, 

15 effective_permissions, 

16) 

17from astrocyte.policy.barriers import ContentValidator, MetadataSanitizer, PiiScanner 

18from astrocyte.policy.escalation import CircuitBreaker, DegradedModeHandler 

19from astrocyte.policy.homeostasis import QuotaTracker, RateLimiter 

20from astrocyte.types import ( 

21 AccessGrant, 

22 AstrocyteContext, 

23 PiiMatch, 

24 RecallResult, 

25 RetainResult, 

26) 

27 

28if TYPE_CHECKING: 

29 from astrocyte.types import Metadata 

30 

31 

32class PolicyEnforcer: 

33 """Centralizes all policy enforcement: access control, rate limiting, 

34 PII scanning, content validation, metadata sanitization, circuit breaking, 

35 and input validation. 

36 """ 

37 

38 def __init__(self, config: AstrocyteConfig) -> None: 

39 self._config = config 

40 

41 # PII scanner 

42 self._pii_scanner = PiiScanner( 

43 mode=config.barriers.pii.mode, 

44 action=config.barriers.pii.action, 

45 countries=config.barriers.pii.countries, 

46 type_overrides=config.barriers.pii.type_overrides, 

47 ) 

48 

49 # Content validation 

50 self._content_validator = ContentValidator( 

51 max_content_length=config.barriers.validation.max_content_length, 

52 reject_empty=config.barriers.validation.reject_empty_content, 

53 allowed_content_types=config.barriers.validation.allowed_content_types, 

54 ) 

55 

56 # Metadata sanitization 

57 self._metadata_sanitizer = MetadataSanitizer( 

58 blocked_keys=config.barriers.metadata.blocked_keys, 

59 max_size_bytes=config.barriers.metadata.max_metadata_size_bytes, 

60 ) 

61 

62 # Rate limiters (per operation) 

63 self._rate_limiters: dict[str, RateLimiter] = {} 

64 rl = config.homeostasis.rate_limits 

65 if rl.retain_per_minute: 

66 self._rate_limiters["retain"] = RateLimiter(rl.retain_per_minute) 

67 if rl.recall_per_minute: 

68 self._rate_limiters["recall"] = RateLimiter(rl.recall_per_minute) 

69 if rl.reflect_per_minute: 

70 self._rate_limiters["reflect"] = RateLimiter(rl.reflect_per_minute) 

71 

72 # Quota tracker 

73 self._quota_tracker = QuotaTracker() 

74 self._quota_limits: dict[str, int | None] = { 

75 "retain": config.homeostasis.quotas.retain_per_day, 

76 "reflect": config.homeostasis.quotas.reflect_per_day, 

77 } 

78 

79 # Atomic lock for rate + quota checks 

80 self._rate_quota_lock = threading.Lock() 

81 

82 # Circuit breaker 

83 cb = config.escalation.circuit_breaker 

84 self._circuit_breaker = CircuitBreaker( 

85 failure_threshold=cb.failure_threshold, 

86 recovery_timeout_seconds=cb.recovery_timeout_seconds, 

87 half_open_max_calls=cb.half_open_max_calls, 

88 ) 

89 self._degraded_handler = DegradedModeHandler(mode=config.escalation.degraded_mode) 

90 

91 # Access control grants 

92 self._access_grants: list[AccessGrant] = [] 

93 

94 # -- Access grants -- 

95 

96 def set_access_grants(self, grants: list[AccessGrant]) -> None: 

97 """Configure access grants.""" 

98 self._access_grants = grants 

99 

100 @property 

101 def access_grants(self) -> list[AccessGrant]: 

102 return self._access_grants 

103 

104 # -- Access control -- 

105 

106 def check_access(self, bank_id: str, permission: str, context: AstrocyteContext | None) -> None: 

107 """Check access control. Raises AccessDenied if denied.""" 

108 if not self._config.access_control.enabled: 

109 return 

110 if context is None: 

111 if self._config.access_control.default_policy == "open": 

112 return 

113 raise AccessDenied("anonymous", bank_id, permission) 

114 

115 eff = effective_permissions(context, self._access_grants, bank_id) 

116 if permission in eff: 

117 return 

118 

119 if self._config.access_control.default_policy == "open": 

120 return 

121 

122 raise AccessDenied(context_principal_label(context), bank_id, permission) 

123 

124 # -- Bank resolution -- 

125 

126 def make_bank_resolver(self) -> BankResolver: 

127 i = self._config.identity 

128 return BankResolver( 

129 user_prefix=i.user_bank_prefix, 

130 agent_prefix=i.agent_bank_prefix, 

131 service_prefix=i.service_bank_prefix, 

132 ) 

133 

134 def resolve_read_bank_ids( 

135 self, 

136 bank_id: str | None, 

137 banks: list[str] | None, 

138 context: AstrocyteContext | None, 

139 ) -> list[str]: 

140 """Resolve bank list for recall/reflect; optional identity-driven auto-resolve.""" 

141 bank_ids = banks or ([bank_id] if bank_id else []) 

142 if not bank_ids and self._config.identity.auto_resolve_banks and context is not None: 

143 known = list((self._config.banks or {}).keys()) 

144 bank_ids = accessible_read_banks( 

145 context, 

146 self._access_grants, 

147 known_bank_ids=known or None, 

148 resolver=self.make_bank_resolver(), 

149 ) 

150 if not bank_ids: 

151 raise ConfigError("Either bank_id or banks must be provided") 

152 for bid in bank_ids: 

153 validate_bank_id(bid) 

154 return bank_ids 

155 

156 # -- Rate limiting + quota -- 

157 

158 def check_rate_and_quota(self, bank_id: str, operation: str) -> None: 

159 """Atomically check rate limit and quota under a shared lock.""" 

160 with self._rate_quota_lock: 

161 self._check_rate_limit(bank_id, operation) 

162 self._check_quota(bank_id, operation) 

163 

164 def _check_rate_limit(self, bank_id: str, operation: str) -> None: 

165 limiter = self._rate_limiters.get(operation) 

166 if limiter: 

167 limiter.check_and_record(bank_id, operation) 

168 

169 def _check_quota(self, bank_id: str, operation: str) -> None: 

170 limit = self._quota_limits.get(operation) 

171 if not self._quota_tracker.check(bank_id, operation, limit): 

172 raise RateLimited(bank_id=bank_id, operation=operation) 

173 

174 def record_quota(self, bank_id: str, operation: str) -> None: 

175 """Record a successful operation against the quota tracker.""" 

176 self._quota_tracker.record(bank_id, operation) 

177 

178 # -- PII scanning -- 

179 

180 async def scan_pii(self, content: str, mode: str) -> tuple[str, list[PiiMatch]]: 

181 """Scan content for PII. Returns (possibly redacted content, matches).""" 

182 if mode in ("llm", "rules_then_llm"): 

183 return await self._pii_scanner.apply_async(content) 

184 return self._pii_scanner.apply(content) 

185 

186 @property 

187 def pii_action(self) -> str: 

188 return self._pii_scanner.action 

189 

190 def scan_pii_output(self, text: str) -> list[PiiMatch]: 

191 """Scan text for PII matches (for DLP output scanning).""" 

192 return self._pii_scanner.scan(text) 

193 

194 # -- Content validation -- 

195 

196 def validate_content(self, content: str, content_type: str) -> list[str]: 

197 """Validate content. Returns list of error strings (empty = valid).""" 

198 return self._content_validator.validate(content, content_type) 

199 

200 # -- Metadata sanitization -- 

201 

202 def sanitize_metadata(self, metadata: "Metadata | None") -> tuple["Metadata | None", list[str]]: 

203 """Sanitize metadata. Returns (sanitized metadata, warnings).""" 

204 return self._metadata_sanitizer.sanitize(metadata) 

205 

206 # -- Circuit breaker -- 

207 

208 def check_circuit(self, provider_name: str) -> None: 

209 """Check circuit breaker. Raises ProviderUnavailable if open.""" 

210 self._circuit_breaker.check(provider_name) 

211 

212 def record_success(self) -> None: 

213 self._circuit_breaker.record_success() 

214 

215 def record_failure(self) -> None: 

216 self._circuit_breaker.record_failure() 

217 

218 def handle_degraded_retain(self, provider_name: str) -> RetainResult: 

219 self._degraded_handler.handle_retain(provider_name) 

220 return RetainResult(stored=False, error="Provider unavailable (degraded mode)") 

221 

222 def handle_degraded_recall(self, provider_name: str) -> RecallResult: 

223 return self._degraded_handler.handle_recall(provider_name) 

224 

225 # -- Input validation -- 

226 

227 def validate_retain_input(self, content: str, tags: list[str] | None) -> str | None: 

228 """Validate retain input sizes. Returns error string or None if valid.""" 

229 max_content_bytes = self._config.homeostasis.retain_max_content_bytes 

230 if max_content_bytes: 

231 size = len(content.encode("utf-8")) 

232 if size > max_content_bytes: 

233 return f"Content exceeds maximum size ({size} > {max_content_bytes} bytes)" 

234 if tags and len(tags) > 100: 

235 return f"Too many tags ({len(tags)} > 100)" 

236 return None