Coverage for astrocyte/documents/builders/summarizer.py: 98%

49 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Adaptive per-node summarization (PageIndex parity). 

2 

3Walks a ``DocumentTree`` and populates each node's ``summary`` field 

4according to PageIndex's rule: 

5 

6 - If node text < threshold tokens: ``summary.text = node.text`` (no 

7 LLM call; ``kind="raw"``). 

8 - If node text ≥ threshold tokens: call LLM with PageIndex's prompt; 

9 ``kind="llm"``. 

10 - For internal nodes (have children): ``kind="prefix"`` — the summary 

11 is meant to be read as "what this whole subtree is about." 

12 

13Default threshold: 200 tokens (PageIndex default; configurable). 

14 

15Public API: 

16 AdaptiveSummarizer(llm_call, threshold_tokens=200, model="gpt-4o-mini") 

17 .summarize_tree(tree) # in-place: sets node.summary on every node 

18 .summarize_node(node) # single-node helper 

19 

20``llm_call`` is a coroutine ``async def fn(prompt: str) -> str``. Pass a 

21real LLM provider in production; pass a fake in tests. Keeps the 

22summarizer testable without an OpenAI key. 

23""" 

24 

25from __future__ import annotations 

26 

27import asyncio 

28import logging 

29from typing import Awaitable, Callable 

30 

31from astrocyte.documents.types import DocumentTree, NodeSummary, TreeNode 

32from astrocyte.policy.homeostasis import count_tokens 

33 

34logger = logging.getLogger(__name__) 

35 

36 

37# PageIndex parity — the exact prompt from 

38# /Users/calvin/AstrocyteAI/PageIndex/pageindex/utils.py:generate_node_summary 

39PAGEINDEX_SUMMARY_PROMPT = """You are given a part of a document, your task is to generate a description of the partial document about what are main points covered in the partial document. 

40 

41 Partial Document Text: {node_text} 

42 

43 Directly return the description, do not include any other text. 

44 """ 

45 

46DEFAULT_THRESHOLD_TOKENS = 200 

47 

48 

49LlmCall = Callable[[str], Awaitable[str]] 

50 

51 

52class AdaptiveSummarizer: 

53 """Per-node summarizer with PageIndex-style adaptive LLM gating. 

54 

55 Cost shape: most LME / LoCoMo session-blocks are < 200 tokens and 

56 skip the LLM entirely. LLM is only called for genuinely long nodes 

57 where the raw text would be too long to use directly downstream. 

58 """ 

59 

60 def __init__( 

61 self, 

62 llm_call: LlmCall, 

63 *, 

64 threshold_tokens: int = DEFAULT_THRESHOLD_TOKENS, 

65 model: str = "gpt-4o-mini", 

66 max_concurrent_llm: int = 8, 

67 ) -> None: 

68 if threshold_tokens < 1: 

69 raise ValueError(f"threshold_tokens must be >= 1, got {threshold_tokens}") 

70 self._llm_call = llm_call 

71 self._threshold = threshold_tokens 

72 self.model = model 

73 self._sem = asyncio.Semaphore(max_concurrent_llm) 

74 

75 # ── single-node helper ──────────────────────────────────────────── 

76 

77 async def summarize_node(self, node: TreeNode) -> NodeSummary: 

78 """Produce a NodeSummary for a single node. 

79 

80 Returns the summary (and sets ``node.summary``). Internal-node 

81 determination is left to the caller because it depends on tree 

82 traversal context — use ``summarize_tree`` for the typical case. 

83 """ 

84 text = node.text or "" 

85 tokens = count_tokens(text) 

86 if tokens < self._threshold: 

87 summary = NodeSummary(text=text, kind="raw", token_count=tokens) 

88 node.summary = summary 

89 return summary 

90 

91 async with self._sem: 

92 try: 

93 prompt = PAGEINDEX_SUMMARY_PROMPT.format(node_text=text) 

94 description = await self._llm_call(prompt) 

95 except Exception as exc: # noqa: BLE001 

96 logger.warning( 

97 "summarize_node: LLM call failed for node=%s title=%r tokens=%d: %s", 

98 node.id, 

99 node.title, 

100 tokens, 

101 exc, 

102 ) 

103 # Failure mode: degrade to raw text. Downstream still has SOMETHING. 

104 summary = NodeSummary(text=text, kind="raw", token_count=tokens) 

105 node.summary = summary 

106 return summary 

107 

108 summary_text = (description or "").strip() or text 

109 summary = NodeSummary( 

110 text=summary_text, 

111 kind="llm", 

112 token_count=count_tokens(summary_text), 

113 ) 

114 node.summary = summary 

115 return summary 

116 

117 # ── tree-level walk ─────────────────────────────────────────────── 

118 

119 async def summarize_tree(self, tree: DocumentTree) -> None: 

120 """Summarize every node in the tree. Mutates in place. 

121 

122 Runs node summarizations concurrently (bounded by 

123 ``max_concurrent_llm``). After completion, each node's 

124 ``summary`` is populated; internal nodes get ``kind="prefix"`` 

125 re-labeled so callers can distinguish "this is summary of a 

126 subtree" from "this is summary of a leaf." 

127 """ 

128 nodes = tree.all_nodes() 

129 if not nodes: 

130 return 

131 

132 # Step 1: summarize every node (concurrent, semaphore-bounded inside) 

133 await asyncio.gather(*(self.summarize_node(n) for n in nodes)) 

134 

135 # Step 2: re-label internal nodes' summary kind. 

136 # PageIndex distinguishes prefix_summary (internal) from summary (leaf). 

137 for n in nodes: 

138 if n.summary is not None and not n.is_leaf(): 

139 # Promote raw→prefix when an internal node had small text 

140 # (so its summary is its own heading-and-intro lines) AND 

141 # promote llm→prefix when it had a generated description. 

142 # Both indicate "this represents the subtree." 

143 n.summary = NodeSummary( 

144 text=n.summary.text, 

145 kind="prefix", 

146 token_count=n.summary.token_count, 

147 ) 

148 

149 # ── introspection ───────────────────────────────────────────────── 

150 

151 @property 

152 def threshold_tokens(self) -> int: 

153 return self._threshold