Coverage for astrocyte/recall/proxy.py: 89%
276 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""HTTP proxy recall — fetch remote hits and merge with local RRF (M4.1)."""
3from __future__ import annotations
5import asyncio
6import ipaddress
7import json
8import logging
9import re
10import socket
11from typing import TYPE_CHECKING, Any
12from urllib.parse import ParseResult, parse_qsl, quote, urlparse, urlunparse
14import httpx
16from astrocyte.config import SourceConfig
17from astrocyte.policy.observability import MetricsCollector, span, timed
18from astrocyte.types import MemoryHit, Metadata
20if TYPE_CHECKING:
21 from astrocyte.config import AstrocyteConfig
23logger = logging.getLogger(__name__)
25_DEFAULT_TIMEOUT = 15.0
27# JSON body placeholders (POST) — resolved to ``query`` / ``bank_id`` strings
28PLACE_QUERY = "__astrocyte.query__"
29PLACE_BANK = "__astrocyte.bank_id__"
31_COUNTER = "astrocyte_proxy_recall_total"
32_HIST = "astrocyte_proxy_recall_duration_seconds"
35def _expand_proxy_url(template: str, query: str) -> str:
36 if "{query}" in template:
37 return template.replace("{query}", quote(query, safe=""))
38 sep = "&" if "?" in template else "?"
39 return f"{template}{sep}q={quote(query, safe='')}"
42def _forbidden_proxy_target_ip_obj(ip: ipaddress.IPv4Address | ipaddress.IPv6Address) -> bool:
43 """True if this address must not be used for outbound proxy recall (SSRF mitigation)."""
44 return bool(
45 ip.is_private or ip.is_loopback or ip.is_link_local or ip.is_reserved or ip.is_multicast or ip.is_unspecified
46 )
49def _unsafe_literal_ip(host: str) -> bool:
50 """True if *host* parses as an IPv4/IPv6 address that must not be used for proxy recall."""
51 try:
52 ip = ipaddress.ip_address(host)
53 except ValueError:
54 return False
55 return _forbidden_proxy_target_ip_obj(ip)
58def validate_proxy_recall_url(url: str) -> None:
59 """Reject URLs that enable SSRF (private/loopback/metadata-ranged IPs, non-HTTP(S), no host).
61 Call this on the **fully expanded** request URL (including encoded user ``query`` fragments).
62 Hostnames still require :func:`validate_proxy_recall_dns` before HTTP (rebinding-safe check).
63 """
64 if not (url or "").strip():
65 raise ValueError("proxy recall URL is empty")
66 parsed = urlparse(url.strip())
67 if parsed.scheme not in ("http", "https"):
68 raise ValueError(f"proxy recall URL scheme not allowed: {parsed.scheme!r}")
69 host = parsed.hostname
70 if not host:
71 raise ValueError("proxy recall URL has no host")
72 h = host.lower().rstrip(".")
73 if h == "localhost" or h.endswith(".localhost"):
74 raise ValueError("proxy recall URL must not target localhost")
75 if _unsafe_literal_ip(h):
76 raise ValueError(
77 "proxy recall URL must not target loopback, private, link-local, or reserved addresses",
78 )
81def _sync_dns_validate_and_first_public_ip(
82 host: str,
83 port: int,
84) -> ipaddress.IPv4Address | ipaddress.IPv6Address:
85 """Resolve *host*, reject if any address is forbidden, return the first (OS order) for pinning."""
86 try:
87 infos = socket.getaddrinfo(host, port, type=socket.SOCK_STREAM)
88 except socket.gaierror as e:
89 raise ValueError(f"proxy recall DNS resolution failed for {host!r}: {e}") from e
90 if not infos:
91 raise ValueError(f"proxy recall DNS returned no addresses for {host!r}")
92 picked: list[ipaddress.IPv4Address | ipaddress.IPv6Address] = []
93 for _fam, _socktype, _proto, _canon, sockaddr in infos:
94 addr_s = sockaddr[0]
95 try:
96 ip = ipaddress.ip_address(addr_s)
97 except ValueError:
98 continue
99 if _forbidden_proxy_target_ip_obj(ip):
100 raise ValueError(
101 f"proxy recall DNS resolved to a forbidden address ({addr_s!r} for host {host!r})",
102 )
103 picked.append(ip)
104 if not picked:
105 raise ValueError(f"proxy recall DNS returned no usable addresses for {host!r}")
106 return picked[0]
109def _is_literal_ip_host(host: str) -> bool:
110 """True if *host* is an IPv4/IPv6 literal (no DNS name)."""
111 try:
112 ipaddress.ip_address(host)
113 except ValueError:
114 return False
115 return True
118def _proxy_recall_host_header_value(original_hostname: str, port: int, scheme: str) -> str:
119 """``Host`` header for the original server name (include port when not default)."""
120 sch = (scheme or "").lower()
121 default = 443 if sch == "https" else 80
122 h = original_hostname.lower().rstrip(".")
123 if port != default:
124 return f"{h}:{port}"
125 return h
128def _rebuild_request_url_pinned_to_ip(
129 parsed: ParseResult,
130 *,
131 pinned: ipaddress.IPv4Address | ipaddress.IPv6Address,
132 port: int,
133) -> str:
134 """Replace authority with pinned IP so the TCP/TLS layer cannot re-resolve differently."""
135 if isinstance(pinned, ipaddress.IPv6Address):
136 hostpart = f"[{pinned.compressed}]"
137 else:
138 hostpart = str(pinned)
139 sch = (parsed.scheme or "").lower()
140 default = 443 if sch == "https" else 80
141 if port != default:
142 netloc = f"{hostpart}:{port}"
143 else:
144 netloc = hostpart
145 return urlunparse((parsed.scheme, netloc, parsed.path, parsed.params, parsed.query, parsed.fragment))
148def _httpx_url_and_query_params(url: str) -> tuple[str, list[tuple[str, str]] | None]:
149 """Split *url* into a path-only URL for the httpx *url* argument and query pairs for *params*.
151 User-controlled search text must not be concatenated into the URL string passed to httpx
152 (SSRF static analysis + clearer separation of authority vs query).
153 """
154 p = urlparse((url or "").strip())
155 without_query = urlunparse((p.scheme, p.netloc, p.path, p.params, "", p.fragment))
156 if not p.query:
157 return without_query, None
158 pairs = parse_qsl(p.query, keep_blank_values=True)
159 if not pairs:
160 return without_query, None
161 return without_query, list(pairs)
164async def validate_proxy_recall_dns(host: str, port: int) -> None:
165 """Async-safe DNS check: resolve *host* in a worker thread, then validate all addresses."""
166 await asyncio.to_thread(_sync_dns_validate_and_first_public_ip, host, port)
169def auth_with_oauth_cache_namespace(
170 auth: dict[str, str | int | float | bool | None] | None,
171 source_id: str,
172) -> dict[str, str | int | float | bool | None] | None:
173 """Attach ``_oauth_cache_id`` so OAuth token caches do not collide across proxy sources."""
174 if not auth:
175 return None
176 return {**auth, "_oauth_cache_id": source_id}
179async def build_proxy_headers(
180 auth: dict[str, str | int | float | bool | None] | None,
181) -> dict[str, str]:
182 """Build HTTP headers (Bearer, API key, OAuth2 client_credentials / refresh, optional static ``headers``)."""
183 out: dict[str, str] = {}
184 if not auth:
185 return out
186 extra = auth.get("headers")
187 if isinstance(extra, dict):
188 for k, v in extra.items():
189 if isinstance(v, (str, int, float, bool)):
190 out[str(k)] = str(v)
191 t = (str(auth.get("type") or "")).strip().lower()
192 grant = (str(auth.get("grant_type") or "")).strip().lower()
193 if t in ("oauth2", "oauth2_client_credentials") and grant in ("", "client_credentials"):
194 from astrocyte.recall.oauth import fetch_oauth2_client_credentials_token
196 token = await fetch_oauth2_client_credentials_token(auth)
197 out["Authorization"] = f"Bearer {token}"
198 elif t == "oauth2_refresh" or (t in ("oauth2",) and grant == "refresh_token"):
199 from astrocyte.recall.oauth import fetch_oauth2_refresh_access_token
201 token = await fetch_oauth2_refresh_access_token(auth)
202 out["Authorization"] = f"Bearer {token}"
203 elif t == "bearer":
204 token = auth.get("token")
205 if token is not None and str(token).strip():
206 out["Authorization"] = f"Bearer {token}"
207 elif t == "api_key":
208 header_name = str(auth.get("header") or "X-API-Key")
209 if not re.match(r"^[A-Za-z0-9\-]+$", header_name):
210 raise ValueError(f"Invalid header name in proxy auth config: {header_name!r}")
211 val = auth.get("value") if auth.get("value") is not None else auth.get("token")
212 if val is not None and str(val).strip():
213 val_str = str(val)
214 if "\r" in val_str or "\n" in val_str:
215 raise ValueError("Header value contains CRLF characters (possible injection)")
216 out[header_name] = val_str
217 return out
220def _deep_replace_placeholders(obj: Any, query: str, bank_id: str) -> Any:
221 if obj == PLACE_QUERY:
222 return query
223 if obj == PLACE_BANK:
224 return bank_id
225 if isinstance(obj, dict):
226 return {k: _deep_replace_placeholders(v, query, bank_id) for k, v in obj.items()}
227 if isinstance(obj, list):
228 return [_deep_replace_placeholders(x, query, bank_id) for x in obj]
229 return obj
232def _resolve_post_json(source: SourceConfig, query: str, bank_id: str) -> dict[str, Any]:
233 rb = source.recall_body
234 if rb is None:
235 return {"query": query, "bank_id": bank_id}
236 if isinstance(rb, dict):
237 resolved = _deep_replace_placeholders(rb, query, bank_id)
238 if isinstance(resolved, dict):
239 return resolved
240 return {"query": query, "bank_id": bank_id}
241 if isinstance(rb, str):
242 try:
243 data = json.loads(rb)
244 except json.JSONDecodeError:
245 return {"query": query, "bank_id": bank_id}
246 resolved = _deep_replace_placeholders(data, query, bank_id)
247 if isinstance(resolved, dict):
248 return resolved
249 return {"query": query, "bank_id": bank_id}
250 return {"query": query, "bank_id": bank_id}
253def _row_to_hit(source_id: str, row: dict[str, Any]) -> MemoryHit | None:
254 text = row.get("text")
255 if not isinstance(text, str) or not text.strip():
256 return None
257 score = row.get("score")
258 if isinstance(score, (int, float)):
259 s = float(score)
260 else:
261 s = 0.5
262 mid = row.get("memory_id")
263 meta_raw = row.get("metadata")
264 meta: Metadata | None = None
265 if isinstance(meta_raw, dict):
266 meta = {}
267 for k, v in meta_raw.items():
268 if isinstance(v, (str, int, float, bool)) or v is None:
269 meta[str(k)] = v
270 tags = row.get("tags")
271 tag_list: list[str] | None = None
272 if isinstance(tags, list):
273 tag_list = [str(x) for x in tags]
274 return MemoryHit(
275 text=text,
276 score=min(1.0, max(0.0, s)),
277 fact_type=str(row["fact_type"]) if row.get("fact_type") is not None else None,
278 metadata=meta,
279 tags=tag_list,
280 memory_id=str(mid) if mid is not None else None,
281 source=f"proxy:{source_id}",
282 )
285def _parse_hits_payload(data: Any) -> list[Any]:
286 raw_hits = data.get("hits") if isinstance(data, dict) else None
287 if raw_hits is None and isinstance(data, dict):
288 raw_hits = data.get("results")
289 if not isinstance(raw_hits, list):
290 return []
291 return raw_hits
294def _record_proxy_metrics(
295 metrics: MetricsCollector | None,
296 *,
297 source_id: str,
298 status: str,
299 duration_s: float | None,
300) -> None:
301 if not metrics:
302 return
303 metrics.inc_counter(
304 _COUNTER,
305 {"source_id": source_id, "status": status},
306 "Proxy recall attempts by source and status",
307 )
308 if duration_s is not None and status == "ok":
309 metrics.observe_histogram(
310 _HIST,
311 duration_s,
312 {"source_id": source_id},
313 "Proxy recall HTTP duration in seconds",
314 )
317async def fetch_proxy_recall_hits(
318 source_id: str,
319 source: SourceConfig,
320 *,
321 query: str,
322 bank_id: str,
323 timeout: float = _DEFAULT_TIMEOUT,
324 metrics: MetricsCollector | None = None,
325) -> list[MemoryHit]:
326 """Call ``source.url`` (GET or POST) and parse JSON ``hits`` / ``results`` arrays."""
327 url_t = source.url or ""
328 if not url_t.strip():
329 return []
331 method = (source.recall_method or "GET").strip().upper()
332 if method not in ("GET", "POST"):
333 logger.warning("proxy source %s: unknown recall_method %r, using GET", source_id, method)
334 method = "GET"
336 base_url = url_t.strip()
338 with span(
339 "astrocyte.proxy_recall",
340 {"source_id": source_id, "method": method, "bank_id": bank_id},
341 ):
342 with timed() as t:
343 try:
344 headers = await build_proxy_headers(
345 auth_with_oauth_cache_namespace(source.auth, source_id),
346 )
347 if method == "POST":
348 request_url = _expand_proxy_url(base_url, query) if "{query}" in base_url else base_url
349 body = _resolve_post_json(source, query, bank_id)
350 else:
351 request_url = _expand_proxy_url(base_url, query)
352 body = None
353 validate_proxy_recall_url(request_url)
354 parsed_req = urlparse(request_url)
355 req_host = parsed_req.hostname
356 if not req_host:
357 raise ValueError("proxy recall URL has no host")
358 req_port = parsed_req.port or (443 if parsed_req.scheme == "https" else 80)
359 sch = (parsed_req.scheme or "").lower()
361 pinned = await asyncio.to_thread(
362 _sync_dns_validate_and_first_public_ip,
363 req_host,
364 req_port,
365 )
367 if _is_literal_ip_host(req_host):
368 effective_url = request_url
369 out_headers = headers
370 req_extensions: dict[str, Any] | None = None
371 else:
372 effective_url = _rebuild_request_url_pinned_to_ip(
373 parsed_req,
374 pinned=pinned,
375 port=req_port,
376 )
377 out_headers = {
378 **headers,
379 "Host": _proxy_recall_host_header_value(req_host, req_port, sch),
380 }
381 req_extensions = {"sni_hostname": req_host} if sch == "https" else None
383 httpx_url, httpx_params = _httpx_url_and_query_params(effective_url)
385 async with httpx.AsyncClient(timeout=timeout, follow_redirects=False) as client:
386 if method == "POST":
387 r = await client.request(
388 "POST",
389 httpx_url,
390 params=httpx_params,
391 headers=out_headers,
392 json=body,
393 extensions=req_extensions,
394 )
395 else:
396 r = await client.request(
397 "GET",
398 httpx_url,
399 params=httpx_params,
400 headers=out_headers,
401 extensions=req_extensions,
402 )
403 r.raise_for_status()
404 data = r.json()
405 except Exception:
406 _record_proxy_metrics(metrics, source_id=source_id, status="error", duration_s=None)
407 raise
408 duration_s = t["elapsed_ms"] / 1000.0
410 _record_proxy_metrics(metrics, source_id=source_id, status="ok", duration_s=duration_s)
412 raw_hits = _parse_hits_payload(data)
413 out: list[MemoryHit] = []
414 for row in raw_hits:
415 if not isinstance(row, dict):
416 continue
417 hit = _row_to_hit(source_id, row)
418 if hit:
419 out.append(hit)
420 return out
423async def gather_proxy_hits_for_bank(
424 config: AstrocyteConfig | Any,
425 *,
426 query: str,
427 bank_id: str,
428 metrics: MetricsCollector | None = None,
429) -> list[MemoryHit]:
430 """Fetch hits from all ``type: proxy`` sources whose ``target_bank`` matches ``bank_id``."""
431 sources = getattr(config, "sources", None) or {}
432 out: list[MemoryHit] = []
433 for sid, src in sources.items():
434 if not isinstance(src, SourceConfig):
435 continue
436 if (src.type or "").strip().lower() != "proxy":
437 continue
438 if (src.target_bank or "").strip() != bank_id:
439 continue
440 try:
441 batch = await fetch_proxy_recall_hits(sid, src, query=query, bank_id=bank_id, metrics=metrics)
442 out.extend(batch)
443 except Exception as e:
444 logger.warning("proxy recall failed for source %s: %s", sid, e)
445 return out
448async def merge_manual_and_proxy_hits(
449 config: AstrocyteConfig | Any,
450 *,
451 query: str,
452 bank_id: str,
453 manual: list[MemoryHit] | None,
454 metrics: MetricsCollector | None = None,
455) -> list[MemoryHit] | None:
456 """Combine caller ``external_context`` with configured proxy sources for this bank."""
457 proxy = await gather_proxy_hits_for_bank(config, query=query, bank_id=bank_id, metrics=metrics)
458 if not proxy and not manual:
459 return None
460 return list(manual or []) + proxy