Coverage for astrocyte/providers/openai.py: 71%
86 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""OpenAI LLMProvider adapter.
3Wraps the openai Python SDK to implement the Astrocyte LLMProvider protocol.
5Usage via config YAML:
6 llm_provider: openai
7 llm_provider_config:
8 api_key: ${OPENAI_API_KEY}
9 model: gpt-4o-mini
10 embedding_model: text-embedding-3-small
12Usage programmatically:
13 from astrocyte.providers.openai import OpenAIProvider
15 llm = OpenAIProvider(api_key="sk-...", model="gpt-4o-mini")
16"""
18from __future__ import annotations
20import json
21import os
22from typing import ClassVar
24from astrocyte.types import (
25 Completion,
26 LLMCapabilities,
27 Message,
28 TokenUsage,
29 ToolCall,
30 ToolDefinition,
31)
34class OpenAIProvider:
35 """LLMProvider backed by the OpenAI API."""
37 SPI_VERSION: ClassVar[int] = 1
39 def __init__(
40 self,
41 *,
42 api_key: str | None = None,
43 model: str = "gpt-4o-mini",
44 embedding_model: str = "text-embedding-3-small",
45 base_url: str | None = None,
46 ) -> None:
47 try:
48 import openai
49 except ImportError as e:
50 raise ImportError(
51 "The 'openai' package is required for OpenAIProvider. Install it with: pip install 'astrocyte[openai]'"
52 ) from e
54 resolved_key = api_key or os.environ.get("OPENAI_API_KEY")
55 if not resolved_key:
56 raise ValueError("OpenAI API key is required. Pass api_key= or set OPENAI_API_KEY.")
58 # Use HTTP/2 multiplexing when ``h2`` is installed. Without it, the
59 # OpenAI SDK's default httpx client falls back to HTTP/1.1, which can
60 # only carry one in-flight request per TCP connection — a hard
61 # serialisation bottleneck for concurrent retain/recall workloads.
62 # With HTTP/2, a single connection multiplexes dozens of requests
63 # (typical 10-30x throughput improvement on embedding/completion
64 # heavy phases). On benchmark hardware retaining 272 LoCoMo sessions
65 # drops from ~6h on HTTP/1.1 to under a minute with HTTP/2.
66 try:
67 import h2 # noqa: F401 — presence enables httpx HTTP/2
68 import httpx
70 http_client: httpx.AsyncClient | None = httpx.AsyncClient(
71 http2=True,
72 # Match OpenAI's defaults but with explicit pool settings so
73 # we don't suddenly serialise on bursty workloads.
74 limits=httpx.Limits(
75 max_connections=100,
76 max_keepalive_connections=20,
77 keepalive_expiry=30.0,
78 ),
79 # Read timeout was 600s historically — far too long when an
80 # HTTP/2 stream stalls (no RST/FIN, just no bytes). Such
81 # stalls block for the full timeout × ``max_retries`` (≈30min
82 # per stuck call) and silently murder benchmark wall-clock.
83 # OpenAI completions normally take 5-30s; 90s is generous
84 # while still failing fast on stuck streams. ``max_retries=3``
85 # below handles transient stalls.
86 timeout=httpx.Timeout(connect=5.0, read=90.0, write=90.0, pool=90.0),
87 )
88 except ImportError:
89 # h2 not installed — let the OpenAI SDK use its default HTTP/1.1
90 # client. Throughput will be lower for concurrent workloads but
91 # functionality is unchanged.
92 http_client = None
94 client_kwargs: dict[str, object] = {
95 "api_key": resolved_key,
96 "base_url": base_url,
97 "max_retries": 3, # Retry on 429 / 5xx with exponential backoff
98 }
99 if http_client is not None:
100 client_kwargs["http_client"] = http_client
102 self._client = openai.AsyncOpenAI(**client_kwargs)
103 self._model = model
104 self._embedding_model = embedding_model
106 def capabilities(self) -> LLMCapabilities:
107 return LLMCapabilities(
108 supports_multimodal_completion=True,
109 modalities_supported=("text", "image_url", "image_base64"),
110 supports_multimodal_embedding=False,
111 supports_batch_embed=True,
112 )
114 async def complete(
115 self,
116 messages: list[Message],
117 model: str | None = None,
118 max_tokens: int = 1024,
119 temperature: float = 0.0,
120 tools: list[ToolDefinition] | None = None,
121 tool_choice: str | None = None,
122 response_format: dict | None = None,
123 ) -> Completion:
124 oai_messages = [_to_oai_message(m) for m in messages]
125 use_model = model or self._model
127 # Native function-calling pass-through (Hindsight parity).
128 # When ``tools`` is set, translate to OpenAI's wire format and
129 # forward; otherwise the request is exactly as before so the
130 # legacy text-only callers see no change.
131 kwargs: dict = {
132 "model": use_model,
133 "messages": oai_messages,
134 "max_tokens": max_tokens,
135 "temperature": temperature,
136 }
137 # Structured outputs (Phase 2 of Hindsight cost-control port).
138 # When set, forwards ``response_format`` straight to the OpenAI
139 # ``chat.completions.create`` call. Eliminates malformed-JSON
140 # failures by constraining the decoder to the supplied schema.
141 # Caller is responsible for the schema shape; we don't validate
142 # it here so any future ``response_format`` variant (json_object,
143 # json_schema, custom) works without an SDK bump.
144 if response_format is not None:
145 kwargs["response_format"] = response_format
146 if tools:
147 kwargs["tools"] = [
148 {
149 "type": "function",
150 "function": {
151 "name": t.name,
152 "description": t.description,
153 "parameters": t.parameters,
154 },
155 }
156 for t in tools
157 ]
158 # OpenAI accepts "auto" / "required" / "none" / specific function ref.
159 if tool_choice in ("auto", "required", "none"):
160 kwargs["tool_choice"] = tool_choice
161 elif tool_choice:
162 kwargs["tool_choice"] = {"type": "function", "function": {"name": tool_choice}}
163 else:
164 kwargs["tool_choice"] = "auto"
166 response = await self._client.chat.completions.create(**kwargs)
168 choice = response.choices[0]
169 usage = None
170 if response.usage:
171 usage = TokenUsage(
172 input_tokens=response.usage.prompt_tokens,
173 output_tokens=response.usage.completion_tokens,
174 )
176 # Parse tool calls when the model emitted them. The OpenAI SDK
177 # returns ``message.tool_calls`` as a list (or ``None``); each
178 # item has ``id``, ``function.name``, ``function.arguments`` (a
179 # JSON string we parse to a dict).
180 tool_calls: list[ToolCall] | None = None
181 raw_tool_calls = getattr(choice.message, "tool_calls", None)
182 if raw_tool_calls:
183 tool_calls = []
184 for tc in raw_tool_calls:
185 fn = getattr(tc, "function", None)
186 if fn is None:
187 continue
188 try:
189 args = json.loads(fn.arguments) if fn.arguments else {}
190 except json.JSONDecodeError:
191 args = {}
192 tool_calls.append(
193 ToolCall(
194 id=tc.id,
195 name=fn.name,
196 arguments=args if isinstance(args, dict) else {},
197 )
198 )
200 return Completion(
201 text=choice.message.content or "",
202 model=response.model,
203 usage=usage,
204 tool_calls=tool_calls,
205 )
207 async def embed(
208 self,
209 texts: list[str],
210 model: str | None = None,
211 ) -> list[list[float]]:
212 use_model = model or self._embedding_model
214 # Truncate texts that would exceed the model's token limit.
215 # text-embedding-3-small has an 8192-token limit (~30K chars).
216 # Truncating at 28K chars provides a safe margin.
217 max_chars = 28_000
218 safe_texts = [_sanitize_text(t)[:max_chars] for t in texts]
220 response = await self._client.embeddings.create(
221 model=use_model,
222 input=safe_texts,
223 )
225 # Sort by index to guarantee order matches input
226 sorted_data = sorted(response.data, key=lambda d: d.index)
227 return [d.embedding for d in sorted_data]
230def _sanitize_text(text: str) -> str:
231 """Remove control characters that break the OpenAI API's JSON parser."""
232 # Remove null bytes and other C0 control chars except \t \n \r
233 return "".join(ch for ch in text if ch in ("\t", "\n", "\r") or (ord(ch) >= 32))
236def _to_oai_message(msg: Message) -> dict:
237 """Convert an Astrocyte Message to an OpenAI API message dict.
239 Tool-call round-trip support (Hindsight-parity agentic reflect):
240 - ``role="tool"`` messages carry ``tool_call_id`` (required by
241 OpenAI) and the tool's serialized output as ``content``.
242 - ``role="assistant"`` messages with ``tool_calls`` are translated
243 back into the OpenAI ``tool_calls`` array so the model sees its
244 own prior calls when the loop continues.
245 """
246 if msg.role == "tool":
247 return {
248 "role": "tool",
249 "tool_call_id": msg.tool_call_id or "",
250 "content": _sanitize_text(msg.content) if isinstance(msg.content, str) else "",
251 }
252 if msg.role == "assistant" and msg.tool_calls:
253 return {
254 "role": "assistant",
255 "content": _sanitize_text(msg.content) if isinstance(msg.content, str) else None,
256 "tool_calls": [
257 {
258 "id": tc.id,
259 "type": "function",
260 "function": {
261 "name": tc.name,
262 "arguments": json.dumps(tc.arguments),
263 },
264 }
265 for tc in msg.tool_calls
266 ],
267 }
268 if isinstance(msg.content, str):
269 return {"role": msg.role, "content": _sanitize_text(msg.content)}
271 # Multimodal: list of ContentPart
272 parts = []
273 for part in msg.content:
274 if part.type == "text" and part.text:
275 parts.append({"type": "text", "text": _sanitize_text(part.text)})
276 elif part.type == "image_url" and part.image_url:
277 parts.append(
278 {
279 "type": "image_url",
280 "image_url": {"url": part.image_url},
281 }
282 )
283 elif part.type == "image_base64" and part.image_base64:
284 parts.append(
285 {
286 "type": "image_url",
287 "image_url": {"url": f"data:image/png;base64,{part.image_base64}"},
288 }
289 )
291 return {"role": msg.role, "content": parts}