Coverage for astrocyte/pipeline/section_reflect.py: 93%
81 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Section reflect adapter (M9 PR2.6).
3Bridges the section recall pipeline (``section_recall``,
4``expand_section_links``) into the agentic reflect loop's MemoryHit-
5shaped tool surface. Reflect doesn't know about sections — it sees
6memories with ``id`` / ``text`` / ``score``.
8The conversion convention:
9- ``MemoryHit.memory_id = f"{document_id}:{line_num}"``
10- ``MemoryHit.text`` is a windowed slice of the markdown around the
11 section's line_num (same slicer the bench's synth uses, kept short
12 enough that ~30 hits fit the agent's context budget).
13- ``MemoryHit.score`` carries the upstream score (rrf / cosine /
14 link weight) verbatim — useful for the agent to triage results.
16Two factories produce closures the reflect loop can call:
18- :func:`make_section_recall_fn` — runs ``section_recall`` on demand for
19 a sub-query the agent issues. Mode is fixed to ``"default"`` so all
20 always-on strategies fire (semantic + keyword + entity); the agent
21 refines by query text, not by mode.
22- :func:`make_section_expand_fn` — given a section memory_id, calls
23 ``expand_section_links`` for 1-hop graph expansion. Counting /
24 multi-session benefit most: "I see one mention, give me adjacent
25 sections."
27Why this exists: PR2.5 (counting synth) showed the picker undercounts
28when multi-session aggregation is required — fetches 5-6 sections when
29the answer needs 8-10. Reflect's iterative tool-call loop resolves
30that by letting the agent re-query until it has enough evidence.
31See ``docs/_design/recall.md`` §7 (PR2.6 reflect dispatch).
32"""
34from __future__ import annotations
36import logging
37from typing import TYPE_CHECKING, Awaitable, Callable
39from astrocyte.types import MemoryHit
41if TYPE_CHECKING:
42 from astrocyte.provider import LLMProvider, PageIndexStore
44_logger = logging.getLogger("astrocyte.pipeline.section_reflect")
47# ── Memory-id conventions ───────────────────────────────────────────
50def format_section_memory_id(document_id: str, line_num: int) -> str:
51 """Compose a ``MemoryHit.memory_id`` from a (doc, line) pair.
53 Used everywhere a section is exposed to the reflect agent so the
54 agent's ``cited_ids`` can be parsed back deterministically."""
55 return f"{document_id}:{line_num}"
58def parse_section_memory_id(memory_id: str) -> tuple[str, int] | None:
59 """Inverse of :func:`format_section_memory_id`. Returns ``None``
60 when the id doesn't conform — caller decides whether to skip or
61 raise. The reflect loop already validates citations against the
62 seen-id pool, so an out-of-shape id here means the agent invented
63 one (filter, don't crash)."""
64 if not memory_id:
65 return None
66 sep = memory_id.rfind(":")
67 if sep < 0:
68 return None
69 doc_id = memory_id[:sep]
70 try:
71 line_num = int(memory_id[sep + 1 :])
72 except ValueError:
73 return None
74 if not doc_id:
75 return None
76 return doc_id, line_num
79# ── Section → MemoryHit conversion ──────────────────────────────────
82def section_tuples_to_memory_hits(
83 tuples: list[tuple[str, int, float]],
84 *,
85 md_text_by_doc: dict[str, str],
86 slice_fn: Callable[[str, int], str],
87 max_chars: int = 600,
88) -> list[MemoryHit]:
89 """Convert ``[(doc_id, line_num, score), ...]`` to MemoryHits.
91 ``slice_fn(md_text, line_num) -> str`` is the bench's section
92 slicer (kept caller-supplied so this module doesn't take a hard
93 dep on the bench file's slicing rules). Truncates each hit's text
94 to ``max_chars`` so 20-30 hits comfortably fit the agent's
95 1024-token reply window when round-tripped as tool results.
96 """
97 out: list[MemoryHit] = []
98 for doc_id, line_num, score in tuples:
99 md = md_text_by_doc.get(doc_id, "")
100 text = slice_fn(md, line_num) if md else ""
101 if len(text) > max_chars:
102 text = text[: max_chars - 3] + "..."
103 out.append(
104 MemoryHit(
105 text=text,
106 score=float(score),
107 memory_id=format_section_memory_id(doc_id, line_num),
108 )
109 )
110 return out
113# ── Closure factories for the reflect loop ─────────────────────────
116RecallFn = Callable[[str, int], Awaitable[list[MemoryHit]]]
117ExpandFn = Callable[[str, int], Awaitable[list[MemoryHit]]]
118ListEntitiesFn = Callable[[str | None, int], Awaitable[list[tuple[str, int]]]]
121def make_section_recall_fn(
122 *,
123 store: PageIndexStore,
124 bank_id: str,
125 embedding_provider: LLMProvider,
126 md_text_by_doc: dict[str, str],
127 slice_fn: Callable[[str, int], str],
128 sub_recall_mode: str = "default",
129) -> RecallFn:
130 """Build a ``recall_fn(query, max_results) -> [MemoryHit]`` that
131 re-runs section recall on the agent's sub-query.
133 Mode is fixed (default ``"default"``) so the reflect loop's
134 sub-queries always exercise the always-on baseline strategies —
135 we don't try to re-detect mode per sub-query. The agent refines
136 via query text.
137 """
138 # Local import to keep the module pure (no top-level dep on
139 # section_recall — callers without pgvector still parse it cleanly).
140 from astrocyte.pipeline.section_recall import section_recall
142 async def _recall(query: str, max_results: int) -> list[MemoryHit]:
143 try:
144 result = await section_recall(
145 store=store,
146 bank_id=bank_id,
147 question=query,
148 mode=sub_recall_mode,
149 embedding_provider=embedding_provider,
150 # No annotator on sub-queries — the agent's query text
151 # IS the refinement signal. Pre-extracted entities /
152 # date_range from the outer pass would only match the
153 # original question's context.
154 question_entities=None,
155 date_range=None,
156 )
157 except Exception as exc: # noqa: BLE001
158 _logger.warning(
159 "section_reflect.recall_fn failed for q=%r: %s: %s",
160 query[:60],
161 type(exc).__name__,
162 exc,
163 )
164 return []
165 # Promote the top-N fused hits into MemoryHit shape. We don't
166 # rerun the cross-encoder reranker here — the agent's
167 # iterative loop is itself a form of reranking, and avoiding
168 # the model load keeps each tool call cheap.
169 tuples = [(h.document_id, h.line_num, h.rrf_score) for h in result.fused[:max_results]]
170 return section_tuples_to_memory_hits(
171 tuples,
172 md_text_by_doc=md_text_by_doc,
173 slice_fn=slice_fn,
174 )
176 return _recall
179def make_section_expand_fn(
180 *,
181 store: PageIndexStore,
182 md_text_by_doc: dict[str, str],
183 slice_fn: Callable[[str, int], str],
184 link_types: list[str] | None = None,
185) -> ExpandFn:
186 """Build an ``expand_fn(memory_id, max_sources) -> [MemoryHit]``
187 that 1-hop expands a section through ``section_links``.
189 ``link_types=None`` returns all link types (causal / supersedes /
190 elaborates / semantic_knn). Counting questions benefit most from
191 ``semantic_knn`` (sibling sections of a known mention); causal
192 questions benefit from causal/supersedes. We default to all and
193 let the rrf-style score on the link expose what mattered.
194 """
196 async def _expand(memory_id: str, max_sources: int) -> list[MemoryHit]:
197 parsed = parse_section_memory_id(memory_id)
198 if parsed is None:
199 _logger.info("section_reflect.expand_fn: unparseable memory_id=%r", memory_id)
200 return []
201 doc_id, line_num = parsed
202 try:
203 tuples = await store.expand_section_links(
204 [(doc_id, line_num)],
205 link_types=link_types,
206 top_k=max_sources,
207 )
208 except Exception as exc: # noqa: BLE001
209 _logger.warning(
210 "section_reflect.expand_fn failed: %s: %s",
211 type(exc).__name__,
212 exc,
213 )
214 return []
215 return section_tuples_to_memory_hits(
216 tuples,
217 md_text_by_doc=md_text_by_doc,
218 slice_fn=slice_fn,
219 )
221 return _expand
224def make_list_entities_fn(
225 *,
226 store: PageIndexStore,
227 bank_id: str,
228 document_id: str,
229) -> ListEntitiesFn:
230 """Build a ``list_entities_fn(pattern, limit) -> [(name, count)]``
231 that calls :meth:`PageIndexStore.list_distinct_entities` scoped to
232 a single document.
234 The bench answers one question per document, so binding
235 ``document_id`` here lets the agent issue a single-arg tool call
236 (``pattern`` only) without leaking storage shape into the prompt.
237 """
239 async def _list(
240 pattern: str | None,
241 limit: int,
242 ) -> list[tuple[str, int]]:
243 try:
244 return await store.list_distinct_entities(
245 bank_id,
246 document_id,
247 pattern=pattern,
248 limit=limit,
249 )
250 except Exception as exc: # noqa: BLE001
251 _logger.warning(
252 "section_reflect.list_entities_fn failed pattern=%r: %s: %s",
253 pattern,
254 type(exc).__name__,
255 exc,
256 )
257 return []
259 return _list
262# ── Citation → line_nums ────────────────────────────────────────────
265def cited_ids_to_line_nums(
266 cited_memory_ids: list[str],
267 *,
268 expected_doc_id: str | None = None,
269) -> list[int]:
270 """Extract line_nums from a reflect ``ReflectResult.sources``.
272 ``expected_doc_id`` filters out citations from sibling documents
273 (the bench builds one document per question, so cross-doc
274 citations indicate the agent confused itself). Pass ``None`` when
275 the caller wants all parsed line_nums regardless of doc.
276 """
277 out: list[int] = []
278 seen: set[int] = set()
279 for mid in cited_memory_ids:
280 parsed = parse_section_memory_id(mid)
281 if parsed is None:
282 continue
283 doc_id, line_num = parsed
284 if expected_doc_id is not None and doc_id != expected_doc_id:
285 continue
286 if line_num in seen:
287 continue
288 out.append(line_num)
289 seen.add(line_num)
290 return out