Coverage for astrocyte/mip/loader.py: 91%
231 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""MIP config loader — YAML loading, env var substitution, validation."""
3from __future__ import annotations
5from datetime import datetime
6from pathlib import Path
8import yaml
10from astrocyte.config import _substitute_env_recursive
11from astrocyte.errors import ConfigError
12from astrocyte.mip.presets import (
13 expand_forget_preset,
14 expand_preset,
15 is_known_forget_preset,
16 is_known_preset,
17 list_forget_presets,
18 list_presets,
19)
20from astrocyte.mip.schema import (
21 ActionSpec,
22 BankDefinition,
23 ChunkerSpec,
24 DedupSpec,
25 EscalationCondition,
26 ForgetSpec,
27 IntentPolicy,
28 MatchBlock,
29 MatchSpec,
30 MipConfig,
31 PipelineSpec,
32 ReflectSpec,
33 RerankSpec,
34 RoutingRule,
35)
37# Recognised sub-keys for the pipeline block. Unknown keys at any level emit
38# warnings during load (forward-compat: vocabulary may grow).
39_PIPELINE_KEYS = {
40 "version",
41 "preset",
42 "chunker",
43 "dedup",
44 "rerank",
45 "reflect",
46 "temporal_half_life_days",
47}
48_CHUNKER_KEYS = {"strategy", "max_size", "overlap"}
49_DEDUP_KEYS = {"threshold", "action"}
50_RERANK_KEYS = {"keyword_weight", "proper_noun_weight"}
51_REFLECT_KEYS = {"prompt", "promote_metadata"}
52_FORGET_KEYS = {
53 "version",
54 "preset",
55 "mode",
56 "audit",
57 "cascade",
58 "respect_legal_hold",
59 "min_age_days",
60 "max_per_call",
61}
62_FORGET_MODES = {"soft", "hard", "tombstone"}
63_FORGET_AUDIT = {"none", "recommended", "required"}
64_TIE_BREAKERS = {"first", "error", "most_specific"}
66# P4: hard cap on metadata fields promoted into reflect prompt
67_PROMOTE_METADATA_MAX = 5
70def load_mip_config(path: str | Path) -> MipConfig:
71 """Load mip.yaml, substitute env vars, validate, return MipConfig."""
72 config_path = Path(path)
73 if not config_path.exists():
74 raise ConfigError(f"MIP config file not found: {config_path}")
76 with open(config_path) as f:
77 raw = yaml.safe_load(f) or {}
79 raw = _substitute_env_recursive(raw)
80 return _parse_mip_config(raw)
83def _parse_mip_config(data: dict) -> MipConfig:
84 """Parse a YAML dict into MipConfig."""
85 config = MipConfig(version=data.get("version", "1.0"))
87 # Phase 5: tie-breaker policy
88 tie_breaker = data.get("tie_breaker", "first")
89 if tie_breaker not in _TIE_BREAKERS:
90 raise ConfigError(f"tie_breaker must be one of {sorted(_TIE_BREAKERS)} (got {tie_breaker!r})")
91 config.tie_breaker = tie_breaker
93 # Banks
94 if "banks" in data and data["banks"]:
95 config.banks = [
96 BankDefinition(
97 id=b["id"],
98 description=b.get("description"),
99 access=b.get("access"),
100 compliance=b.get("compliance"),
101 )
102 for b in data["banks"]
103 ]
105 # Rules
106 if "rules" in data and data["rules"]:
107 config.rules = [_parse_rule(r) for r in data["rules"]]
109 # Intent policy
110 if "intent_policy" in data and data["intent_policy"]:
111 ip = data["intent_policy"]
112 escalate_when = None
113 if "escalate_when" in ip and ip["escalate_when"]:
114 escalate_when = []
115 for ew in ip["escalate_when"]:
116 if isinstance(ew, dict):
117 for key, val in ew.items():
118 if isinstance(val, dict):
119 for op, v in val.items():
120 escalate_when.append(EscalationCondition(condition=key, operator=op, value=v))
121 else:
122 escalate_when.append(EscalationCondition(condition=key, operator="eq", value=val))
124 config.intent_policy = IntentPolicy(
125 escalate_when=escalate_when,
126 model_context=ip.get("model_context"),
127 constraints=ip.get("constraints"),
128 )
130 errors = _validate_mip_config(config)
131 if errors:
132 raise ConfigError(f"MIP config validation errors: {'; '.join(errors)}")
134 return config
137def _parse_rule(data: dict) -> RoutingRule:
138 """Parse a single rule dict into RoutingRule."""
139 match_data = data.get("match", {})
140 action_data = data.get("action", {})
141 rule_name = data.get("name", "<unnamed>")
143 # Phase 5 — shadow / activation window / observability tags
144 shadow = data.get("shadow", False)
145 if not isinstance(shadow, bool):
146 raise ConfigError(f"Rule '{rule_name}': shadow must be a boolean")
148 active_from = _parse_iso_datetime(data.get("active_from"), rule_name, "active_from")
149 active_until = _parse_iso_datetime(data.get("active_until"), rule_name, "active_until")
150 if active_from and active_until and active_until <= active_from:
151 raise ConfigError(f"Rule '{rule_name}': active_until must be strictly after active_from")
153 obs_tags = data.get("observability_tags")
154 if obs_tags is not None:
155 if not isinstance(obs_tags, list) or not all(isinstance(t, str) for t in obs_tags):
156 raise ConfigError(f"Rule '{rule_name}': observability_tags must be a list of strings")
158 return RoutingRule(
159 name=data["name"],
160 priority=data.get("priority", 100),
161 match=_parse_match_block(match_data),
162 action=_parse_action(action_data, rule_name=rule_name, match_data=match_data),
163 override=data.get("override", False),
164 shadow=shadow,
165 active_from=active_from,
166 active_until=active_until,
167 observability_tags=obs_tags,
168 )
171def _parse_iso_datetime(value: object, rule_name: str, field: str) -> datetime | None:
172 """Parse an ISO 8601 datetime string into a datetime; passthrough None.
174 Naive timestamps are interpreted as UTC so comparisons are unambiguous.
175 """
176 from datetime import timezone
178 if value is None:
179 return None
180 if isinstance(value, datetime):
181 dt = value
182 elif isinstance(value, str):
183 try:
184 dt = datetime.fromisoformat(value)
185 except ValueError as exc:
186 raise ConfigError(f"Rule '{rule_name}': {field} must be ISO 8601 datetime (got {value!r}): {exc}") from exc
187 else:
188 raise ConfigError(f"Rule '{rule_name}': {field} must be a datetime string (got {type(value).__name__})")
189 if dt.tzinfo is None:
190 dt = dt.replace(tzinfo=timezone.utc)
191 return dt
194def _parse_match_block(data: dict) -> MatchBlock:
195 """Parse a YAML match block into MatchBlock."""
196 all_conditions = None
197 any_conditions = None
198 none_conditions = None
200 if "all" in data:
201 all_conditions = [_parse_match_spec(s) for s in data["all"]] if data["all"] else []
202 if "any" in data:
203 any_conditions = [_parse_match_spec(s) for s in data["any"]] if data["any"] else []
204 if "none" in data:
205 none_conditions = [_parse_match_spec(s) for s in data["none"]] if data["none"] else []
207 # Single-level shorthand: {"content_type": "student_answer", "pii_detected": true}
208 if not all_conditions and not any_conditions and not none_conditions:
209 specs = []
210 for key, val in data.items():
211 if key in ("all", "any", "none"):
212 continue
213 specs.append(_parse_match_spec({key: val}))
214 if specs:
215 all_conditions = specs
217 return MatchBlock(
218 all_conditions=all_conditions,
219 any_conditions=any_conditions,
220 none_conditions=none_conditions,
221 )
224def _parse_match_spec(data: dict) -> MatchSpec:
225 """Parse a single match spec dict into MatchSpec."""
226 for field_name, value in data.items():
227 if isinstance(value, dict):
228 # Operator form: {"metadata.count": {"gte": 5}}
229 for op, v in value.items():
230 return MatchSpec(field=field_name, operator=op, value=v)
231 elif value == "present":
232 return MatchSpec(field=field_name, operator="present")
233 elif value == "absent":
234 return MatchSpec(field=field_name, operator="absent")
235 else:
236 # Simple equality: {"content_type": "student_answer"}
237 return MatchSpec(field=field_name, operator="eq", value=value)
238 raise ConfigError("Empty match spec")
241def _parse_action(
242 data: dict,
243 rule_name: str = "<unnamed>",
244 match_data: dict | None = None,
245) -> ActionSpec:
246 """Parse a YAML action block into ActionSpec.
248 `rule_name` and `match_data` are used for guardrail diagnostics on the
249 optional `pipeline:` sub-block (P2/P4/P5).
250 """
251 pipeline_data = data.get("pipeline")
252 pipeline = (
253 _parse_pipeline(pipeline_data, rule_name=rule_name, match_data=match_data or {})
254 if pipeline_data is not None
255 else None
256 )
258 forget_data = data.get("forget")
259 forget = _parse_forget(forget_data, rule_name=rule_name) if forget_data is not None else None
261 return ActionSpec(
262 bank=data.get("bank"),
263 tags=data.get("tags"),
264 retain_policy=data.get("retain_policy"),
265 escalate=data.get("escalate"),
266 confidence=data.get("confidence", 1.0),
267 pipeline=pipeline,
268 forget=forget,
269 )
272def _parse_forget(data: dict, rule_name: str) -> ForgetSpec:
273 """Parse and validate an action.forget sub-block (Phase 4).
275 Enforces:
276 - P2: ``version`` is required
277 - ``mode`` ∈ {soft, hard, tombstone}
278 - ``audit`` ∈ {none, recommended, required}
279 - ``min_age_days``, ``max_per_call`` non-negative ints
280 - ``mode: hard`` requires ``audit: required`` (compliance discipline)
281 - Unknown preset names error with a list of valid presets
282 """
283 if not isinstance(data, dict):
284 raise ConfigError(f"Rule '{rule_name}': forget must be a mapping")
286 _warn_unknown_keys(data, _FORGET_KEYS, f"rule '{rule_name}' forget")
288 version = data.get("version")
289 if version is None:
290 raise ConfigError(f"Rule '{rule_name}': forget.version is required when forget block is set")
291 if not isinstance(version, int):
292 raise ConfigError(f"Rule '{rule_name}': forget.version must be an integer (got {type(version).__name__})")
294 preset = data.get("preset")
295 if preset is not None and not is_known_forget_preset(preset):
296 raise ConfigError(
297 f"Rule '{rule_name}': unknown forget preset '{preset}' (known: {', '.join(list_forget_presets())})"
298 )
300 mode = data.get("mode")
301 if mode is not None and mode not in _FORGET_MODES:
302 raise ConfigError(f"Rule '{rule_name}': forget.mode must be one of {sorted(_FORGET_MODES)} (got {mode!r})")
304 audit = data.get("audit")
305 if audit is not None and audit not in _FORGET_AUDIT:
306 raise ConfigError(f"Rule '{rule_name}': forget.audit must be one of {sorted(_FORGET_AUDIT)} (got {audit!r})")
308 min_age = data.get("min_age_days")
309 if min_age is not None and (not isinstance(min_age, int) or min_age < 0):
310 raise ConfigError(f"Rule '{rule_name}': forget.min_age_days must be a non-negative int")
312 max_per_call = data.get("max_per_call")
313 if max_per_call is not None and (not isinstance(max_per_call, int) or max_per_call <= 0):
314 raise ConfigError(f"Rule '{rule_name}': forget.max_per_call must be a positive int")
316 spec = ForgetSpec(
317 version=version,
318 preset=preset,
319 mode=mode,
320 audit=audit,
321 cascade=data.get("cascade"),
322 respect_legal_hold=data.get("respect_legal_hold"),
323 min_age_days=min_age,
324 max_per_call=max_per_call,
325 )
326 resolved = expand_forget_preset(spec)
328 # Compliance discipline: hard delete demands audit. Check after preset
329 # expansion so the gdpr preset (mode=hard, audit=required) passes cleanly.
330 if resolved.mode == "hard" and resolved.audit != "required":
331 raise ConfigError(f"Rule '{rule_name}': forget.mode='hard' requires forget.audit='required'")
333 return resolved
336def _parse_pipeline(
337 data: dict,
338 rule_name: str,
339 match_data: dict,
340) -> PipelineSpec:
341 """Parse and validate a pipeline action sub-block.
343 Enforces guardrails:
344 - P2: version is required when any pipeline field is set
345 - P4: reflect.promote_metadata capped at 5 fields
346 - P5: pipeline fields require content_type in match block
348 Unknown keys at any level emit warnings (forward-compatible vocabulary).
349 Presets are expanded at load time so downstream code never sees `preset`.
350 """
351 if not isinstance(data, dict):
352 raise ConfigError(f"Rule '{rule_name}': pipeline must be a mapping")
354 _warn_unknown_keys(data, _PIPELINE_KEYS, f"rule '{rule_name}' pipeline")
356 # P2: version required
357 version = data.get("version")
358 if version is None:
359 raise ConfigError(f"Rule '{rule_name}': pipeline.version is required when pipeline block is set")
360 if not isinstance(version, int):
361 raise ConfigError(f"Rule '{rule_name}': pipeline.version must be an integer (got {type(version).__name__})")
363 # P5: content_type must be referenced in match block
364 if not _match_references_content_type(match_data):
365 raise ConfigError(f"Rule '{rule_name}': pipeline fields require 'content_type' in the match block")
367 preset = data.get("preset")
368 if preset is not None and not is_known_preset(preset):
369 raise ConfigError(f"Rule '{rule_name}': unknown preset '{preset}' (known: {', '.join(list_presets())})")
371 half_life = data.get("temporal_half_life_days")
372 if half_life is not None:
373 if not isinstance(half_life, (int, float)) or half_life <= 0:
374 raise ConfigError(
375 f"Rule '{rule_name}': pipeline.temporal_half_life_days must be a positive number (got {half_life!r})",
376 )
378 spec = PipelineSpec(
379 version=version,
380 preset=preset,
381 chunker=_parse_chunker(data.get("chunker"), rule_name),
382 dedup=_parse_dedup(data.get("dedup"), rule_name),
383 rerank=_parse_rerank(data.get("rerank"), rule_name),
384 reflect=_parse_reflect(data.get("reflect"), rule_name),
385 temporal_half_life_days=float(half_life) if half_life is not None else None,
386 )
388 return expand_preset(spec)
391def _parse_chunker(data: dict | None, rule_name: str) -> ChunkerSpec | None:
392 if data is None:
393 return None
394 if not isinstance(data, dict):
395 raise ConfigError(f"Rule '{rule_name}': pipeline.chunker must be a mapping")
396 _warn_unknown_keys(data, _CHUNKER_KEYS, f"rule '{rule_name}' pipeline.chunker")
397 return ChunkerSpec(
398 strategy=data.get("strategy"),
399 max_size=data.get("max_size"),
400 overlap=data.get("overlap"),
401 )
404def _parse_dedup(data: dict | None, rule_name: str) -> DedupSpec | None:
405 if data is None:
406 return None
407 if not isinstance(data, dict):
408 raise ConfigError(f"Rule '{rule_name}': pipeline.dedup must be a mapping")
409 _warn_unknown_keys(data, _DEDUP_KEYS, f"rule '{rule_name}' pipeline.dedup")
410 return DedupSpec(
411 threshold=data.get("threshold"),
412 action=data.get("action"),
413 )
416def _parse_rerank(data: dict | None, rule_name: str) -> RerankSpec | None:
417 if data is None:
418 return None
419 if not isinstance(data, dict):
420 raise ConfigError(f"Rule '{rule_name}': pipeline.rerank must be a mapping")
421 _warn_unknown_keys(data, _RERANK_KEYS, f"rule '{rule_name}' pipeline.rerank")
422 return RerankSpec(
423 keyword_weight=data.get("keyword_weight"),
424 proper_noun_weight=data.get("proper_noun_weight"),
425 )
428def _parse_reflect(data: dict | None, rule_name: str) -> ReflectSpec | None:
429 if data is None:
430 return None
431 if not isinstance(data, dict):
432 raise ConfigError(f"Rule '{rule_name}': pipeline.reflect must be a mapping")
433 _warn_unknown_keys(data, _REFLECT_KEYS, f"rule '{rule_name}' pipeline.reflect")
434 promote = data.get("promote_metadata")
435 if promote is not None:
436 if not isinstance(promote, list):
437 raise ConfigError(f"Rule '{rule_name}': pipeline.reflect.promote_metadata must be a list")
438 # P4: hard cap
439 if len(promote) > _PROMOTE_METADATA_MAX:
440 raise ConfigError(
441 f"Rule '{rule_name}': pipeline.reflect.promote_metadata "
442 f"capped at {_PROMOTE_METADATA_MAX} fields (got {len(promote)})"
443 )
444 return ReflectSpec(
445 prompt=data.get("prompt"),
446 promote_metadata=promote,
447 )
450def _match_references_content_type(match_data: dict) -> bool:
451 """True if the match block references content_type at any level."""
452 if not match_data:
453 return False
454 if "content_type" in match_data:
455 return True
456 for key in ("all", "any", "none"):
457 block = match_data.get(key) or []
458 for spec in block:
459 if isinstance(spec, dict) and "content_type" in spec:
460 return True
461 return False
464def _warn_unknown_keys(data: dict, known: set[str], context: str) -> None:
465 import warnings
467 unknown = set(data.keys()) - known
468 if unknown:
469 warnings.warn(
470 f"MIP loader: unknown keys in {context}: {sorted(unknown)}",
471 stacklevel=3,
472 )
475def _validate_mip_config(config: MipConfig) -> list[str]:
476 """Validate internal consistency. Returns list of error messages."""
477 errors: list[str] = []
479 if config.rules:
480 names = [r.name for r in config.rules]
481 if len(names) != len(set(names)):
482 errors.append("Duplicate rule names found")
484 for rule in config.rules:
485 if not rule.name:
486 errors.append("Rule missing 'name'")
487 if rule.action.escalate and rule.override:
488 errors.append(f"Rule '{rule.name}' cannot have both override=true and escalate=mip")
490 return errors