Coverage for astrocyte/providers/openai.py: 71%

86 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""OpenAI LLMProvider adapter. 

2 

3Wraps the openai Python SDK to implement the Astrocyte LLMProvider protocol. 

4 

5Usage via config YAML: 

6 llm_provider: openai 

7 llm_provider_config: 

8 api_key: ${OPENAI_API_KEY} 

9 model: gpt-4o-mini 

10 embedding_model: text-embedding-3-small 

11 

12Usage programmatically: 

13 from astrocyte.providers.openai import OpenAIProvider 

14 

15 llm = OpenAIProvider(api_key="sk-...", model="gpt-4o-mini") 

16""" 

17 

18from __future__ import annotations 

19 

20import json 

21import os 

22from typing import ClassVar 

23 

24from astrocyte.types import ( 

25 Completion, 

26 LLMCapabilities, 

27 Message, 

28 TokenUsage, 

29 ToolCall, 

30 ToolDefinition, 

31) 

32 

33 

34class OpenAIProvider: 

35 """LLMProvider backed by the OpenAI API.""" 

36 

37 SPI_VERSION: ClassVar[int] = 1 

38 

39 def __init__( 

40 self, 

41 *, 

42 api_key: str | None = None, 

43 model: str = "gpt-4o-mini", 

44 embedding_model: str = "text-embedding-3-small", 

45 base_url: str | None = None, 

46 ) -> None: 

47 try: 

48 import openai 

49 except ImportError as e: 

50 raise ImportError( 

51 "The 'openai' package is required for OpenAIProvider. Install it with: pip install 'astrocyte[openai]'" 

52 ) from e 

53 

54 resolved_key = api_key or os.environ.get("OPENAI_API_KEY") 

55 if not resolved_key: 

56 raise ValueError("OpenAI API key is required. Pass api_key= or set OPENAI_API_KEY.") 

57 

58 # Use HTTP/2 multiplexing when ``h2`` is installed. Without it, the 

59 # OpenAI SDK's default httpx client falls back to HTTP/1.1, which can 

60 # only carry one in-flight request per TCP connection — a hard 

61 # serialisation bottleneck for concurrent retain/recall workloads. 

62 # With HTTP/2, a single connection multiplexes dozens of requests 

63 # (typical 10-30x throughput improvement on embedding/completion 

64 # heavy phases). On benchmark hardware retaining 272 LoCoMo sessions 

65 # drops from ~6h on HTTP/1.1 to under a minute with HTTP/2. 

66 try: 

67 import h2 # noqa: F401 — presence enables httpx HTTP/2 

68 import httpx 

69 

70 http_client: httpx.AsyncClient | None = httpx.AsyncClient( 

71 http2=True, 

72 # Match OpenAI's defaults but with explicit pool settings so 

73 # we don't suddenly serialise on bursty workloads. 

74 limits=httpx.Limits( 

75 max_connections=100, 

76 max_keepalive_connections=20, 

77 keepalive_expiry=30.0, 

78 ), 

79 # Read timeout was 600s historically — far too long when an 

80 # HTTP/2 stream stalls (no RST/FIN, just no bytes). Such 

81 # stalls block for the full timeout × ``max_retries`` (≈30min 

82 # per stuck call) and silently murder benchmark wall-clock. 

83 # OpenAI completions normally take 5-30s; 90s is generous 

84 # while still failing fast on stuck streams. ``max_retries=3`` 

85 # below handles transient stalls. 

86 timeout=httpx.Timeout(connect=5.0, read=90.0, write=90.0, pool=90.0), 

87 ) 

88 except ImportError: 

89 # h2 not installed — let the OpenAI SDK use its default HTTP/1.1 

90 # client. Throughput will be lower for concurrent workloads but 

91 # functionality is unchanged. 

92 http_client = None 

93 

94 client_kwargs: dict[str, object] = { 

95 "api_key": resolved_key, 

96 "base_url": base_url, 

97 "max_retries": 3, # Retry on 429 / 5xx with exponential backoff 

98 } 

99 if http_client is not None: 

100 client_kwargs["http_client"] = http_client 

101 

102 self._client = openai.AsyncOpenAI(**client_kwargs) 

103 self._model = model 

104 self._embedding_model = embedding_model 

105 

106 def capabilities(self) -> LLMCapabilities: 

107 return LLMCapabilities( 

108 supports_multimodal_completion=True, 

109 modalities_supported=("text", "image_url", "image_base64"), 

110 supports_multimodal_embedding=False, 

111 supports_batch_embed=True, 

112 ) 

113 

114 async def complete( 

115 self, 

116 messages: list[Message], 

117 model: str | None = None, 

118 max_tokens: int = 1024, 

119 temperature: float = 0.0, 

120 tools: list[ToolDefinition] | None = None, 

121 tool_choice: str | None = None, 

122 response_format: dict | None = None, 

123 ) -> Completion: 

124 oai_messages = [_to_oai_message(m) for m in messages] 

125 use_model = model or self._model 

126 

127 # Native function-calling pass-through (Hindsight parity). 

128 # When ``tools`` is set, translate to OpenAI's wire format and 

129 # forward; otherwise the request is exactly as before so the 

130 # legacy text-only callers see no change. 

131 kwargs: dict = { 

132 "model": use_model, 

133 "messages": oai_messages, 

134 "max_tokens": max_tokens, 

135 "temperature": temperature, 

136 } 

137 # Structured outputs (Phase 2 of Hindsight cost-control port). 

138 # When set, forwards ``response_format`` straight to the OpenAI 

139 # ``chat.completions.create`` call. Eliminates malformed-JSON 

140 # failures by constraining the decoder to the supplied schema. 

141 # Caller is responsible for the schema shape; we don't validate 

142 # it here so any future ``response_format`` variant (json_object, 

143 # json_schema, custom) works without an SDK bump. 

144 if response_format is not None: 

145 kwargs["response_format"] = response_format 

146 if tools: 

147 kwargs["tools"] = [ 

148 { 

149 "type": "function", 

150 "function": { 

151 "name": t.name, 

152 "description": t.description, 

153 "parameters": t.parameters, 

154 }, 

155 } 

156 for t in tools 

157 ] 

158 # OpenAI accepts "auto" / "required" / "none" / specific function ref. 

159 if tool_choice in ("auto", "required", "none"): 

160 kwargs["tool_choice"] = tool_choice 

161 elif tool_choice: 

162 kwargs["tool_choice"] = {"type": "function", "function": {"name": tool_choice}} 

163 else: 

164 kwargs["tool_choice"] = "auto" 

165 

166 response = await self._client.chat.completions.create(**kwargs) 

167 

168 choice = response.choices[0] 

169 usage = None 

170 if response.usage: 

171 usage = TokenUsage( 

172 input_tokens=response.usage.prompt_tokens, 

173 output_tokens=response.usage.completion_tokens, 

174 ) 

175 

176 # Parse tool calls when the model emitted them. The OpenAI SDK 

177 # returns ``message.tool_calls`` as a list (or ``None``); each 

178 # item has ``id``, ``function.name``, ``function.arguments`` (a 

179 # JSON string we parse to a dict). 

180 tool_calls: list[ToolCall] | None = None 

181 raw_tool_calls = getattr(choice.message, "tool_calls", None) 

182 if raw_tool_calls: 

183 tool_calls = [] 

184 for tc in raw_tool_calls: 

185 fn = getattr(tc, "function", None) 

186 if fn is None: 

187 continue 

188 try: 

189 args = json.loads(fn.arguments) if fn.arguments else {} 

190 except json.JSONDecodeError: 

191 args = {} 

192 tool_calls.append( 

193 ToolCall( 

194 id=tc.id, 

195 name=fn.name, 

196 arguments=args if isinstance(args, dict) else {}, 

197 ) 

198 ) 

199 

200 return Completion( 

201 text=choice.message.content or "", 

202 model=response.model, 

203 usage=usage, 

204 tool_calls=tool_calls, 

205 ) 

206 

207 async def embed( 

208 self, 

209 texts: list[str], 

210 model: str | None = None, 

211 ) -> list[list[float]]: 

212 use_model = model or self._embedding_model 

213 

214 # Truncate texts that would exceed the model's token limit. 

215 # text-embedding-3-small has an 8192-token limit (~30K chars). 

216 # Truncating at 28K chars provides a safe margin. 

217 max_chars = 28_000 

218 safe_texts = [_sanitize_text(t)[:max_chars] for t in texts] 

219 

220 response = await self._client.embeddings.create( 

221 model=use_model, 

222 input=safe_texts, 

223 ) 

224 

225 # Sort by index to guarantee order matches input 

226 sorted_data = sorted(response.data, key=lambda d: d.index) 

227 return [d.embedding for d in sorted_data] 

228 

229 

230def _sanitize_text(text: str) -> str: 

231 """Remove control characters that break the OpenAI API's JSON parser.""" 

232 # Remove null bytes and other C0 control chars except \t \n \r 

233 return "".join(ch for ch in text if ch in ("\t", "\n", "\r") or (ord(ch) >= 32)) 

234 

235 

236def _to_oai_message(msg: Message) -> dict: 

237 """Convert an Astrocyte Message to an OpenAI API message dict. 

238 

239 Tool-call round-trip support (Hindsight-parity agentic reflect): 

240 - ``role="tool"`` messages carry ``tool_call_id`` (required by 

241 OpenAI) and the tool's serialized output as ``content``. 

242 - ``role="assistant"`` messages with ``tool_calls`` are translated 

243 back into the OpenAI ``tool_calls`` array so the model sees its 

244 own prior calls when the loop continues. 

245 """ 

246 if msg.role == "tool": 

247 return { 

248 "role": "tool", 

249 "tool_call_id": msg.tool_call_id or "", 

250 "content": _sanitize_text(msg.content) if isinstance(msg.content, str) else "", 

251 } 

252 if msg.role == "assistant" and msg.tool_calls: 

253 return { 

254 "role": "assistant", 

255 "content": _sanitize_text(msg.content) if isinstance(msg.content, str) else None, 

256 "tool_calls": [ 

257 { 

258 "id": tc.id, 

259 "type": "function", 

260 "function": { 

261 "name": tc.name, 

262 "arguments": json.dumps(tc.arguments), 

263 }, 

264 } 

265 for tc in msg.tool_calls 

266 ], 

267 } 

268 if isinstance(msg.content, str): 

269 return {"role": msg.role, "content": _sanitize_text(msg.content)} 

270 

271 # Multimodal: list of ContentPart 

272 parts = [] 

273 for part in msg.content: 

274 if part.type == "text" and part.text: 

275 parts.append({"type": "text", "text": _sanitize_text(part.text)}) 

276 elif part.type == "image_url" and part.image_url: 

277 parts.append( 

278 { 

279 "type": "image_url", 

280 "image_url": {"url": part.image_url}, 

281 } 

282 ) 

283 elif part.type == "image_base64" and part.image_base64: 

284 parts.append( 

285 { 

286 "type": "image_url", 

287 "image_url": {"url": f"data:image/png;base64,{part.image_base64}"}, 

288 } 

289 ) 

290 

291 return {"role": msg.role, "content": parts}