Coverage for astrocyte/documents/retrieval/navigator.py: 16%

127 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""DocumentNavigator — agentic tree-search retrieval. 

2 

3PageIndex's core retrieval innovation: instead of vector similarity, 

4an LLM navigates the document tree structure to identify the exact 

5relevant section. 

6 

7Agent loop per document: 

8 1. get_document_structure() → TreeSkeleton (titles + summaries, no text) 

9 2. LLM: "Which node_ids are most relevant to {query}?" → JSON list 

10 3. get_node_content() for each identified node 

11 4. Optional: explore children if the LLM requests deeper navigation 

12 5. Compile SectionHit[] with breadcrumb + relevance reasoning 

13 

14Cost: 1-2 LLM calls per document. No embedding at query time. 

15Same llm_call pattern as AdaptiveSummarizer — injectable, testable 

16without an OpenAI key. 

17""" 

18 

19from __future__ import annotations 

20 

21import asyncio 

22import json 

23import logging 

24from typing import Awaitable, Callable 

25 

26from astrocyte.documents.retrieval.retriever import DocumentRetriever 

27from astrocyte.documents.retrieval.types import ( 

28 DocumentSearchResult, 

29 SectionHit, 

30 SkeletonNode, 

31 TreeSkeleton, 

32) 

33from astrocyte.documents.storage import DocumentNotFoundError 

34 

35logger = logging.getLogger(__name__) 

36 

37LlmCall = Callable[[str], Awaitable[str]] 

38 

39# ── Prompts ────────────────────────────────────────────────────────────────── 

40 

41_STRUCTURE_PROMPT = """\ 

42You are navigating a document tree to find sections relevant to a query. 

43 

44Document: {title} 

45Query: {query} 

46 

47Document structure (node_id | title | summary): 

48{structure_lines} 

49 

50Select the most relevant sections. Return a JSON object: 

51 "node_ids": list of node_id strings, most relevant first (max {max_sections}) 

52 "reasoning": one sentence explaining why these sections answer the query 

53 

54Return ONLY valid JSON. No markdown fences, no extra text. 

55Example: {{"node_ids": ["abc123", "def456"], "reasoning": "Section 3.2 covers pricing directly."}} 

56""" 

57 

58_CHILD_PROMPT = """\ 

59You are drilling into a document section to answer a query more precisely. 

60 

61Query: {query} 

62Parent section: "{parent_title}" 

63 

64Child sections available: 

65{children_lines} 

66 

67Which child node_ids should be retrieved? Return JSON: 

68{{"node_ids": [...], "reasoning": "..."}} 

69Return ONLY valid JSON. 

70""" 

71 

72 

73# ── Helpers ─────────────────────────────────────────────────────────────────── 

74 

75def _format_structure(skeleton: TreeSkeleton) -> str: 

76 lines: list[str] = [] 

77 for node in skeleton.nodes: 

78 indent = " " * max(0, node.depth - 1) 

79 if node.summary: 

80 snip = node.summary[:80].replace("\n", " ") 

81 snip = snip + "…" if len(node.summary) > 80 else snip 

82 lines.append(f"{indent}{node.node_id} | {node.title} | {snip}") 

83 else: 

84 lines.append(f"{indent}{node.node_id} | {node.title}") 

85 return "\n".join(lines) 

86 

87 

88def _format_children(children: list[SkeletonNode]) -> str: 

89 lines: list[str] = [] 

90 for child in children: 

91 if child.summary: 

92 snip = child.summary[:80].replace("\n", " ") 

93 snip = snip + "…" if len(child.summary) > 80 else snip 

94 lines.append(f"{child.node_id} | {child.title} | {snip}") 

95 else: 

96 lines.append(f"{child.node_id} | {child.title}") 

97 return "\n".join(lines) 

98 

99 

100def _parse_llm_json(text: str) -> dict: 

101 """Extract a JSON object from LLM response, tolerating markdown fences.""" 

102 text = text.strip() 

103 if text.startswith("```"): 

104 lines = text.split("\n") 

105 text = "\n".join(line for line in lines if not line.startswith("```")).strip() 

106 try: 

107 return json.loads(text) 

108 except json.JSONDecodeError: 

109 start, end = text.find("{"), text.rfind("}") + 1 

110 if start >= 0 and end > start: 

111 try: 

112 return json.loads(text[start:end]) 

113 except json.JSONDecodeError: 

114 pass 

115 return {"node_ids": [], "reasoning": ""} 

116 

117 

118def _build_breadcrumb(skeleton: TreeSkeleton, node_id: str) -> list[str]: 

119 """Return ancestor titles from root to the parent of node_id.""" 

120 parent_map = {n.node_id: n.parent_id for n in skeleton.nodes} 

121 title_map = {n.node_id: n.title for n in skeleton.nodes} 

122 chain: list[str] = [] 

123 current = parent_map.get(node_id) 

124 while current is not None: 

125 chain.append(title_map.get(current, current)) 

126 current = parent_map.get(current) 

127 chain.reverse() 

128 return chain 

129 

130 

131# ── Navigator ───────────────────────────────────────────────────────────────── 

132 

133class DocumentNavigator: 

134 """Agentic tree-search — PageIndex's core retrieval innovation. 

135 

136 Usage: 

137 retriever = DocumentRetriever(document_store) 

138 navigator = DocumentNavigator(retriever, llm_call) 

139 result = await navigator.search(query, [doc_id]) 

140 

141 ``llm_call`` is ``async def fn(prompt: str) -> str``. 

142 Pass a real LLM wrapper in production; pass a fake in tests. 

143 """ 

144 

145 def __init__( 

146 self, 

147 retriever: DocumentRetriever, 

148 llm_call: LlmCall, 

149 *, 

150 max_iterations: int = 5, 

151 max_sections: int = 3, 

152 ) -> None: 

153 self._retriever = retriever 

154 self._llm_call = llm_call 

155 self.max_iterations = max_iterations 

156 self.max_sections = max_sections 

157 

158 async def search( 

159 self, 

160 query: str, 

161 document_ids: list[str], 

162 ) -> DocumentSearchResult: 

163 """Tree-search across one or more documents. 

164 

165 Documents are searched concurrently. Results are ordered by document 

166 then by LLM relevance ranking within each document. 

167 """ 

168 tasks = [self._search_one(query, doc_id) for doc_id in document_ids] 

169 per_doc = await asyncio.gather(*tasks, return_exceptions=True) 

170 

171 all_sections: list[SectionHit] = [] 

172 iterations_used = 0 

173 for doc_id, outcome in zip(document_ids, per_doc): 

174 if isinstance(outcome, Exception): 

175 logger.warning("tree-search failed for doc=%s: %s", doc_id, outcome) 

176 continue 

177 sections, iters = outcome 

178 all_sections.extend(sections) 

179 iterations_used += iters 

180 

181 return DocumentSearchResult( 

182 query=query, 

183 sections=all_sections[: self.max_sections * len(document_ids)], 

184 documents_searched=len(document_ids), 

185 iterations_used=iterations_used, 

186 ) 

187 

188 async def _search_one( 

189 self, 

190 query: str, 

191 doc_id: str, 

192 ) -> tuple[list[SectionHit], int]: 

193 try: 

194 skeleton = await self._retriever.get_document_structure(doc_id) 

195 except DocumentNotFoundError: 

196 logger.warning("document not found: %s", doc_id) 

197 return [], 0 

198 

199 if not skeleton.nodes: 

200 return [], 0 

201 

202 iterations = 0 

203 sections: list[SectionHit] = [] 

204 collected: set[str] = set() 

205 

206 # ── Step 1: reason over full skeleton ──────────────────────── 

207 prompt = _STRUCTURE_PROMPT.format( 

208 title=skeleton.title, 

209 query=query, 

210 structure_lines=_format_structure(skeleton), 

211 max_sections=self.max_sections, 

212 ) 

213 try: 

214 response = await self._llm_call(prompt) 

215 iterations += 1 

216 except Exception as exc: # noqa: BLE001 

217 logger.warning("navigator LLM call failed for doc=%s: %s", doc_id, exc) 

218 return [], iterations 

219 

220 parsed = _parse_llm_json(response) 

221 node_ids: list[str] = parsed.get("node_ids", []) 

222 reasoning: str = parsed.get("reasoning", "") 

223 

224 # ── Step 2: fetch identified nodes ──────────────────────────── 

225 for node_id in node_ids[: self.max_sections]: 

226 if node_id in collected: 

227 continue 

228 try: 

229 content = await self._retriever.get_node_content(doc_id, node_id) 

230 except (KeyError, DocumentNotFoundError) as exc: 

231 logger.debug("node fetch failed %s/%s: %s", doc_id, node_id, exc) 

232 continue 

233 

234 sections.append( 

235 SectionHit( 

236 document_id=doc_id, 

237 node_id=node_id, 

238 node_title=content.title, 

239 node_depth=content.depth, 

240 breadcrumb=_build_breadcrumb(skeleton, node_id), 

241 text=content.text, 

242 relevance_reasoning=reasoning, 

243 ) 

244 ) 

245 collected.add(node_id) 

246 

247 # ── Step 3: optional child exploration ─────────────────── 

248 if ( 

249 content.children 

250 and len(sections) < self.max_sections 

251 and iterations < self.max_iterations 

252 ): 

253 child_prompt = _CHILD_PROMPT.format( 

254 query=query, 

255 parent_title=content.title, 

256 children_lines=_format_children(content.children), 

257 ) 

258 try: 

259 child_response = await self._llm_call(child_prompt) 

260 iterations += 1 

261 child_parsed = _parse_llm_json(child_response) 

262 child_ids: list[str] = child_parsed.get("node_ids", []) 

263 child_reasoning: str = child_parsed.get("reasoning", reasoning) 

264 except Exception as exc: # noqa: BLE001 

265 logger.debug("child exploration LLM failed: %s", exc) 

266 child_ids, child_reasoning = [], reasoning 

267 

268 for child_id in child_ids: 

269 if len(sections) >= self.max_sections or child_id in collected: 

270 break 

271 try: 

272 cc = await self._retriever.get_node_content(doc_id, child_id) 

273 except (KeyError, DocumentNotFoundError): 

274 continue 

275 sections.append( 

276 SectionHit( 

277 document_id=doc_id, 

278 node_id=child_id, 

279 node_title=cc.title, 

280 node_depth=cc.depth, 

281 breadcrumb=_build_breadcrumb(skeleton, child_id), 

282 text=cc.text, 

283 relevance_reasoning=child_reasoning, 

284 ) 

285 ) 

286 collected.add(child_id) 

287 

288 return sections, iterations