Coverage for astrocyte/recall/proxy.py: 89%

276 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""HTTP proxy recall — fetch remote hits and merge with local RRF (M4.1).""" 

2 

3from __future__ import annotations 

4 

5import asyncio 

6import ipaddress 

7import json 

8import logging 

9import re 

10import socket 

11from typing import TYPE_CHECKING, Any 

12from urllib.parse import ParseResult, parse_qsl, quote, urlparse, urlunparse 

13 

14import httpx 

15 

16from astrocyte.config import SourceConfig 

17from astrocyte.policy.observability import MetricsCollector, span, timed 

18from astrocyte.types import MemoryHit, Metadata 

19 

20if TYPE_CHECKING: 

21 from astrocyte.config import AstrocyteConfig 

22 

23logger = logging.getLogger(__name__) 

24 

25_DEFAULT_TIMEOUT = 15.0 

26 

27# JSON body placeholders (POST) — resolved to ``query`` / ``bank_id`` strings 

28PLACE_QUERY = "__astrocyte.query__" 

29PLACE_BANK = "__astrocyte.bank_id__" 

30 

31_COUNTER = "astrocyte_proxy_recall_total" 

32_HIST = "astrocyte_proxy_recall_duration_seconds" 

33 

34 

35def _expand_proxy_url(template: str, query: str) -> str: 

36 if "{query}" in template: 

37 return template.replace("{query}", quote(query, safe="")) 

38 sep = "&" if "?" in template else "?" 

39 return f"{template}{sep}q={quote(query, safe='')}" 

40 

41 

42def _forbidden_proxy_target_ip_obj(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool: 

43 """True if this address must not be used for outbound proxy recall (SSRF mitigation).""" 

44 return bool( 

45 ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved or ip.is_multicast or ip.is_unspecified 

46 ) 

47 

48 

49def _unsafe_literal_ip(host: str) -> bool: 

50 """True if *host* parses as an IPv4/IPv6 address that must not be used for proxy recall.""" 

51 try: 

52 ip = ipaddress.ip_address(host) 

53 except ValueError: 

54 return False 

55 return _forbidden_proxy_target_ip_obj(ip) 

56 

57 

58def validate_proxy_recall_url(url: str) -> None: 

59 """Reject URLs that enable SSRF (private/loopback/metadata-ranged IPs, non-HTTP(S), no host). 

60 

61 Call this on the **fully expanded** request URL (including encoded user ``query`` fragments). 

62 Hostnames still require :func:`validate_proxy_recall_dns` before HTTP (rebinding-safe check). 

63 """ 

64 if not (url or "").strip(): 

65 raise ValueError("proxy recall URL is empty") 

66 parsed = urlparse(url.strip()) 

67 if parsed.scheme not in ("http", "https"): 

68 raise ValueError(f"proxy recall URL scheme not allowed: {parsed.scheme!r}") 

69 host = parsed.hostname 

70 if not host: 

71 raise ValueError("proxy recall URL has no host") 

72 h = host.lower().rstrip(".") 

73 if h == "localhost" or h.endswith(".localhost"): 

74 raise ValueError("proxy recall URL must not target localhost") 

75 if _unsafe_literal_ip(h): 

76 raise ValueError( 

77 "proxy recall URL must not target loopback, private, link-local, or reserved addresses", 

78 ) 

79 

80 

81def _sync_dns_validate_and_first_public_ip( 

82 host: str, 

83 port: int, 

84) -> ipaddress.IPv4Address | ipaddress.IPv6Address: 

85 """Resolve *host*, reject if any address is forbidden, return the first (OS order) for pinning.""" 

86 try: 

87 infos = socket.getaddrinfo(host, port, type=socket.SOCK_STREAM) 

88 except socket.gaierror as e: 

89 raise ValueError(f"proxy recall DNS resolution failed for {host!r}: {e}") from e 

90 if not infos: 

91 raise ValueError(f"proxy recall DNS returned no addresses for {host!r}") 

92 picked: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = [] 

93 for _fam, _socktype, _proto, _canon, sockaddr in infos: 

94 addr_s = sockaddr[0] 

95 try: 

96 ip = ipaddress.ip_address(addr_s) 

97 except ValueError: 

98 continue 

99 if _forbidden_proxy_target_ip_obj(ip): 

100 raise ValueError( 

101 f"proxy recall DNS resolved to a forbidden address ({addr_s!r} for host {host!r})", 

102 ) 

103 picked.append(ip) 

104 if not picked: 

105 raise ValueError(f"proxy recall DNS returned no usable addresses for {host!r}") 

106 return picked[0] 

107 

108 

109def _is_literal_ip_host(host: str) -> bool: 

110 """True if *host* is an IPv4/IPv6 literal (no DNS name).""" 

111 try: 

112 ipaddress.ip_address(host) 

113 except ValueError: 

114 return False 

115 return True 

116 

117 

118def _proxy_recall_host_header_value(original_hostname: str, port: int, scheme: str) -> str: 

119 """``Host`` header for the original server name (include port when not default).""" 

120 sch = (scheme or "").lower() 

121 default = 443 if sch == "https" else 80 

122 h = original_hostname.lower().rstrip(".") 

123 if port != default: 

124 return f"{h}:{port}" 

125 return h 

126 

127 

128def _rebuild_request_url_pinned_to_ip( 

129 parsed: ParseResult, 

130 *, 

131 pinned: ipaddress.IPv4Address | ipaddress.IPv6Address, 

132 port: int, 

133) -> str: 

134 """Replace authority with pinned IP so the TCP/TLS layer cannot re-resolve differently.""" 

135 if isinstance(pinned, ipaddress.IPv6Address): 

136 hostpart = f"[{pinned.compressed}]" 

137 else: 

138 hostpart = str(pinned) 

139 sch = (parsed.scheme or "").lower() 

140 default = 443 if sch == "https" else 80 

141 if port != default: 

142 netloc = f"{hostpart}:{port}" 

143 else: 

144 netloc = hostpart 

145 return urlunparse((parsed.scheme, netloc, parsed.path, parsed.params, parsed.query, parsed.fragment)) 

146 

147 

148def _httpx_url_and_query_params(url: str) -> tuple[str, list[tuple[str, str]] | None]: 

149 """Split *url* into a path-only URL for the httpx *url* argument and query pairs for *params*. 

150 

151 User-controlled search text must not be concatenated into the URL string passed to httpx 

152 (SSRF static analysis + clearer separation of authority vs query). 

153 """ 

154 p = urlparse((url or "").strip()) 

155 without_query = urlunparse((p.scheme, p.netloc, p.path, p.params, "", p.fragment)) 

156 if not p.query: 

157 return without_query, None 

158 pairs = parse_qsl(p.query, keep_blank_values=True) 

159 if not pairs: 

160 return without_query, None 

161 return without_query, list(pairs) 

162 

163 

164async def validate_proxy_recall_dns(host: str, port: int) -> None: 

165 """Async-safe DNS check: resolve *host* in a worker thread, then validate all addresses.""" 

166 await asyncio.to_thread(_sync_dns_validate_and_first_public_ip, host, port) 

167 

168 

169def auth_with_oauth_cache_namespace( 

170 auth: dict[str, str | int | float | bool | None] | None, 

171 source_id: str, 

172) -> dict[str, str | int | float | bool | None] | None: 

173 """Attach ``_oauth_cache_id`` so OAuth token caches do not collide across proxy sources.""" 

174 if not auth: 

175 return None 

176 return {**auth, "_oauth_cache_id": source_id} 

177 

178 

179async def build_proxy_headers( 

180 auth: dict[str, str | int | float | bool | None] | None, 

181) -> dict[str, str]: 

182 """Build HTTP headers (Bearer, API key, OAuth2 client_credentials / refresh, optional static ``headers``).""" 

183 out: dict[str, str] = {} 

184 if not auth: 

185 return out 

186 extra = auth.get("headers") 

187 if isinstance(extra, dict): 

188 for k, v in extra.items(): 

189 if isinstance(v, (str, int, float, bool)): 

190 out[str(k)] = str(v) 

191 t = (str(auth.get("type") or "")).strip().lower() 

192 grant = (str(auth.get("grant_type") or "")).strip().lower() 

193 if t in ("oauth2", "oauth2_client_credentials") and grant in ("", "client_credentials"): 

194 from astrocyte.recall.oauth import fetch_oauth2_client_credentials_token 

195 

196 token = await fetch_oauth2_client_credentials_token(auth) 

197 out["Authorization"] = f"Bearer {token}" 

198 elif t == "oauth2_refresh" or (t in ("oauth2",) and grant == "refresh_token"): 

199 from astrocyte.recall.oauth import fetch_oauth2_refresh_access_token 

200 

201 token = await fetch_oauth2_refresh_access_token(auth) 

202 out["Authorization"] = f"Bearer {token}" 

203 elif t == "bearer": 

204 token = auth.get("token") 

205 if token is not None and str(token).strip(): 

206 out["Authorization"] = f"Bearer {token}" 

207 elif t == "api_key": 

208 header_name = str(auth.get("header") or "X-API-Key") 

209 if not re.match(r"^[A-Za-z0-9\-]+$", header_name): 

210 raise ValueError(f"Invalid header name in proxy auth config: {header_name!r}") 

211 val = auth.get("value") if auth.get("value") is not None else auth.get("token") 

212 if val is not None and str(val).strip(): 

213 val_str = str(val) 

214 if "\r" in val_str or "\n" in val_str: 

215 raise ValueError("Header value contains CRLF characters (possible injection)") 

216 out[header_name] = val_str 

217 return out 

218 

219 

220def _deep_replace_placeholders(obj: Any, query: str, bank_id: str) -> Any: 

221 if obj == PLACE_QUERY: 

222 return query 

223 if obj == PLACE_BANK: 

224 return bank_id 

225 if isinstance(obj, dict): 

226 return {k: _deep_replace_placeholders(v, query, bank_id) for k, v in obj.items()} 

227 if isinstance(obj, list): 

228 return [_deep_replace_placeholders(x, query, bank_id) for x in obj] 

229 return obj 

230 

231 

232def _resolve_post_json(source: SourceConfig, query: str, bank_id: str) -> dict[str, Any]: 

233 rb = source.recall_body 

234 if rb is None: 

235 return {"query": query, "bank_id": bank_id} 

236 if isinstance(rb, dict): 

237 resolved = _deep_replace_placeholders(rb, query, bank_id) 

238 if isinstance(resolved, dict): 

239 return resolved 

240 return {"query": query, "bank_id": bank_id} 

241 if isinstance(rb, str): 

242 try: 

243 data = json.loads(rb) 

244 except json.JSONDecodeError: 

245 return {"query": query, "bank_id": bank_id} 

246 resolved = _deep_replace_placeholders(data, query, bank_id) 

247 if isinstance(resolved, dict): 

248 return resolved 

249 return {"query": query, "bank_id": bank_id} 

250 return {"query": query, "bank_id": bank_id} 

251 

252 

253def _row_to_hit(source_id: str, row: dict[str, Any]) -> MemoryHit | None: 

254 text = row.get("text") 

255 if not isinstance(text, str) or not text.strip(): 

256 return None 

257 score = row.get("score") 

258 if isinstance(score, (int, float)): 

259 s = float(score) 

260 else: 

261 s = 0.5 

262 mid = row.get("memory_id") 

263 meta_raw = row.get("metadata") 

264 meta: Metadata | None = None 

265 if isinstance(meta_raw, dict): 

266 meta = {} 

267 for k, v in meta_raw.items(): 

268 if isinstance(v, (str, int, float, bool)) or v is None: 

269 meta[str(k)] = v 

270 tags = row.get("tags") 

271 tag_list: list[str] | None = None 

272 if isinstance(tags, list): 

273 tag_list = [str(x) for x in tags] 

274 return MemoryHit( 

275 text=text, 

276 score=min(1.0, max(0.0, s)), 

277 fact_type=str(row["fact_type"]) if row.get("fact_type") is not None else None, 

278 metadata=meta, 

279 tags=tag_list, 

280 memory_id=str(mid) if mid is not None else None, 

281 source=f"proxy:{source_id}", 

282 ) 

283 

284 

285def _parse_hits_payload(data: Any) -> list[Any]: 

286 raw_hits = data.get("hits") if isinstance(data, dict) else None 

287 if raw_hits is None and isinstance(data, dict): 

288 raw_hits = data.get("results") 

289 if not isinstance(raw_hits, list): 

290 return [] 

291 return raw_hits 

292 

293 

294def _record_proxy_metrics( 

295 metrics: MetricsCollector | None, 

296 *, 

297 source_id: str, 

298 status: str, 

299 duration_s: float | None, 

300) -> None: 

301 if not metrics: 

302 return 

303 metrics.inc_counter( 

304 _COUNTER, 

305 {"source_id": source_id, "status": status}, 

306 "Proxy recall attempts by source and status", 

307 ) 

308 if duration_s is not None and status == "ok": 

309 metrics.observe_histogram( 

310 _HIST, 

311 duration_s, 

312 {"source_id": source_id}, 

313 "Proxy recall HTTP duration in seconds", 

314 ) 

315 

316 

317async def fetch_proxy_recall_hits( 

318 source_id: str, 

319 source: SourceConfig, 

320 *, 

321 query: str, 

322 bank_id: str, 

323 timeout: float = _DEFAULT_TIMEOUT, 

324 metrics: MetricsCollector | None = None, 

325) -> list[MemoryHit]: 

326 """Call ``source.url`` (GET or POST) and parse JSON ``hits`` / ``results`` arrays.""" 

327 url_t = source.url or "" 

328 if not url_t.strip(): 

329 return [] 

330 

331 method = (source.recall_method or "GET").strip().upper() 

332 if method not in ("GET", "POST"): 

333 logger.warning("proxy source %s: unknown recall_method %r, using GET", source_id, method) 

334 method = "GET" 

335 

336 base_url = url_t.strip() 

337 

338 with span( 

339 "astrocyte.proxy_recall", 

340 {"source_id": source_id, "method": method, "bank_id": bank_id}, 

341 ): 

342 with timed() as t: 

343 try: 

344 headers = await build_proxy_headers( 

345 auth_with_oauth_cache_namespace(source.auth, source_id), 

346 ) 

347 if method == "POST": 

348 request_url = _expand_proxy_url(base_url, query) if "{query}" in base_url else base_url 

349 body = _resolve_post_json(source, query, bank_id) 

350 else: 

351 request_url = _expand_proxy_url(base_url, query) 

352 body = None 

353 validate_proxy_recall_url(request_url) 

354 parsed_req = urlparse(request_url) 

355 req_host = parsed_req.hostname 

356 if not req_host: 

357 raise ValueError("proxy recall URL has no host") 

358 req_port = parsed_req.port or (443 if parsed_req.scheme == "https" else 80) 

359 sch = (parsed_req.scheme or "").lower() 

360 

361 pinned = await asyncio.to_thread( 

362 _sync_dns_validate_and_first_public_ip, 

363 req_host, 

364 req_port, 

365 ) 

366 

367 if _is_literal_ip_host(req_host): 

368 effective_url = request_url 

369 out_headers = headers 

370 req_extensions: dict[str, Any] | None = None 

371 else: 

372 effective_url = _rebuild_request_url_pinned_to_ip( 

373 parsed_req, 

374 pinned=pinned, 

375 port=req_port, 

376 ) 

377 out_headers = { 

378 **headers, 

379 "Host": _proxy_recall_host_header_value(req_host, req_port, sch), 

380 } 

381 req_extensions = {"sni_hostname": req_host} if sch == "https" else None 

382 

383 httpx_url, httpx_params = _httpx_url_and_query_params(effective_url) 

384 

385 async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client: 

386 if method == "POST": 

387 r = await client.request( 

388 "POST", 

389 httpx_url, 

390 params=httpx_params, 

391 headers=out_headers, 

392 json=body, 

393 extensions=req_extensions, 

394 ) 

395 else: 

396 r = await client.request( 

397 "GET", 

398 httpx_url, 

399 params=httpx_params, 

400 headers=out_headers, 

401 extensions=req_extensions, 

402 ) 

403 r.raise_for_status() 

404 data = r.json() 

405 except Exception: 

406 _record_proxy_metrics(metrics, source_id=source_id, status="error", duration_s=None) 

407 raise 

408 duration_s = t["elapsed_ms"] / 1000.0 

409 

410 _record_proxy_metrics(metrics, source_id=source_id, status="ok", duration_s=duration_s) 

411 

412 raw_hits = _parse_hits_payload(data) 

413 out: list[MemoryHit] = [] 

414 for row in raw_hits: 

415 if not isinstance(row, dict): 

416 continue 

417 hit = _row_to_hit(source_id, row) 

418 if hit: 

419 out.append(hit) 

420 return out 

421 

422 

423async def gather_proxy_hits_for_bank( 

424 config: AstrocyteConfig | Any, 

425 *, 

426 query: str, 

427 bank_id: str, 

428 metrics: MetricsCollector | None = None, 

429) -> list[MemoryHit]: 

430 """Fetch hits from all ``type: proxy`` sources whose ``target_bank`` matches ``bank_id``.""" 

431 sources = getattr(config, "sources", None) or {} 

432 out: list[MemoryHit] = [] 

433 for sid, src in sources.items(): 

434 if not isinstance(src, SourceConfig): 

435 continue 

436 if (src.type or "").strip().lower() != "proxy": 

437 continue 

438 if (src.target_bank or "").strip() != bank_id: 

439 continue 

440 try: 

441 batch = await fetch_proxy_recall_hits(sid, src, query=query, bank_id=bank_id, metrics=metrics) 

442 out.extend(batch) 

443 except Exception as e: 

444 logger.warning("proxy recall failed for source %s: %s", sid, e) 

445 return out 

446 

447 

448async def merge_manual_and_proxy_hits( 

449 config: AstrocyteConfig | Any, 

450 *, 

451 query: str, 

452 bank_id: str, 

453 manual: list[MemoryHit] | None, 

454 metrics: MetricsCollector | None = None, 

455) -> list[MemoryHit] | None: 

456 """Combine caller ``external_context`` with configured proxy sources for this bank.""" 

457 proxy = await gather_proxy_hits_for_bank(config, query=query, bank_id=bank_id, metrics=metrics) 

458 if not proxy and not manual: 

459 return None 

460 return list(manual or []) + proxy