Coverage for astrocyte/pipeline/section_reflect.py: 93%

81 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Section reflect adapter (M9 PR2.6). 

2 

3Bridges the section recall pipeline (``section_recall``, 

4``expand_section_links``) into the agentic reflect loop's MemoryHit- 

5shaped tool surface. Reflect doesn't know about sections — it sees 

6memories with ``id`` / ``text`` / ``score``. 

7 

8The conversion convention: 

9- ``MemoryHit.memory_id = f"{document_id}:{line_num}"`` 

10- ``MemoryHit.text`` is a windowed slice of the markdown around the 

11 section's line_num (same slicer the bench's synth uses, kept short 

12 enough that ~30 hits fit the agent's context budget). 

13- ``MemoryHit.score`` carries the upstream score (rrf / cosine / 

14 link weight) verbatim — useful for the agent to triage results. 

15 

16Two factories produce closures the reflect loop can call: 

17 

18- :func:`make_section_recall_fn` — runs ``section_recall`` on demand for 

19 a sub-query the agent issues. Mode is fixed to ``"default"`` so all 

20 always-on strategies fire (semantic + keyword + entity); the agent 

21 refines by query text, not by mode. 

22- :func:`make_section_expand_fn` — given a section memory_id, calls 

23 ``expand_section_links`` for 1-hop graph expansion. Counting / 

24 multi-session benefit most: "I see one mention, give me adjacent 

25 sections." 

26 

27Why this exists: PR2.5 (counting synth) showed the picker undercounts 

28when multi-session aggregation is required — fetches 5-6 sections when 

29the answer needs 8-10. Reflect's iterative tool-call loop resolves 

30that by letting the agent re-query until it has enough evidence. 

31See ``docs/_design/recall.md`` §7 (PR2.6 reflect dispatch). 

32""" 

33 

34from __future__ import annotations 

35 

36import logging 

37from typing import TYPE_CHECKING, Awaitable, Callable 

38 

39from astrocyte.types import MemoryHit 

40 

41if TYPE_CHECKING: 

42 from astrocyte.provider import LLMProvider, PageIndexStore 

43 

44_logger = logging.getLogger("astrocyte.pipeline.section_reflect") 

45 

46 

47# ── Memory-id conventions ─────────────────────────────────────────── 

48 

49 

50def format_section_memory_id(document_id: str, line_num: int) -> str: 

51 """Compose a ``MemoryHit.memory_id`` from a (doc, line) pair. 

52 

53 Used everywhere a section is exposed to the reflect agent so the 

54 agent's ``cited_ids`` can be parsed back deterministically.""" 

55 return f"{document_id}:{line_num}" 

56 

57 

58def parse_section_memory_id(memory_id: str) -> tuple[str, int] | None: 

59 """Inverse of :func:`format_section_memory_id`. Returns ``None`` 

60 when the id doesn't conform — caller decides whether to skip or 

61 raise. The reflect loop already validates citations against the 

62 seen-id pool, so an out-of-shape id here means the agent invented 

63 one (filter, don't crash).""" 

64 if not memory_id: 

65 return None 

66 sep = memory_id.rfind(":") 

67 if sep < 0: 

68 return None 

69 doc_id = memory_id[:sep] 

70 try: 

71 line_num = int(memory_id[sep + 1 :]) 

72 except ValueError: 

73 return None 

74 if not doc_id: 

75 return None 

76 return doc_id, line_num 

77 

78 

79# ── Section → MemoryHit conversion ────────────────────────────────── 

80 

81 

82def section_tuples_to_memory_hits( 

83 tuples: list[tuple[str, int, float]], 

84 *, 

85 md_text_by_doc: dict[str, str], 

86 slice_fn: Callable[[str, int], str], 

87 max_chars: int = 600, 

88) -> list[MemoryHit]: 

89 """Convert ``[(doc_id, line_num, score), ...]`` to MemoryHits. 

90 

91 ``slice_fn(md_text, line_num) -> str`` is the bench's section 

92 slicer (kept caller-supplied so this module doesn't take a hard 

93 dep on the bench file's slicing rules). Truncates each hit's text 

94 to ``max_chars`` so 20-30 hits comfortably fit the agent's 

95 1024-token reply window when round-tripped as tool results. 

96 """ 

97 out: list[MemoryHit] = [] 

98 for doc_id, line_num, score in tuples: 

99 md = md_text_by_doc.get(doc_id, "") 

100 text = slice_fn(md, line_num) if md else "" 

101 if len(text) > max_chars: 

102 text = text[: max_chars - 3] + "..." 

103 out.append( 

104 MemoryHit( 

105 text=text, 

106 score=float(score), 

107 memory_id=format_section_memory_id(doc_id, line_num), 

108 ) 

109 ) 

110 return out 

111 

112 

113# ── Closure factories for the reflect loop ───────────────────────── 

114 

115 

116RecallFn = Callable[[str, int], Awaitable[list[MemoryHit]]] 

117ExpandFn = Callable[[str, int], Awaitable[list[MemoryHit]]] 

118ListEntitiesFn = Callable[[str | None, int], Awaitable[list[tuple[str, int]]]] 

119 

120 

121def make_section_recall_fn( 

122 *, 

123 store: PageIndexStore, 

124 bank_id: str, 

125 embedding_provider: LLMProvider, 

126 md_text_by_doc: dict[str, str], 

127 slice_fn: Callable[[str, int], str], 

128 sub_recall_mode: str = "default", 

129) -> RecallFn: 

130 """Build a ``recall_fn(query, max_results) -> [MemoryHit]`` that 

131 re-runs section recall on the agent's sub-query. 

132 

133 Mode is fixed (default ``"default"``) so the reflect loop's 

134 sub-queries always exercise the always-on baseline strategies — 

135 we don't try to re-detect mode per sub-query. The agent refines 

136 via query text. 

137 """ 

138 # Local import to keep the module pure (no top-level dep on 

139 # section_recall — callers without pgvector still parse it cleanly). 

140 from astrocyte.pipeline.section_recall import section_recall 

141 

142 async def _recall(query: str, max_results: int) -> list[MemoryHit]: 

143 try: 

144 result = await section_recall( 

145 store=store, 

146 bank_id=bank_id, 

147 question=query, 

148 mode=sub_recall_mode, 

149 embedding_provider=embedding_provider, 

150 # No annotator on sub-queries — the agent's query text 

151 # IS the refinement signal. Pre-extracted entities / 

152 # date_range from the outer pass would only match the 

153 # original question's context. 

154 question_entities=None, 

155 date_range=None, 

156 ) 

157 except Exception as exc: # noqa: BLE001 

158 _logger.warning( 

159 "section_reflect.recall_fn failed for q=%r: %s: %s", 

160 query[:60], 

161 type(exc).__name__, 

162 exc, 

163 ) 

164 return [] 

165 # Promote the top-N fused hits into MemoryHit shape. We don't 

166 # rerun the cross-encoder reranker here — the agent's 

167 # iterative loop is itself a form of reranking, and avoiding 

168 # the model load keeps each tool call cheap. 

169 tuples = [(h.document_id, h.line_num, h.rrf_score) for h in result.fused[:max_results]] 

170 return section_tuples_to_memory_hits( 

171 tuples, 

172 md_text_by_doc=md_text_by_doc, 

173 slice_fn=slice_fn, 

174 ) 

175 

176 return _recall 

177 

178 

179def make_section_expand_fn( 

180 *, 

181 store: PageIndexStore, 

182 md_text_by_doc: dict[str, str], 

183 slice_fn: Callable[[str, int], str], 

184 link_types: list[str] | None = None, 

185) -> ExpandFn: 

186 """Build an ``expand_fn(memory_id, max_sources) -> [MemoryHit]`` 

187 that 1-hop expands a section through ``section_links``. 

188 

189 ``link_types=None`` returns all link types (causal / supersedes / 

190 elaborates / semantic_knn). Counting questions benefit most from 

191 ``semantic_knn`` (sibling sections of a known mention); causal 

192 questions benefit from causal/supersedes. We default to all and 

193 let the rrf-style score on the link expose what mattered. 

194 """ 

195 

196 async def _expand(memory_id: str, max_sources: int) -> list[MemoryHit]: 

197 parsed = parse_section_memory_id(memory_id) 

198 if parsed is None: 

199 _logger.info("section_reflect.expand_fn: unparseable memory_id=%r", memory_id) 

200 return [] 

201 doc_id, line_num = parsed 

202 try: 

203 tuples = await store.expand_section_links( 

204 [(doc_id, line_num)], 

205 link_types=link_types, 

206 top_k=max_sources, 

207 ) 

208 except Exception as exc: # noqa: BLE001 

209 _logger.warning( 

210 "section_reflect.expand_fn failed: %s: %s", 

211 type(exc).__name__, 

212 exc, 

213 ) 

214 return [] 

215 return section_tuples_to_memory_hits( 

216 tuples, 

217 md_text_by_doc=md_text_by_doc, 

218 slice_fn=slice_fn, 

219 ) 

220 

221 return _expand 

222 

223 

224def make_list_entities_fn( 

225 *, 

226 store: PageIndexStore, 

227 bank_id: str, 

228 document_id: str, 

229) -> ListEntitiesFn: 

230 """Build a ``list_entities_fn(pattern, limit) -> [(name, count)]`` 

231 that calls :meth:`PageIndexStore.list_distinct_entities` scoped to 

232 a single document. 

233 

234 The bench answers one question per document, so binding 

235 ``document_id`` here lets the agent issue a single-arg tool call 

236 (``pattern`` only) without leaking storage shape into the prompt. 

237 """ 

238 

239 async def _list( 

240 pattern: str | None, 

241 limit: int, 

242 ) -> list[tuple[str, int]]: 

243 try: 

244 return await store.list_distinct_entities( 

245 bank_id, 

246 document_id, 

247 pattern=pattern, 

248 limit=limit, 

249 ) 

250 except Exception as exc: # noqa: BLE001 

251 _logger.warning( 

252 "section_reflect.list_entities_fn failed pattern=%r: %s: %s", 

253 pattern, 

254 type(exc).__name__, 

255 exc, 

256 ) 

257 return [] 

258 

259 return _list 

260 

261 

262# ── Citation → line_nums ──────────────────────────────────────────── 

263 

264 

265def cited_ids_to_line_nums( 

266 cited_memory_ids: list[str], 

267 *, 

268 expected_doc_id: str | None = None, 

269) -> list[int]: 

270 """Extract line_nums from a reflect ``ReflectResult.sources``. 

271 

272 ``expected_doc_id`` filters out citations from sibling documents 

273 (the bench builds one document per question, so cross-doc 

274 citations indicate the agent confused itself). Pass ``None`` when 

275 the caller wants all parsed line_nums regardless of doc. 

276 """ 

277 out: list[int] = [] 

278 seen: set[int] = set() 

279 for mid in cited_memory_ids: 

280 parsed = parse_section_memory_id(mid) 

281 if parsed is None: 

282 continue 

283 doc_id, line_num = parsed 

284 if expected_doc_id is not None and doc_id != expected_doc_id: 

285 continue 

286 if line_num in seen: 

287 continue 

288 out.append(line_num) 

289 seen.add(line_num) 

290 return out