Coverage for astrocyte/recall/oauth.py: 83%
114 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""OAuth2 for proxy recall — client credentials, refresh (with rotation), authorization code exchange."""
3from __future__ import annotations
5import base64
6import logging
7import time
8from dataclasses import dataclass
9from typing import Any
11import httpx
13logger = logging.getLogger(__name__)
15_TOKEN_CACHE: dict[str, _TokenState] = {}
18@dataclass
19class _TokenState:
20 access_token: str
21 expires_at: float # monotonic
22 refresh_token: str | None = None # updated when issuer rotates refresh tokens
25def _client_credentials_cache_key(auth: dict[str, str | int | float | bool | None]) -> str:
26 ns = auth.get("_oauth_cache_id")
27 if ns is not None and str(ns).strip():
28 return f"cc|{ns}"
29 tid = str(auth.get("token_url") or "")
30 cid = str(auth.get("client_id") or "")
31 sc = str(auth.get("scope") or "")
32 return f"{tid}|{cid}|{sc}"
35def _refresh_cache_key(auth: dict[str, str | int | float | bool | None]) -> str:
36 """Partition key for refresh-token cache entries.
38 Requires ``_oauth_cache_id`` (proxy recall sets this via :func:`auth_with_oauth_cache_namespace`)
39 so we never hash ``refresh_token`` for keying — avoids collisions across sources and satisfies
40 static analysis that treats bearer secrets like password material.
41 """
42 ns = auth.get("_oauth_cache_id")
43 if ns is not None and str(ns).strip():
44 return f"rt|{ns}"
45 raise ValueError(
46 "oauth2 refresh requires _oauth_cache_id for in-memory token caching; "
47 "use auth_with_oauth_cache_namespace(auth, unique_id) or set _oauth_cache_id on auth.",
48 )
51def _normalize_auth_method(raw: str | None) -> str:
52 v = (raw or "client_secret_post").strip().lower()
53 if v in ("basic", "client_secret_basic", "http_basic"):
54 return "client_secret_basic"
55 return "client_secret_post"
58async def post_oauth2_token_endpoint(
59 token_url: str,
60 form: dict[str, str],
61 *,
62 client_id: str,
63 client_secret: str,
64 token_endpoint_auth_method: str,
65 timeout: float = 30.0,
66) -> dict[str, Any]:
67 """POST ``application/x-www-form-urlencoded`` to the token endpoint.
69 ``client_secret_basic`` (RFC 6749 §2.3.1): ``Authorization: Basic``; client id/secret are not
70 duplicated in the form body.
71 ``client_secret_post``: ``client_id`` / ``client_secret`` are added to ``form`` if not present.
72 """
73 method = _normalize_auth_method(token_endpoint_auth_method)
74 headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"}
75 data = dict(form)
77 if method == "client_secret_basic":
78 raw = f"{client_id}:{client_secret}".encode()
79 headers["Authorization"] = "Basic " + base64.b64encode(raw).decode("ascii")
80 data.pop("client_id", None)
81 data.pop("client_secret", None)
82 else:
83 if client_id and "client_id" not in data:
84 data["client_id"] = client_id
85 if client_secret and "client_secret" not in data:
86 data["client_secret"] = client_secret
88 async with httpx.AsyncClient(timeout=timeout) as client:
89 r = await client.post(token_url, data=data, headers=headers)
90 r.raise_for_status()
91 return r.json()
94def _parse_expires(body: dict[str, Any], now: float) -> tuple[str, float]:
95 token = body.get("access_token")
96 if not isinstance(token, str) or not token.strip():
97 raise ValueError("OAuth2 token response missing access_token")
98 expires_in = body.get("expires_in")
99 if isinstance(expires_in, (int, float)) and expires_in > 0:
100 ttl = float(expires_in)
101 else:
102 ttl = 3600.0
103 return token.strip(), now + ttl - 30.0
106async def fetch_oauth2_client_credentials_token(
107 auth: dict[str, str | int | float | bool | None],
108 *,
109 timeout: float = 30.0,
110) -> str:
111 """Obtain an access token (RFC 6749 ``client_credentials``). Cached until shortly before expiry."""
112 token_url = str(auth.get("token_url") or "").strip()
113 client_id = str(auth.get("client_id") or "").strip()
114 client_secret = str(auth.get("client_secret") or "").strip()
115 if not token_url or not client_id:
116 raise ValueError("oauth2_client_credentials requires token_url and client_id")
118 key = _client_credentials_cache_key(auth)
119 now = time.monotonic()
120 cached = _TOKEN_CACHE.get(key)
121 if cached is not None and now < cached.expires_at:
122 return cached.access_token
124 scope = str(auth.get("scope") or "").strip()
125 auth_method = str(auth.get("token_endpoint_auth_method") or "client_secret_post")
126 data: dict[str, str] = {"grant_type": "client_credentials"}
127 if scope:
128 data["scope"] = scope
130 body = await post_oauth2_token_endpoint(
131 token_url,
132 data,
133 client_id=client_id,
134 client_secret=client_secret,
135 token_endpoint_auth_method=auth_method,
136 timeout=timeout,
137 )
138 access, exp = _parse_expires(body, now)
139 _TOKEN_CACHE[key] = _TokenState(access_token=access, expires_at=exp, refresh_token=None)
140 return access
143async def fetch_oauth2_refresh_access_token(
144 auth: dict[str, str | int | float | bool | None],
145 *,
146 timeout: float = 30.0,
147) -> str:
148 """Obtain an access token via ``grant_type=refresh_token`` (RFC 6749).
150 Caches the access token. If the issuer returns a new ``refresh_token``, it replaces the stored
151 one for this cache entry (**rotation**) so subsequent refreshes use the latest secret.
153 ``auth`` must include ``_oauth_cache_id`` (see :func:`auth_with_oauth_cache_namespace`) so the
154 cache is partitioned without deriving keys from ``refresh_token``.
155 """
156 token_url = str(auth.get("token_url") or "").strip()
157 client_id = str(auth.get("client_id") or "").strip()
158 client_secret = str(auth.get("client_secret") or "").strip()
159 if not token_url or not client_id:
160 raise ValueError("oauth2 refresh flow requires token_url and client_id")
162 key = _refresh_cache_key(auth)
163 now = time.monotonic()
164 state = _TOKEN_CACHE.get(key)
166 if state is not None and now < state.expires_at:
167 return state.access_token
169 refresh = str(auth.get("refresh_token") or "").strip()
170 if state is not None and state.refresh_token:
171 refresh = state.refresh_token
172 if not refresh:
173 raise ValueError("oauth2 refresh flow requires refresh_token (in config or from prior rotation)")
175 scope = str(auth.get("scope") or "").strip()
176 auth_method = str(auth.get("token_endpoint_auth_method") or "client_secret_post")
177 data: dict[str, str] = {
178 "grant_type": "refresh_token",
179 "refresh_token": refresh,
180 }
181 if scope:
182 data["scope"] = scope
184 body = await post_oauth2_token_endpoint(
185 token_url,
186 data,
187 client_id=client_id,
188 client_secret=client_secret,
189 token_endpoint_auth_method=auth_method,
190 timeout=timeout,
191 )
192 access, exp = _parse_expires(body, now)
194 new_rt = body.get("refresh_token")
195 if isinstance(new_rt, str) and new_rt.strip():
196 stored_refresh = new_rt.strip()
197 else:
198 stored_refresh = refresh
200 _TOKEN_CACHE[key] = _TokenState(
201 access_token=access,
202 expires_at=exp,
203 refresh_token=stored_refresh,
204 )
205 return access
208async def exchange_oauth2_authorization_code(
209 *,
210 token_url: str,
211 client_id: str,
212 client_secret: str,
213 code: str,
214 redirect_uri: str,
215 scope: str | None = None,
216 token_endpoint_auth_method: str = "client_secret_post",
217 timeout: float = 30.0,
218) -> dict[str, Any]:
219 """Exchange an authorization code for tokens (RFC 6749 §4.1.3).
221 Returns the parsed JSON body (typically ``access_token``, ``expires_in``, ``refresh_token``, …).
222 Does not use the in-memory cache — callers should persist ``refresh_token`` and configure
223 ``type: oauth2_refresh`` (or ``oauth2`` with ``grant_type: refresh_token``) for ongoing API access.
224 """
225 data: dict[str, str] = {
226 "grant_type": "authorization_code",
227 "code": code.strip(),
228 "redirect_uri": redirect_uri.strip(),
229 }
230 if scope and scope.strip():
231 data["scope"] = scope.strip()
233 return await post_oauth2_token_endpoint(
234 token_url.strip(),
235 data,
236 client_id=client_id.strip(),
237 client_secret=client_secret.strip(),
238 token_endpoint_auth_method=token_endpoint_auth_method,
239 timeout=timeout,
240 )
243def clear_oauth2_token_cache_for_tests() -> None:
244 """Test helper — clear in-memory OAuth token cache."""
245 _TOKEN_CACHE.clear()