Coverage for astrocyte/pipeline/query_rewrite.py: 0%
33 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""Query rewriting for retrieval (M38).
3Reformulates the user's natural-language question into a search-optimized
4form *before* retrieval. The rewritten query is fed to the recall
5pipeline; the ORIGINAL question still reaches the answerer so reasoning
6context is preserved.
8Why
9---
11Natural questions often have phrasing that doesn't match how facts are
12indexed:
14- "Which book did I read a week ago?" → search target is "books read",
15 date filter is "week ago"
16- "How many weeks since I quit smoking?" → search target is "quit
17 smoking" event, answer requires date math
18- "What is my favourite coffee?" → search target is "coffee preference"
20The cheap pattern: one LLM call (gpt-4o-mini, ~300ms, ~$0.0001) maps
21the question into a search query optimized for vector similarity +
22keyword match. The answerer still sees the original question so it
23reasons in the user's voice.
25This is *retrieval-side* rewriting, not query expansion. We produce ONE
26rewritten string, not N variants. Multi-query expansion (M38b) is a
27future extension.
29Compared to Hindsight
30---------------------
32Hindsight doesn't have an explicit query rewriter at retain or query
33time. Their `agentic_reflect` loop achieves a similar effect by letting
34the LLM call ``recall(sub_query=...)`` iteratively with refined
35queries. M38 is the *non-agentic* equivalent — one rewrite, one
36retrieval, much cheaper than a full reflect loop. Stacks naturally
37with M36 reflect routing (rewrite first, then route).
39Reference
40---------
42- ``docs/_design/m36-reflect-loop.md`` — M36 reflect routing
43- ``docs/_design/m34-query-intent-routing.md`` — query intent + analyzer
44"""
46from __future__ import annotations
48import logging
49from typing import Any, Protocol
51_logger = logging.getLogger("astrocyte.pipeline.query_rewrite")
54class _LLMProvider(Protocol):
55 async def complete(self, *, messages: list, model: str | None = None, **kwargs: Any) -> Any:
56 """Send ``messages`` to the LLM and return the completion. The
57 rewriter only needs the response text — adapters may return
58 their richer ``Completion`` shape; this Protocol is intentionally
59 permissive on the return type (``Any``)."""
62_REWRITE_SYSTEM_PROMPT = """\
63You rewrite user questions into search queries optimized for semantic + keyword retrieval over a personal memory store.
65Rules:
661. Extract the core search target (event, entity, preference, fact).
672. Preserve specific names, numbers, dates, and proper nouns.
683. Drop conversational fluff ("did I tell you about", "do you remember when").
694. For "how many/long" questions about durations, focus on the EVENT being asked about (the duration is answer-side, not retrieval-side).
705. For comparison questions ("which X first", "before or after"), keep BOTH entities in the search query.
716. Output one line, no quotes, no markdown, no commentary.
73Examples:
74Question: "Which book did I finish reading a week ago?"
75Rewrite: books I finished reading
77Question: "How many weeks ago did I attend the music festival?"
78Rewrite: music festival attendance date
80Question: "What's my favourite coffee?"
81Rewrite: coffee preference favourite
83Question: "Did I tell you about my promotion before or after meeting my new manager?"
84Rewrite: promotion event and meeting new manager event
86Question: "What is my brother's name?"
87Rewrite: brother name family
88"""
91async def rewrite_query(
92 question: str,
93 *,
94 llm_provider: _LLMProvider,
95 model: str | None = None,
96 timeout_sec: float = 5.0,
97) -> str | None:
98 """Rewrite ``question`` into a search-optimized query.
100 Returns the rewritten string, or ``None`` if the call fails or the
101 rewrite is empty / suspiciously short. Caller should fall back to
102 the original question on ``None``.
104 Single LLM call (~300ms, ~$0.0001 at gpt-4o-mini pricing). Designed
105 to be cheap enough to call on every recall.
107 Args:
108 question: The user's natural question.
109 llm_provider: An object with an awaitable ``complete(messages,
110 model)`` method matching the in-tree ``LLMProvider`` shape.
111 model: Optional model override. Defaults to the provider's
112 default (typically gpt-4o-mini in our bench config).
113 timeout_sec: Soft timeout; if the call takes longer than this,
114 we still wait (no hard kill) — caller can layer asyncio
115 timeouts on top if needed.
117 Returns:
118 Rewritten query string, or ``None`` on failure / empty output.
119 """
120 import asyncio
122 from astrocyte.types import Message # noqa: PLC0415
124 if not question or not question.strip():
125 return None
126 msg = question.strip()
127 try:
128 completion = await asyncio.wait_for(
129 llm_provider.complete(
130 messages=[
131 Message(role="system", content=_REWRITE_SYSTEM_PROMPT),
132 Message(role="user", content=f"Question: {msg}\nRewrite:"),
133 ],
134 model=model,
135 max_tokens=120,
136 temperature=0.0,
137 ),
138 timeout=timeout_sec,
139 )
140 except asyncio.TimeoutError:
141 _logger.warning("query_rewrite: timeout after %.1fs on question: %s", timeout_sec, msg[:80])
142 return None
143 except Exception as exc: # noqa: BLE001
144 _logger.warning("query_rewrite: %s on question: %s", type(exc).__name__, msg[:80])
145 return None
147 rewritten = (getattr(completion, "text", "") or "").strip()
148 # Strip a "Rewrite:" prefix if the model echoes it.
149 if rewritten.lower().startswith("rewrite:"):
150 rewritten = rewritten[len("rewrite:"):].strip()
151 # Strip surrounding quotes the model sometimes adds.
152 if (rewritten.startswith('"') and rewritten.endswith('"')) or (
153 rewritten.startswith("'") and rewritten.endswith("'")
154 ):
155 rewritten = rewritten[1:-1].strip()
156 if not rewritten or len(rewritten) < 3:
157 return None
158 # Sanity guard: a rewrite that's twice as long as the original is
159 # likely a model hallucination — fall back to the original.
160 if len(rewritten) > 2 * len(msg) + 100:
161 _logger.warning("query_rewrite: suspiciously long output (%dc vs %dc), discarding", len(rewritten), len(msg))
162 return None
163 return rewritten
166__all__ = ["rewrite_query"]