Coverage for astrocyte/pipeline/query_rewrite.py: 0%

33 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""Query rewriting for retrieval (M38). 

2 

3Reformulates the user's natural-language question into a search-optimized 

4form *before* retrieval. The rewritten query is fed to the recall 

5pipeline; the ORIGINAL question still reaches the answerer so reasoning 

6context is preserved. 

7 

8Why 

9--- 

10 

11Natural questions often have phrasing that doesn't match how facts are 

12indexed: 

13 

14- "Which book did I read a week ago?" → search target is "books read", 

15 date filter is "week ago" 

16- "How many weeks since I quit smoking?" → search target is "quit 

17 smoking" event, answer requires date math 

18- "What is my favourite coffee?" → search target is "coffee preference" 

19 

20The cheap pattern: one LLM call (gpt-4o-mini, ~300ms, ~$0.0001) maps 

21the question into a search query optimized for vector similarity + 

22keyword match. The answerer still sees the original question so it 

23reasons in the user's voice. 

24 

25This is *retrieval-side* rewriting, not query expansion. We produce ONE 

26rewritten string, not N variants. Multi-query expansion (M38b) is a 

27future extension. 

28 

29Compared to Hindsight 

30--------------------- 

31 

32Hindsight doesn't have an explicit query rewriter at retain or query 

33time. Their `agentic_reflect` loop achieves a similar effect by letting 

34the LLM call ``recall(sub_query=...)`` iteratively with refined 

35queries. M38 is the *non-agentic* equivalent — one rewrite, one 

36retrieval, much cheaper than a full reflect loop. Stacks naturally 

37with M36 reflect routing (rewrite first, then route). 

38 

39Reference 

40--------- 

41 

42- ``docs/_design/m36-reflect-loop.md`` — M36 reflect routing 

43- ``docs/_design/m34-query-intent-routing.md`` — query intent + analyzer 

44""" 

45 

46from __future__ import annotations 

47 

48import logging 

49from typing import Any, Protocol 

50 

51_logger = logging.getLogger("astrocyte.pipeline.query_rewrite") 

52 

53 

54class _LLMProvider(Protocol): 

55 async def complete(self, *, messages: list, model: str | None = None, **kwargs: Any) -> Any: 

56 """Send ``messages`` to the LLM and return the completion. The 

57 rewriter only needs the response text — adapters may return 

58 their richer ``Completion`` shape; this Protocol is intentionally 

59 permissive on the return type (``Any``).""" 

60 

61 

62_REWRITE_SYSTEM_PROMPT = """\ 

63You rewrite user questions into search queries optimized for semantic + keyword retrieval over a personal memory store. 

64 

65Rules: 

661. Extract the core search target (event, entity, preference, fact). 

672. Preserve specific names, numbers, dates, and proper nouns. 

683. Drop conversational fluff ("did I tell you about", "do you remember when"). 

694. For "how many/long" questions about durations, focus on the EVENT being asked about (the duration is answer-side, not retrieval-side). 

705. For comparison questions ("which X first", "before or after"), keep BOTH entities in the search query. 

716. Output one line, no quotes, no markdown, no commentary. 

72 

73Examples: 

74Question: "Which book did I finish reading a week ago?" 

75Rewrite: books I finished reading 

76 

77Question: "How many weeks ago did I attend the music festival?" 

78Rewrite: music festival attendance date 

79 

80Question: "What's my favourite coffee?" 

81Rewrite: coffee preference favourite 

82 

83Question: "Did I tell you about my promotion before or after meeting my new manager?" 

84Rewrite: promotion event and meeting new manager event 

85 

86Question: "What is my brother's name?" 

87Rewrite: brother name family 

88""" 

89 

90 

91async def rewrite_query( 

92 question: str, 

93 *, 

94 llm_provider: _LLMProvider, 

95 model: str | None = None, 

96 timeout_sec: float = 5.0, 

97) -> str | None: 

98 """Rewrite ``question`` into a search-optimized query. 

99 

100 Returns the rewritten string, or ``None`` if the call fails or the 

101 rewrite is empty / suspiciously short. Caller should fall back to 

102 the original question on ``None``. 

103 

104 Single LLM call (~300ms, ~$0.0001 at gpt-4o-mini pricing). Designed 

105 to be cheap enough to call on every recall. 

106 

107 Args: 

108 question: The user's natural question. 

109 llm_provider: An object with an awaitable ``complete(messages, 

110 model)`` method matching the in-tree ``LLMProvider`` shape. 

111 model: Optional model override. Defaults to the provider's 

112 default (typically gpt-4o-mini in our bench config). 

113 timeout_sec: Soft timeout; if the call takes longer than this, 

114 we still wait (no hard kill) — caller can layer asyncio 

115 timeouts on top if needed. 

116 

117 Returns: 

118 Rewritten query string, or ``None`` on failure / empty output. 

119 """ 

120 import asyncio 

121 

122 from astrocyte.types import Message # noqa: PLC0415 

123 

124 if not question or not question.strip(): 

125 return None 

126 msg = question.strip() 

127 try: 

128 completion = await asyncio.wait_for( 

129 llm_provider.complete( 

130 messages=[ 

131 Message(role="system", content=_REWRITE_SYSTEM_PROMPT), 

132 Message(role="user", content=f"Question: {msg}\nRewrite:"), 

133 ], 

134 model=model, 

135 max_tokens=120, 

136 temperature=0.0, 

137 ), 

138 timeout=timeout_sec, 

139 ) 

140 except asyncio.TimeoutError: 

141 _logger.warning("query_rewrite: timeout after %.1fs on question: %s", timeout_sec, msg[:80]) 

142 return None 

143 except Exception as exc: # noqa: BLE001 

144 _logger.warning("query_rewrite: %s on question: %s", type(exc).__name__, msg[:80]) 

145 return None 

146 

147 rewritten = (getattr(completion, "text", "") or "").strip() 

148 # Strip a "Rewrite:" prefix if the model echoes it. 

149 if rewritten.lower().startswith("rewrite:"): 

150 rewritten = rewritten[len("rewrite:"):].strip() 

151 # Strip surrounding quotes the model sometimes adds. 

152 if (rewritten.startswith('"') and rewritten.endswith('"')) or ( 

153 rewritten.startswith("'") and rewritten.endswith("'") 

154 ): 

155 rewritten = rewritten[1:-1].strip() 

156 if not rewritten or len(rewritten) < 3: 

157 return None 

158 # Sanity guard: a rewrite that's twice as long as the original is 

159 # likely a model hallucination — fall back to the original. 

160 if len(rewritten) > 2 * len(msg) + 100: 

161 _logger.warning("query_rewrite: suspiciously long output (%dc vs %dc), discarding", len(rewritten), len(msg)) 

162 return None 

163 return rewritten 

164 

165 

166__all__ = ["rewrite_query"]