Coverage for astrocyte/documents/builders/summarizer.py: 98%
49 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Adaptive per-node summarization (PageIndex parity).
3Walks a ``DocumentTree`` and populates each node's ``summary`` field
4according to PageIndex's rule:
6 - If node text < threshold tokens: ``summary.text = node.text`` (no
7 LLM call; ``kind="raw"``).
8 - If node text ≥ threshold tokens: call LLM with PageIndex's prompt;
9 ``kind="llm"``.
10 - For internal nodes (have children): ``kind="prefix"`` — the summary
11 is meant to be read as "what this whole subtree is about."
13Default threshold: 200 tokens (PageIndex default; configurable).
15Public API:
16 AdaptiveSummarizer(llm_call, threshold_tokens=200, model="gpt-4o-mini")
17 .summarize_tree(tree) # in-place: sets node.summary on every node
18 .summarize_node(node) # single-node helper
20``llm_call`` is a coroutine ``async def fn(prompt: str) -> str``. Pass a
21real LLM provider in production; pass a fake in tests. Keeps the
22summarizer testable without an OpenAI key.
23"""
25from __future__ import annotations
27import asyncio
28import logging
29from typing import Awaitable, Callable
31from astrocyte.documents.types import DocumentTree, NodeSummary, TreeNode
32from astrocyte.policy.homeostasis import count_tokens
34logger = logging.getLogger(__name__)
37# PageIndex parity — the exact prompt from
38# /Users/calvin/AstrocyteAI/PageIndex/pageindex/utils.py:generate_node_summary
39PAGEINDEX_SUMMARY_PROMPT = """You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document.
41 Partial Document Text: {node_text}
43 Directly return the description, do not include any other text.
44 """
46DEFAULT_THRESHOLD_TOKENS = 200
49LlmCall = Callable[[str], Awaitable[str]]
52class AdaptiveSummarizer:
53 """Per-node summarizer with PageIndex-style adaptive LLM gating.
55 Cost shape: most LME / LoCoMo session-blocks are < 200 tokens and
56 skip the LLM entirely. LLM is only called for genuinely long nodes
57 where the raw text would be too long to use directly downstream.
58 """
60 def __init__(
61 self,
62 llm_call: LlmCall,
63 *,
64 threshold_tokens: int = DEFAULT_THRESHOLD_TOKENS,
65 model: str = "gpt-4o-mini",
66 max_concurrent_llm: int = 8,
67 ) -> None:
68 if threshold_tokens < 1:
69 raise ValueError(f"threshold_tokens must be >= 1, got {threshold_tokens}")
70 self._llm_call = llm_call
71 self._threshold = threshold_tokens
72 self.model = model
73 self._sem = asyncio.Semaphore(max_concurrent_llm)
75 # ── single-node helper ────────────────────────────────────────────
77 async def summarize_node(self, node: TreeNode) -> NodeSummary:
78 """Produce a NodeSummary for a single node.
80 Returns the summary (and sets ``node.summary``). Internal-node
81 determination is left to the caller because it depends on tree
82 traversal context — use ``summarize_tree`` for the typical case.
83 """
84 text = node.text or ""
85 tokens = count_tokens(text)
86 if tokens < self._threshold:
87 summary = NodeSummary(text=text, kind="raw", token_count=tokens)
88 node.summary = summary
89 return summary
91 async with self._sem:
92 try:
93 prompt = PAGEINDEX_SUMMARY_PROMPT.format(node_text=text)
94 description = await self._llm_call(prompt)
95 except Exception as exc: # noqa: BLE001
96 logger.warning(
97 "summarize_node: LLM call failed for node=%s title=%r tokens=%d: %s",
98 node.id,
99 node.title,
100 tokens,
101 exc,
102 )
103 # Failure mode: degrade to raw text. Downstream still has SOMETHING.
104 summary = NodeSummary(text=text, kind="raw", token_count=tokens)
105 node.summary = summary
106 return summary
108 summary_text = (description or "").strip() or text
109 summary = NodeSummary(
110 text=summary_text,
111 kind="llm",
112 token_count=count_tokens(summary_text),
113 )
114 node.summary = summary
115 return summary
117 # ── tree-level walk ───────────────────────────────────────────────
119 async def summarize_tree(self, tree: DocumentTree) -> None:
120 """Summarize every node in the tree. Mutates in place.
122 Runs node summarizations concurrently (bounded by
123 ``max_concurrent_llm``). After completion, each node's
124 ``summary`` is populated; internal nodes get ``kind="prefix"``
125 re-labeled so callers can distinguish "this is summary of a
126 subtree" from "this is summary of a leaf."
127 """
128 nodes = tree.all_nodes()
129 if not nodes:
130 return
132 # Step 1: summarize every node (concurrent, semaphore-bounded inside)
133 await asyncio.gather(*(self.summarize_node(n) for n in nodes))
135 # Step 2: re-label internal nodes' summary kind.
136 # PageIndex distinguishes prefix_summary (internal) from summary (leaf).
137 for n in nodes:
138 if n.summary is not None and not n.is_leaf():
139 # Promote raw→prefix when an internal node had small text
140 # (so its summary is its own heading-and-intro lines) AND
141 # promote llm→prefix when it had a generated description.
142 # Both indicate "this represents the subtree."
143 n.summary = NodeSummary(
144 text=n.summary.text,
145 kind="prefix",
146 token_count=n.summary.token_count,
147 )
149 # ── introspection ─────────────────────────────────────────────────
151 @property
152 def threshold_tokens(self) -> int:
153 return self._threshold