Coverage for astrocyte/documents/retrieval/navigator.py: 16%
127 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""DocumentNavigator — agentic tree-search retrieval.
3PageIndex's core retrieval innovation: instead of vector similarity,
4an LLM navigates the document tree structure to identify the exact
5relevant section.
7Agent loop per document:
8 1. get_document_structure() → TreeSkeleton (titles + summaries, no text)
9 2. LLM: "Which node_ids are most relevant to {query}?" → JSON list
10 3. get_node_content() for each identified node
11 4. Optional: explore children if the LLM requests deeper navigation
12 5. Compile SectionHit[] with breadcrumb + relevance reasoning
14Cost: 1-2 LLM calls per document. No embedding at query time.
15Same llm_call pattern as AdaptiveSummarizer — injectable, testable
16without an OpenAI key.
17"""
19from __future__ import annotations
21import asyncio
22import json
23import logging
24from typing import Awaitable, Callable
26from astrocyte.documents.retrieval.retriever import DocumentRetriever
27from astrocyte.documents.retrieval.types import (
28 DocumentSearchResult,
29 SectionHit,
30 SkeletonNode,
31 TreeSkeleton,
32)
33from astrocyte.documents.storage import DocumentNotFoundError
35logger = logging.getLogger(__name__)
37LlmCall = Callable[[str], Awaitable[str]]
39# ── Prompts ──────────────────────────────────────────────────────────────────
41_STRUCTURE_PROMPT = """\
42You are navigating a document tree to find sections relevant to a query.
44Document: {title}
45Query: {query}
47Document structure (node_id | title | summary):
48{structure_lines}
50Select the most relevant sections. Return a JSON object:
51 "node_ids": list of node_id strings, most relevant first (max {max_sections})
52 "reasoning": one sentence explaining why these sections answer the query
54Return ONLY valid JSON. No markdown fences, no extra text.
55Example: {{"node_ids": ["abc123", "def456"], "reasoning": "Section 3.2 covers pricing directly."}}
56"""
58_CHILD_PROMPT = """\
59You are drilling into a document section to answer a query more precisely.
61Query: {query}
62Parent section: "{parent_title}"
64Child sections available:
65{children_lines}
67Which child node_ids should be retrieved? Return JSON:
68{{"node_ids": [...], "reasoning": "..."}}
69Return ONLY valid JSON.
70"""
73# ── Helpers ───────────────────────────────────────────────────────────────────
75def _format_structure(skeleton: TreeSkeleton) -> str:
76 lines: list[str] = []
77 for node in skeleton.nodes:
78 indent = " " * max(0, node.depth - 1)
79 if node.summary:
80 snip = node.summary[:80].replace("\n", " ")
81 snip = snip + "…" if len(node.summary) > 80 else snip
82 lines.append(f"{indent}{node.node_id} | {node.title} | {snip}")
83 else:
84 lines.append(f"{indent}{node.node_id} | {node.title}")
85 return "\n".join(lines)
88def _format_children(children: list[SkeletonNode]) -> str:
89 lines: list[str] = []
90 for child in children:
91 if child.summary:
92 snip = child.summary[:80].replace("\n", " ")
93 snip = snip + "…" if len(child.summary) > 80 else snip
94 lines.append(f"{child.node_id} | {child.title} | {snip}")
95 else:
96 lines.append(f"{child.node_id} | {child.title}")
97 return "\n".join(lines)
100def _parse_llm_json(text: str) -> dict:
101 """Extract a JSON object from LLM response, tolerating markdown fences."""
102 text = text.strip()
103 if text.startswith("```"):
104 lines = text.split("\n")
105 text = "\n".join(line for line in lines if not line.startswith("```")).strip()
106 try:
107 return json.loads(text)
108 except json.JSONDecodeError:
109 start, end = text.find("{"), text.rfind("}") + 1
110 if start >= 0 and end > start:
111 try:
112 return json.loads(text[start:end])
113 except json.JSONDecodeError:
114 pass
115 return {"node_ids": [], "reasoning": ""}
118def _build_breadcrumb(skeleton: TreeSkeleton, node_id: str) -> list[str]:
119 """Return ancestor titles from root to the parent of node_id."""
120 parent_map = {n.node_id: n.parent_id for n in skeleton.nodes}
121 title_map = {n.node_id: n.title for n in skeleton.nodes}
122 chain: list[str] = []
123 current = parent_map.get(node_id)
124 while current is not None:
125 chain.append(title_map.get(current, current))
126 current = parent_map.get(current)
127 chain.reverse()
128 return chain
131# ── Navigator ─────────────────────────────────────────────────────────────────
133class DocumentNavigator:
134 """Agentic tree-search — PageIndex's core retrieval innovation.
136 Usage:
137 retriever = DocumentRetriever(document_store)
138 navigator = DocumentNavigator(retriever, llm_call)
139 result = await navigator.search(query, [doc_id])
141 ``llm_call`` is ``async def fn(prompt: str) -> str``.
142 Pass a real LLM wrapper in production; pass a fake in tests.
143 """
145 def __init__(
146 self,
147 retriever: DocumentRetriever,
148 llm_call: LlmCall,
149 *,
150 max_iterations: int = 5,
151 max_sections: int = 3,
152 ) -> None:
153 self._retriever = retriever
154 self._llm_call = llm_call
155 self.max_iterations = max_iterations
156 self.max_sections = max_sections
158 async def search(
159 self,
160 query: str,
161 document_ids: list[str],
162 ) -> DocumentSearchResult:
163 """Tree-search across one or more documents.
165 Documents are searched concurrently. Results are ordered by document
166 then by LLM relevance ranking within each document.
167 """
168 tasks = [self._search_one(query, doc_id) for doc_id in document_ids]
169 per_doc = await asyncio.gather(*tasks, return_exceptions=True)
171 all_sections: list[SectionHit] = []
172 iterations_used = 0
173 for doc_id, outcome in zip(document_ids, per_doc):
174 if isinstance(outcome, Exception):
175 logger.warning("tree-search failed for doc=%s: %s", doc_id, outcome)
176 continue
177 sections, iters = outcome
178 all_sections.extend(sections)
179 iterations_used += iters
181 return DocumentSearchResult(
182 query=query,
183 sections=all_sections[: self.max_sections * len(document_ids)],
184 documents_searched=len(document_ids),
185 iterations_used=iterations_used,
186 )
188 async def _search_one(
189 self,
190 query: str,
191 doc_id: str,
192 ) -> tuple[list[SectionHit], int]:
193 try:
194 skeleton = await self._retriever.get_document_structure(doc_id)
195 except DocumentNotFoundError:
196 logger.warning("document not found: %s", doc_id)
197 return [], 0
199 if not skeleton.nodes:
200 return [], 0
202 iterations = 0
203 sections: list[SectionHit] = []
204 collected: set[str] = set()
206 # ── Step 1: reason over full skeleton ────────────────────────
207 prompt = _STRUCTURE_PROMPT.format(
208 title=skeleton.title,
209 query=query,
210 structure_lines=_format_structure(skeleton),
211 max_sections=self.max_sections,
212 )
213 try:
214 response = await self._llm_call(prompt)
215 iterations += 1
216 except Exception as exc: # noqa: BLE001
217 logger.warning("navigator LLM call failed for doc=%s: %s", doc_id, exc)
218 return [], iterations
220 parsed = _parse_llm_json(response)
221 node_ids: list[str] = parsed.get("node_ids", [])
222 reasoning: str = parsed.get("reasoning", "")
224 # ── Step 2: fetch identified nodes ────────────────────────────
225 for node_id in node_ids[: self.max_sections]:
226 if node_id in collected:
227 continue
228 try:
229 content = await self._retriever.get_node_content(doc_id, node_id)
230 except (KeyError, DocumentNotFoundError) as exc:
231 logger.debug("node fetch failed %s/%s: %s", doc_id, node_id, exc)
232 continue
234 sections.append(
235 SectionHit(
236 document_id=doc_id,
237 node_id=node_id,
238 node_title=content.title,
239 node_depth=content.depth,
240 breadcrumb=_build_breadcrumb(skeleton, node_id),
241 text=content.text,
242 relevance_reasoning=reasoning,
243 )
244 )
245 collected.add(node_id)
247 # ── Step 3: optional child exploration ───────────────────
248 if (
249 content.children
250 and len(sections) < self.max_sections
251 and iterations < self.max_iterations
252 ):
253 child_prompt = _CHILD_PROMPT.format(
254 query=query,
255 parent_title=content.title,
256 children_lines=_format_children(content.children),
257 )
258 try:
259 child_response = await self._llm_call(child_prompt)
260 iterations += 1
261 child_parsed = _parse_llm_json(child_response)
262 child_ids: list[str] = child_parsed.get("node_ids", [])
263 child_reasoning: str = child_parsed.get("reasoning", reasoning)
264 except Exception as exc: # noqa: BLE001
265 logger.debug("child exploration LLM failed: %s", exc)
266 child_ids, child_reasoning = [], reasoning
268 for child_id in child_ids:
269 if len(sections) >= self.max_sections or child_id in collected:
270 break
271 try:
272 cc = await self._retriever.get_node_content(doc_id, child_id)
273 except (KeyError, DocumentNotFoundError):
274 continue
275 sections.append(
276 SectionHit(
277 document_id=doc_id,
278 node_id=child_id,
279 node_title=cc.title,
280 node_depth=cc.depth,
281 breadcrumb=_build_breadcrumb(skeleton, child_id),
282 text=cc.text,
283 relevance_reasoning=child_reasoning,
284 )
285 )
286 collected.add(child_id)
288 return sections, iterations