Coverage for astrocyte/mip/loader.py: 91%

231 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""MIP config loader — YAML loading, env var substitution, validation.""" 

2 

3from __future__ import annotations 

4 

5from datetime import datetime 

6from pathlib import Path 

7 

8import yaml 

9 

10from astrocyte.config import _substitute_env_recursive 

11from astrocyte.errors import ConfigError 

12from astrocyte.mip.presets import ( 

13 expand_forget_preset, 

14 expand_preset, 

15 is_known_forget_preset, 

16 is_known_preset, 

17 list_forget_presets, 

18 list_presets, 

19) 

20from astrocyte.mip.schema import ( 

21 ActionSpec, 

22 BankDefinition, 

23 ChunkerSpec, 

24 DedupSpec, 

25 EscalationCondition, 

26 ForgetSpec, 

27 IntentPolicy, 

28 MatchBlock, 

29 MatchSpec, 

30 MipConfig, 

31 PipelineSpec, 

32 ReflectSpec, 

33 RerankSpec, 

34 RoutingRule, 

35) 

36 

37# Recognised sub-keys for the pipeline block. Unknown keys at any level emit 

38# warnings during load (forward-compat: vocabulary may grow). 

39_PIPELINE_KEYS = { 

40 "version", 

41 "preset", 

42 "chunker", 

43 "dedup", 

44 "rerank", 

45 "reflect", 

46 "temporal_half_life_days", 

47} 

48_CHUNKER_KEYS = {"strategy", "max_size", "overlap"} 

49_DEDUP_KEYS = {"threshold", "action"} 

50_RERANK_KEYS = {"keyword_weight", "proper_noun_weight"} 

51_REFLECT_KEYS = {"prompt", "promote_metadata"} 

52_FORGET_KEYS = { 

53 "version", 

54 "preset", 

55 "mode", 

56 "audit", 

57 "cascade", 

58 "respect_legal_hold", 

59 "min_age_days", 

60 "max_per_call", 

61} 

62_FORGET_MODES = {"soft", "hard", "tombstone"} 

63_FORGET_AUDIT = {"none", "recommended", "required"} 

64_TIE_BREAKERS = {"first", "error", "most_specific"} 

65 

66# P4: hard cap on metadata fields promoted into reflect prompt 

67_PROMOTE_METADATA_MAX = 5 

68 

69 

70def load_mip_config(path: str | Path) -> MipConfig: 

71 """Load mip.yaml, substitute env vars, validate, return MipConfig.""" 

72 config_path = Path(path) 

73 if not config_path.exists(): 

74 raise ConfigError(f"MIP config file not found: {config_path}") 

75 

76 with open(config_path) as f: 

77 raw = yaml.safe_load(f) or {} 

78 

79 raw = _substitute_env_recursive(raw) 

80 return _parse_mip_config(raw) 

81 

82 

83def _parse_mip_config(data: dict) -> MipConfig: 

84 """Parse a YAML dict into MipConfig.""" 

85 config = MipConfig(version=data.get("version", "1.0")) 

86 

87 # Phase 5: tie-breaker policy 

88 tie_breaker = data.get("tie_breaker", "first") 

89 if tie_breaker not in _TIE_BREAKERS: 

90 raise ConfigError(f"tie_breaker must be one of {sorted(_TIE_BREAKERS)} (got {tie_breaker!r})") 

91 config.tie_breaker = tie_breaker 

92 

93 # Banks 

94 if "banks" in data and data["banks"]: 

95 config.banks = [ 

96 BankDefinition( 

97 id=b["id"], 

98 description=b.get("description"), 

99 access=b.get("access"), 

100 compliance=b.get("compliance"), 

101 ) 

102 for b in data["banks"] 

103 ] 

104 

105 # Rules 

106 if "rules" in data and data["rules"]: 

107 config.rules = [_parse_rule(r) for r in data["rules"]] 

108 

109 # Intent policy 

110 if "intent_policy" in data and data["intent_policy"]: 

111 ip = data["intent_policy"] 

112 escalate_when = None 

113 if "escalate_when" in ip and ip["escalate_when"]: 

114 escalate_when = [] 

115 for ew in ip["escalate_when"]: 

116 if isinstance(ew, dict): 

117 for key, val in ew.items(): 

118 if isinstance(val, dict): 

119 for op, v in val.items(): 

120 escalate_when.append(EscalationCondition(condition=key, operator=op, value=v)) 

121 else: 

122 escalate_when.append(EscalationCondition(condition=key, operator="eq", value=val)) 

123 

124 config.intent_policy = IntentPolicy( 

125 escalate_when=escalate_when, 

126 model_context=ip.get("model_context"), 

127 constraints=ip.get("constraints"), 

128 ) 

129 

130 errors = _validate_mip_config(config) 

131 if errors: 

132 raise ConfigError(f"MIP config validation errors: {'; '.join(errors)}") 

133 

134 return config 

135 

136 

137def _parse_rule(data: dict) -> RoutingRule: 

138 """Parse a single rule dict into RoutingRule.""" 

139 match_data = data.get("match", {}) 

140 action_data = data.get("action", {}) 

141 rule_name = data.get("name", "<unnamed>") 

142 

143 # Phase 5 — shadow / activation window / observability tags 

144 shadow = data.get("shadow", False) 

145 if not isinstance(shadow, bool): 

146 raise ConfigError(f"Rule '{rule_name}': shadow must be a boolean") 

147 

148 active_from = _parse_iso_datetime(data.get("active_from"), rule_name, "active_from") 

149 active_until = _parse_iso_datetime(data.get("active_until"), rule_name, "active_until") 

150 if active_from and active_until and active_until <= active_from: 

151 raise ConfigError(f"Rule '{rule_name}': active_until must be strictly after active_from") 

152 

153 obs_tags = data.get("observability_tags") 

154 if obs_tags is not None: 

155 if not isinstance(obs_tags, list) or not all(isinstance(t, str) for t in obs_tags): 

156 raise ConfigError(f"Rule '{rule_name}': observability_tags must be a list of strings") 

157 

158 return RoutingRule( 

159 name=data["name"], 

160 priority=data.get("priority", 100), 

161 match=_parse_match_block(match_data), 

162 action=_parse_action(action_data, rule_name=rule_name, match_data=match_data), 

163 override=data.get("override", False), 

164 shadow=shadow, 

165 active_from=active_from, 

166 active_until=active_until, 

167 observability_tags=obs_tags, 

168 ) 

169 

170 

171def _parse_iso_datetime(value: object, rule_name: str, field: str) -> datetime | None: 

172 """Parse an ISO 8601 datetime string into a datetime; passthrough None. 

173 

174 Naive timestamps are interpreted as UTC so comparisons are unambiguous. 

175 """ 

176 from datetime import timezone 

177 

178 if value is None: 

179 return None 

180 if isinstance(value, datetime): 

181 dt = value 

182 elif isinstance(value, str): 

183 try: 

184 dt = datetime.fromisoformat(value) 

185 except ValueError as exc: 

186 raise ConfigError(f"Rule '{rule_name}': {field} must be ISO 8601 datetime (got {value!r}): {exc}") from exc 

187 else: 

188 raise ConfigError(f"Rule '{rule_name}': {field} must be a datetime string (got {type(value).__name__})") 

189 if dt.tzinfo is None: 

190 dt = dt.replace(tzinfo=timezone.utc) 

191 return dt 

192 

193 

194def _parse_match_block(data: dict) -> MatchBlock: 

195 """Parse a YAML match block into MatchBlock.""" 

196 all_conditions = None 

197 any_conditions = None 

198 none_conditions = None 

199 

200 if "all" in data: 

201 all_conditions = [_parse_match_spec(s) for s in data["all"]] if data["all"] else [] 

202 if "any" in data: 

203 any_conditions = [_parse_match_spec(s) for s in data["any"]] if data["any"] else [] 

204 if "none" in data: 

205 none_conditions = [_parse_match_spec(s) for s in data["none"]] if data["none"] else [] 

206 

207 # Single-level shorthand: {"content_type": "student_answer", "pii_detected": true} 

208 if not all_conditions and not any_conditions and not none_conditions: 

209 specs = [] 

210 for key, val in data.items(): 

211 if key in ("all", "any", "none"): 

212 continue 

213 specs.append(_parse_match_spec({key: val})) 

214 if specs: 

215 all_conditions = specs 

216 

217 return MatchBlock( 

218 all_conditions=all_conditions, 

219 any_conditions=any_conditions, 

220 none_conditions=none_conditions, 

221 ) 

222 

223 

224def _parse_match_spec(data: dict) -> MatchSpec: 

225 """Parse a single match spec dict into MatchSpec.""" 

226 for field_name, value in data.items(): 

227 if isinstance(value, dict): 

228 # Operator form: {"metadata.count": {"gte": 5}} 

229 for op, v in value.items(): 

230 return MatchSpec(field=field_name, operator=op, value=v) 

231 elif value == "present": 

232 return MatchSpec(field=field_name, operator="present") 

233 elif value == "absent": 

234 return MatchSpec(field=field_name, operator="absent") 

235 else: 

236 # Simple equality: {"content_type": "student_answer"} 

237 return MatchSpec(field=field_name, operator="eq", value=value) 

238 raise ConfigError("Empty match spec") 

239 

240 

241def _parse_action( 

242 data: dict, 

243 rule_name: str = "<unnamed>", 

244 match_data: dict | None = None, 

245) -> ActionSpec: 

246 """Parse a YAML action block into ActionSpec. 

247 

248 `rule_name` and `match_data` are used for guardrail diagnostics on the 

249 optional `pipeline:` sub-block (P2/P4/P5). 

250 """ 

251 pipeline_data = data.get("pipeline") 

252 pipeline = ( 

253 _parse_pipeline(pipeline_data, rule_name=rule_name, match_data=match_data or {}) 

254 if pipeline_data is not None 

255 else None 

256 ) 

257 

258 forget_data = data.get("forget") 

259 forget = _parse_forget(forget_data, rule_name=rule_name) if forget_data is not None else None 

260 

261 return ActionSpec( 

262 bank=data.get("bank"), 

263 tags=data.get("tags"), 

264 retain_policy=data.get("retain_policy"), 

265 escalate=data.get("escalate"), 

266 confidence=data.get("confidence", 1.0), 

267 pipeline=pipeline, 

268 forget=forget, 

269 ) 

270 

271 

272def _parse_forget(data: dict, rule_name: str) -> ForgetSpec: 

273 """Parse and validate an action.forget sub-block (Phase 4). 

274 

275 Enforces: 

276 - P2: ``version`` is required 

277 - ``mode`` ∈ {soft, hard, tombstone} 

278 - ``audit`` ∈ {none, recommended, required} 

279 - ``min_age_days``, ``max_per_call`` non-negative ints 

280 - ``mode: hard`` requires ``audit: required`` (compliance discipline) 

281 - Unknown preset names error with a list of valid presets 

282 """ 

283 if not isinstance(data, dict): 

284 raise ConfigError(f"Rule '{rule_name}': forget must be a mapping") 

285 

286 _warn_unknown_keys(data, _FORGET_KEYS, f"rule '{rule_name}' forget") 

287 

288 version = data.get("version") 

289 if version is None: 

290 raise ConfigError(f"Rule '{rule_name}': forget.version is required when forget block is set") 

291 if not isinstance(version, int): 

292 raise ConfigError(f"Rule '{rule_name}': forget.version must be an integer (got {type(version).__name__})") 

293 

294 preset = data.get("preset") 

295 if preset is not None and not is_known_forget_preset(preset): 

296 raise ConfigError( 

297 f"Rule '{rule_name}': unknown forget preset '{preset}' (known: {', '.join(list_forget_presets())})" 

298 ) 

299 

300 mode = data.get("mode") 

301 if mode is not None and mode not in _FORGET_MODES: 

302 raise ConfigError(f"Rule '{rule_name}': forget.mode must be one of {sorted(_FORGET_MODES)} (got {mode!r})") 

303 

304 audit = data.get("audit") 

305 if audit is not None and audit not in _FORGET_AUDIT: 

306 raise ConfigError(f"Rule '{rule_name}': forget.audit must be one of {sorted(_FORGET_AUDIT)} (got {audit!r})") 

307 

308 min_age = data.get("min_age_days") 

309 if min_age is not None and (not isinstance(min_age, int) or min_age < 0): 

310 raise ConfigError(f"Rule '{rule_name}': forget.min_age_days must be a non-negative int") 

311 

312 max_per_call = data.get("max_per_call") 

313 if max_per_call is not None and (not isinstance(max_per_call, int) or max_per_call <= 0): 

314 raise ConfigError(f"Rule '{rule_name}': forget.max_per_call must be a positive int") 

315 

316 spec = ForgetSpec( 

317 version=version, 

318 preset=preset, 

319 mode=mode, 

320 audit=audit, 

321 cascade=data.get("cascade"), 

322 respect_legal_hold=data.get("respect_legal_hold"), 

323 min_age_days=min_age, 

324 max_per_call=max_per_call, 

325 ) 

326 resolved = expand_forget_preset(spec) 

327 

328 # Compliance discipline: hard delete demands audit. Check after preset 

329 # expansion so the gdpr preset (mode=hard, audit=required) passes cleanly. 

330 if resolved.mode == "hard" and resolved.audit != "required": 

331 raise ConfigError(f"Rule '{rule_name}': forget.mode='hard' requires forget.audit='required'") 

332 

333 return resolved 

334 

335 

336def _parse_pipeline( 

337 data: dict, 

338 rule_name: str, 

339 match_data: dict, 

340) -> PipelineSpec: 

341 """Parse and validate a pipeline action sub-block. 

342 

343 Enforces guardrails: 

344 - P2: version is required when any pipeline field is set 

345 - P4: reflect.promote_metadata capped at 5 fields 

346 - P5: pipeline fields require content_type in match block 

347 

348 Unknown keys at any level emit warnings (forward-compatible vocabulary). 

349 Presets are expanded at load time so downstream code never sees `preset`. 

350 """ 

351 if not isinstance(data, dict): 

352 raise ConfigError(f"Rule '{rule_name}': pipeline must be a mapping") 

353 

354 _warn_unknown_keys(data, _PIPELINE_KEYS, f"rule '{rule_name}' pipeline") 

355 

356 # P2: version required 

357 version = data.get("version") 

358 if version is None: 

359 raise ConfigError(f"Rule '{rule_name}': pipeline.version is required when pipeline block is set") 

360 if not isinstance(version, int): 

361 raise ConfigError(f"Rule '{rule_name}': pipeline.version must be an integer (got {type(version).__name__})") 

362 

363 # P5: content_type must be referenced in match block 

364 if not _match_references_content_type(match_data): 

365 raise ConfigError(f"Rule '{rule_name}': pipeline fields require 'content_type' in the match block") 

366 

367 preset = data.get("preset") 

368 if preset is not None and not is_known_preset(preset): 

369 raise ConfigError(f"Rule '{rule_name}': unknown preset '{preset}' (known: {', '.join(list_presets())})") 

370 

371 half_life = data.get("temporal_half_life_days") 

372 if half_life is not None: 

373 if not isinstance(half_life, (int, float)) or half_life <= 0: 

374 raise ConfigError( 

375 f"Rule '{rule_name}': pipeline.temporal_half_life_days must be a positive number (got {half_life!r})", 

376 ) 

377 

378 spec = PipelineSpec( 

379 version=version, 

380 preset=preset, 

381 chunker=_parse_chunker(data.get("chunker"), rule_name), 

382 dedup=_parse_dedup(data.get("dedup"), rule_name), 

383 rerank=_parse_rerank(data.get("rerank"), rule_name), 

384 reflect=_parse_reflect(data.get("reflect"), rule_name), 

385 temporal_half_life_days=float(half_life) if half_life is not None else None, 

386 ) 

387 

388 return expand_preset(spec) 

389 

390 

391def _parse_chunker(data: dict | None, rule_name: str) -> ChunkerSpec | None: 

392 if data is None: 

393 return None 

394 if not isinstance(data, dict): 

395 raise ConfigError(f"Rule '{rule_name}': pipeline.chunker must be a mapping") 

396 _warn_unknown_keys(data, _CHUNKER_KEYS, f"rule '{rule_name}' pipeline.chunker") 

397 return ChunkerSpec( 

398 strategy=data.get("strategy"), 

399 max_size=data.get("max_size"), 

400 overlap=data.get("overlap"), 

401 ) 

402 

403 

404def _parse_dedup(data: dict | None, rule_name: str) -> DedupSpec | None: 

405 if data is None: 

406 return None 

407 if not isinstance(data, dict): 

408 raise ConfigError(f"Rule '{rule_name}': pipeline.dedup must be a mapping") 

409 _warn_unknown_keys(data, _DEDUP_KEYS, f"rule '{rule_name}' pipeline.dedup") 

410 return DedupSpec( 

411 threshold=data.get("threshold"), 

412 action=data.get("action"), 

413 ) 

414 

415 

416def _parse_rerank(data: dict | None, rule_name: str) -> RerankSpec | None: 

417 if data is None: 

418 return None 

419 if not isinstance(data, dict): 

420 raise ConfigError(f"Rule '{rule_name}': pipeline.rerank must be a mapping") 

421 _warn_unknown_keys(data, _RERANK_KEYS, f"rule '{rule_name}' pipeline.rerank") 

422 return RerankSpec( 

423 keyword_weight=data.get("keyword_weight"), 

424 proper_noun_weight=data.get("proper_noun_weight"), 

425 ) 

426 

427 

428def _parse_reflect(data: dict | None, rule_name: str) -> ReflectSpec | None: 

429 if data is None: 

430 return None 

431 if not isinstance(data, dict): 

432 raise ConfigError(f"Rule '{rule_name}': pipeline.reflect must be a mapping") 

433 _warn_unknown_keys(data, _REFLECT_KEYS, f"rule '{rule_name}' pipeline.reflect") 

434 promote = data.get("promote_metadata") 

435 if promote is not None: 

436 if not isinstance(promote, list): 

437 raise ConfigError(f"Rule '{rule_name}': pipeline.reflect.promote_metadata must be a list") 

438 # P4: hard cap 

439 if len(promote) > _PROMOTE_METADATA_MAX: 

440 raise ConfigError( 

441 f"Rule '{rule_name}': pipeline.reflect.promote_metadata " 

442 f"capped at {_PROMOTE_METADATA_MAX} fields (got {len(promote)})" 

443 ) 

444 return ReflectSpec( 

445 prompt=data.get("prompt"), 

446 promote_metadata=promote, 

447 ) 

448 

449 

450def _match_references_content_type(match_data: dict) -> bool: 

451 """True if the match block references content_type at any level.""" 

452 if not match_data: 

453 return False 

454 if "content_type" in match_data: 

455 return True 

456 for key in ("all", "any", "none"): 

457 block = match_data.get(key) or [] 

458 for spec in block: 

459 if isinstance(spec, dict) and "content_type" in spec: 

460 return True 

461 return False 

462 

463 

464def _warn_unknown_keys(data: dict, known: set[str], context: str) -> None: 

465 import warnings 

466 

467 unknown = set(data.keys()) - known 

468 if unknown: 

469 warnings.warn( 

470 f"MIP loader: unknown keys in {context}: {sorted(unknown)}", 

471 stacklevel=3, 

472 ) 

473 

474 

475def _validate_mip_config(config: MipConfig) -> list[str]: 

476 """Validate internal consistency. Returns list of error messages.""" 

477 errors: list[str] = [] 

478 

479 if config.rules: 

480 names = [r.name for r in config.rules] 

481 if len(names) != len(set(names)): 

482 errors.append("Duplicate rule names found") 

483 

484 for rule in config.rules: 

485 if not rule.name: 

486 errors.append("Rule missing 'name'") 

487 if rule.action.escalate and rule.override: 

488 errors.append(f"Rule '{rule.name}' cannot have both override=true and escalate=mip") 

489 

490 return errors