Coverage for astrocyte/recall/authority.py: 91%

67 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Structured recall authority — label hits by tier for precedence-in-the-prompt (M7). 

2 

3Hits are grouped using ``MemoryHit.metadata["authority_tier"]`` matching 

4:class:`~astrocyte.config.RecallAuthorityTierConfig` ``id``. Unmatched hits appear in a final 

5``[UNASSIGNED]`` section. Multi-query retrieval per tier is a later phase; this module formats 

6a single fused ``RecallResult``. 

7""" 

8 

9from __future__ import annotations 

10 

11from dataclasses import replace 

12from pathlib import Path 

13 

14from astrocyte.config import RecallAuthorityConfig 

15from astrocyte.types import MemoryHit, Metadata, RecallResult 

16 

17_METADATA_KEY = "authority_tier" 

18 

19 

20def merge_retain_metadata_authority_tier( 

21 metadata: Metadata | None, 

22 *, 

23 bank_id: str, 

24 profile_authority_tier: str | None, 

25 recall_authority: RecallAuthorityConfig | None, 

26) -> Metadata | None: 

27 """Set ``metadata[\"authority_tier\"]`` from extraction profile or ``tier_by_bank`` (M7 producers).""" 

28 if recall_authority is None or not recall_authority.enabled: 

29 return metadata 

30 tier: str | None = None 

31 if profile_authority_tier and str(profile_authority_tier).strip(): 

32 tier = str(profile_authority_tier).strip() 

33 elif recall_authority.tier_by_bank: 

34 raw = recall_authority.tier_by_bank.get(bank_id) 

35 tier = str(raw).strip() if raw else None 

36 if not tier: 

37 return metadata 

38 out: Metadata = dict(metadata or {}) 

39 out[_METADATA_KEY] = tier 

40 return out 

41 

42 

43def load_authority_rules(cfg: RecallAuthorityConfig) -> str: 

44 """Return rules text from ``rules_inline`` or ``rules_path`` (file UTF-8).""" 

45 if cfg.rules_inline and str(cfg.rules_inline).strip(): 

46 return str(cfg.rules_inline).strip() 

47 if cfg.rules_path and str(cfg.rules_path).strip(): 

48 path = Path(cfg.rules_path) 

49 return path.read_text(encoding="utf-8").strip() 

50 return "" 

51 

52 

53def build_authority_context(cfg: RecallAuthorityConfig, hits: list[MemoryHit]) -> str: 

54 """Build a single string with priority-ordered sections for model context.""" 

55 rules = load_authority_rules(cfg) 

56 tiers_sorted = sorted(cfg.tiers, key=lambda t: (t.priority, t.id)) 

57 buckets: dict[str, list[MemoryHit]] = {t.id: [] for t in tiers_sorted} 

58 unassigned: list[MemoryHit] = [] 

59 

60 for h in hits: 

61 md = h.metadata or {} 

62 raw = md.get(_METADATA_KEY) 

63 key = str(raw).strip() if raw is not None else "" 

64 if key and key in buckets: 

65 buckets[key].append(h) 

66 else: 

67 unassigned.append(h) 

68 

69 lines: list[str] = [] 

70 if rules: 

71 lines.append(rules) 

72 lines.append("") 

73 lines.append("---") 

74 lines.append("") 

75 

76 for t in tiers_sorted: 

77 label = t.label.strip() if t.label else f"[{t.id}]" 

78 lines.append(label) 

79 section_hits = buckets.get(t.id, []) 

80 if not section_hits: 

81 lines.append("(no hits in this tier)") 

82 else: 

83 for h in section_hits: 

84 lines.append(f"- {h.text.strip()}") 

85 lines.append("") 

86 

87 if unassigned: 

88 lines.append("[UNASSIGNED]") 

89 for h in unassigned: 

90 lines.append(f"- {h.text.strip()}") 

91 lines.append("") 

92 

93 return "\n".join(lines).rstrip() + "\n" 

94 

95 

96def apply_recall_authority(result: RecallResult, cfg: RecallAuthorityConfig | None) -> RecallResult: 

97 """Attach ``authority_context`` when ``recall_authority.enabled`` and tiers are configured.""" 

98 if cfg is None or not cfg.enabled: 

99 return result 

100 if not cfg.tiers: 

101 return replace(result, authority_context=load_authority_rules(cfg) or None) 

102 text = build_authority_context(cfg, result.hits) 

103 return replace(result, authority_context=text or None)