Coverage for astrocyte/recall/oauth.py: 83%

114 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""OAuth2 for proxy recall — client credentials, refresh (with rotation), authorization code exchange.""" 

2 

3from __future__ import annotations 

4 

5import base64 

6import logging 

7import time 

8from dataclasses import dataclass 

9from typing import Any 

10 

11import httpx 

12 

13logger = logging.getLogger(__name__) 

14 

15_TOKEN_CACHE: dict[str, _TokenState] = {} 

16 

17 

18@dataclass 

19class _TokenState: 

20 access_token: str 

21 expires_at: float # monotonic 

22 refresh_token: str | None = None # updated when issuer rotates refresh tokens 

23 

24 

25def _client_credentials_cache_key(auth: dict[str, str | int | float | bool | None]) -> str: 

26 ns = auth.get("_oauth_cache_id") 

27 if ns is not None and str(ns).strip(): 

28 return f"cc|{ns}" 

29 tid = str(auth.get("token_url") or "") 

30 cid = str(auth.get("client_id") or "") 

31 sc = str(auth.get("scope") or "") 

32 return f"{tid}|{cid}|{sc}" 

33 

34 

35def _refresh_cache_key(auth: dict[str, str | int | float | bool | None]) -> str: 

36 """Partition key for refresh-token cache entries. 

37 

38 Requires ``_oauth_cache_id`` (proxy recall sets this via :func:`auth_with_oauth_cache_namespace`) 

39 so we never hash ``refresh_token`` for keying — avoids collisions across sources and satisfies 

40 static analysis that treats bearer secrets like password material. 

41 """ 

42 ns = auth.get("_oauth_cache_id") 

43 if ns is not None and str(ns).strip(): 

44 return f"rt|{ns}" 

45 raise ValueError( 

46 "oauth2 refresh requires _oauth_cache_id for in-memory token caching; " 

47 "use auth_with_oauth_cache_namespace(auth, unique_id) or set _oauth_cache_id on auth.", 

48 ) 

49 

50 

51def _normalize_auth_method(raw: str | None) -> str: 

52 v = (raw or "client_secret_post").strip().lower() 

53 if v in ("basic", "client_secret_basic", "http_basic"): 

54 return "client_secret_basic" 

55 return "client_secret_post" 

56 

57 

58async def post_oauth2_token_endpoint( 

59 token_url: str, 

60 form: dict[str, str], 

61 *, 

62 client_id: str, 

63 client_secret: str, 

64 token_endpoint_auth_method: str, 

65 timeout: float = 30.0, 

66) -> dict[str, Any]: 

67 """POST ``application/x-www-form-urlencoded`` to the token endpoint. 

68 

69 ``client_secret_basic`` (RFC 6749 §2.3.1): ``Authorization: Basic``; client id/secret are not 

70 duplicated in the form body. 

71 ``client_secret_post``: ``client_id`` / ``client_secret`` are added to ``form`` if not present. 

72 """ 

73 method = _normalize_auth_method(token_endpoint_auth_method) 

74 headers: dict[str, str] = {"Content-Type": "application/x-www-form-urlencoded"} 

75 data = dict(form) 

76 

77 if method == "client_secret_basic": 

78 raw = f"{client_id}:{client_secret}".encode() 

79 headers["Authorization"] = "Basic " + base64.b64encode(raw).decode("ascii") 

80 data.pop("client_id", None) 

81 data.pop("client_secret", None) 

82 else: 

83 if client_id and "client_id" not in data: 

84 data["client_id"] = client_id 

85 if client_secret and "client_secret" not in data: 

86 data["client_secret"] = client_secret 

87 

88 async with httpx.AsyncClient(timeout=timeout) as client: 

89 r = await client.post(token_url, data=data, headers=headers) 

90 r.raise_for_status() 

91 return r.json() 

92 

93 

94def _parse_expires(body: dict[str, Any], now: float) -> tuple[str, float]: 

95 token = body.get("access_token") 

96 if not isinstance(token, str) or not token.strip(): 

97 raise ValueError("OAuth2 token response missing access_token") 

98 expires_in = body.get("expires_in") 

99 if isinstance(expires_in, (int, float)) and expires_in > 0: 

100 ttl = float(expires_in) 

101 else: 

102 ttl = 3600.0 

103 return token.strip(), now + ttl - 30.0 

104 

105 

106async def fetch_oauth2_client_credentials_token( 

107 auth: dict[str, str | int | float | bool | None], 

108 *, 

109 timeout: float = 30.0, 

110) -> str: 

111 """Obtain an access token (RFC 6749 ``client_credentials``). Cached until shortly before expiry.""" 

112 token_url = str(auth.get("token_url") or "").strip() 

113 client_id = str(auth.get("client_id") or "").strip() 

114 client_secret = str(auth.get("client_secret") or "").strip() 

115 if not token_url or not client_id: 

116 raise ValueError("oauth2_client_credentials requires token_url and client_id") 

117 

118 key = _client_credentials_cache_key(auth) 

119 now = time.monotonic() 

120 cached = _TOKEN_CACHE.get(key) 

121 if cached is not None and now < cached.expires_at: 

122 return cached.access_token 

123 

124 scope = str(auth.get("scope") or "").strip() 

125 auth_method = str(auth.get("token_endpoint_auth_method") or "client_secret_post") 

126 data: dict[str, str] = {"grant_type": "client_credentials"} 

127 if scope: 

128 data["scope"] = scope 

129 

130 body = await post_oauth2_token_endpoint( 

131 token_url, 

132 data, 

133 client_id=client_id, 

134 client_secret=client_secret, 

135 token_endpoint_auth_method=auth_method, 

136 timeout=timeout, 

137 ) 

138 access, exp = _parse_expires(body, now) 

139 _TOKEN_CACHE[key] = _TokenState(access_token=access, expires_at=exp, refresh_token=None) 

140 return access 

141 

142 

143async def fetch_oauth2_refresh_access_token( 

144 auth: dict[str, str | int | float | bool | None], 

145 *, 

146 timeout: float = 30.0, 

147) -> str: 

148 """Obtain an access token via ``grant_type=refresh_token`` (RFC 6749). 

149 

150 Caches the access token. If the issuer returns a new ``refresh_token``, it replaces the stored 

151 one for this cache entry (**rotation**) so subsequent refreshes use the latest secret. 

152 

153 ``auth`` must include ``_oauth_cache_id`` (see :func:`auth_with_oauth_cache_namespace`) so the 

154 cache is partitioned without deriving keys from ``refresh_token``. 

155 """ 

156 token_url = str(auth.get("token_url") or "").strip() 

157 client_id = str(auth.get("client_id") or "").strip() 

158 client_secret = str(auth.get("client_secret") or "").strip() 

159 if not token_url or not client_id: 

160 raise ValueError("oauth2 refresh flow requires token_url and client_id") 

161 

162 key = _refresh_cache_key(auth) 

163 now = time.monotonic() 

164 state = _TOKEN_CACHE.get(key) 

165 

166 if state is not None and now < state.expires_at: 

167 return state.access_token 

168 

169 refresh = str(auth.get("refresh_token") or "").strip() 

170 if state is not None and state.refresh_token: 

171 refresh = state.refresh_token 

172 if not refresh: 

173 raise ValueError("oauth2 refresh flow requires refresh_token (in config or from prior rotation)") 

174 

175 scope = str(auth.get("scope") or "").strip() 

176 auth_method = str(auth.get("token_endpoint_auth_method") or "client_secret_post") 

177 data: dict[str, str] = { 

178 "grant_type": "refresh_token", 

179 "refresh_token": refresh, 

180 } 

181 if scope: 

182 data["scope"] = scope 

183 

184 body = await post_oauth2_token_endpoint( 

185 token_url, 

186 data, 

187 client_id=client_id, 

188 client_secret=client_secret, 

189 token_endpoint_auth_method=auth_method, 

190 timeout=timeout, 

191 ) 

192 access, exp = _parse_expires(body, now) 

193 

194 new_rt = body.get("refresh_token") 

195 if isinstance(new_rt, str) and new_rt.strip(): 

196 stored_refresh = new_rt.strip() 

197 else: 

198 stored_refresh = refresh 

199 

200 _TOKEN_CACHE[key] = _TokenState( 

201 access_token=access, 

202 expires_at=exp, 

203 refresh_token=stored_refresh, 

204 ) 

205 return access 

206 

207 

208async def exchange_oauth2_authorization_code( 

209 *, 

210 token_url: str, 

211 client_id: str, 

212 client_secret: str, 

213 code: str, 

214 redirect_uri: str, 

215 scope: str | None = None, 

216 token_endpoint_auth_method: str = "client_secret_post", 

217 timeout: float = 30.0, 

218) -> dict[str, Any]: 

219 """Exchange an authorization code for tokens (RFC 6749 §4.1.3). 

220 

221 Returns the parsed JSON body (typically ``access_token``, ``expires_in``, ``refresh_token``, …). 

222 Does not use the in-memory cache — callers should persist ``refresh_token`` and configure 

223 ``type: oauth2_refresh`` (or ``oauth2`` with ``grant_type: refresh_token``) for ongoing API access. 

224 """ 

225 data: dict[str, str] = { 

226 "grant_type": "authorization_code", 

227 "code": code.strip(), 

228 "redirect_uri": redirect_uri.strip(), 

229 } 

230 if scope and scope.strip(): 

231 data["scope"] = scope.strip() 

232 

233 return await post_oauth2_token_endpoint( 

234 token_url.strip(), 

235 data, 

236 client_id=client_id.strip(), 

237 client_secret=client_secret.strip(), 

238 token_endpoint_auth_method=token_endpoint_auth_method, 

239 timeout=timeout, 

240 ) 

241 

242 

243def clear_oauth2_token_cache_for_tests() -> None: 

244 """Test helper — clear in-memory OAuth token cache.""" 

245 _TOKEN_CACHE.clear()