Coverage for astrocyte/_mcp_identity.py: 75%

71 statements  

« prev     ^ index     » next       coverage.py v7.15.0, created at 2026-07-04 05:24 +0000

1"""JWT identity middleware wiring for the Astrocyte MCP server. 

2 

3This module is the **transport layer** counterpart to the pure classifier 

4in :mod:`astrocyte.identity_jwt`. It accepts a Bearer token from an 

5incoming request, validates it, extracts the claims, and builds an 

6:class:`AstrocyteContext` for the request. 

7 

8The decoder is **injectable** so the wiring can be tested without PyJWT 

9as a runtime dependency (unit tests feed pre-decoded claim dicts); the 

10default decoder uses PyJWT + a JWKS client when available. 

11 

12See: 

13 

14- :mod:`astrocyte.identity_jwt` — the pure claim-to-identity classifier 

15- ``docs/_plugins/jwt-identity-middleware.md`` — operator guide 

16- ``docs/_design/astrocyte_identity_spec.md`` §3 Gap 1 — architectural spec 

17""" 

18 

19from __future__ import annotations 

20 

21import logging 

22from dataclasses import dataclass 

23from typing import Any, Callable, Protocol 

24 

25from astrocyte.config import JwtMiddlewareConfig 

26from astrocyte.errors import AuthorizationError 

27from astrocyte.identity import format_principal 

28from astrocyte.identity_jwt import classify_jwt_claims 

29from astrocyte.types import AstrocyteContext 

30 

31logger = logging.getLogger("astrocyte.identity.mcp") 

32 

33#: Type alias for the decoder strategy. Given a raw JWT string, returns a 

34#: decoded and signature-validated claim dict. Raises on any failure — 

35#: the middleware maps exceptions to :class:`AuthorizationError`. 

36Decoder = Callable[[str], dict[str, Any]] 

37 

38 

39# --------------------------------------------------------------------------- 

40# Header extraction 

41# --------------------------------------------------------------------------- 

42 

43 

44def extract_bearer_token(headers: dict[str, str]) -> str | None: 

45 """Pull the Bearer token from an ``Authorization`` header, if any. 

46 

47 Header lookup is case-insensitive. Returns None when the header is 

48 absent or doesn't start with ``Bearer ``; the middleware decides how 

49 to handle that (fail closed vs anonymous fallback). 

50 """ 

51 if not headers: 

52 return None 

53 # Case-insensitive lookup — HTTP header names aren't case-sensitive. 

54 for key, value in headers.items(): 

55 if key.lower() == "authorization": 

56 if not isinstance(value, str): 

57 return None 

58 stripped = value.strip() 

59 if stripped.lower().startswith("bearer "): 

60 # Slice past "Bearer " (7 chars) preserving the token. 

61 return stripped[7:].strip() or None 

62 return None 

63 return None 

64 

65 

66# --------------------------------------------------------------------------- 

67# Middleware core — transport-independent 

68# --------------------------------------------------------------------------- 

69 

70 

71class JwtIdentityMiddleware: 

72 """Resolve an :class:`AstrocyteContext` from inbound request headers. 

73 

74 The middleware orchestrates four steps: 

75 

76 1. Extract the Bearer token from the Authorization header. 

77 2. If no token present, apply the anonymous / fail-closed policy. 

78 3. Decode + validate the token via the injected :data:`Decoder`. 

79 4. Classify the claims via :func:`classify_jwt_claims` and build an 

80 :class:`AstrocyteContext` whose ``actor`` is the resolved identity 

81 and whose ``principal`` is the canonical ``{type}:{id}`` form. 

82 

83 The decoder is injected so tests (and embedded deployments with 

84 custom token formats) can replace PyJWT with a minimal substitute. 

85 Production deployments use :func:`make_pyjwt_decoder` which wires up 

86 a JWKS client against the configured endpoint. 

87 """ 

88 

89 def __init__( 

90 self, 

91 config: JwtMiddlewareConfig, 

92 decoder: Decoder, 

93 ) -> None: 

94 self._config = config 

95 self._decoder = decoder 

96 

97 def resolve(self, headers: dict[str, str]) -> AstrocyteContext: 

98 """Resolve caller identity from inbound headers. 

99 

100 Returns a fully-formed :class:`AstrocyteContext` — ``principal`` 

101 is always populated so downstream access control and audit logs 

102 have a label even for anonymous callers. 

103 

104 Raises: 

105 AuthorizationError: Bearer token was presented but invalid 

106 (signature, expiry, audience, issuer, or classification 

107 failure); or no token was presented but anonymous access 

108 is disabled; or the middleware is enabled but the caller 

109 has no ``Authorization`` header. 

110 """ 

111 token = extract_bearer_token(headers) 

112 

113 if token is None: 

114 # No credential presented. Policy decides whether that's OK. 

115 if self._config.fail_closed and not self._config.allow_anonymous: 

116 raise AuthorizationError( 

117 "No Bearer token presented and anonymous access is " 

118 "disabled (identity.jwt_middleware.allow_anonymous=false)." 

119 ) 

120 # Anonymous fallback — principal label makes this visible in logs. 

121 return AstrocyteContext(principal="anonymous") 

122 

123 # Token presented — validate and classify. Any failure here is 

124 # fail-closed regardless of allow_anonymous: a broken token is 

125 # never silently downgraded to anonymous (that would let an 

126 # attacker bypass identity-aware MIP rules by malforming their 

127 # token). 

128 try: 

129 claims = self._decoder(token) 

130 except AuthorizationError: 

131 raise 

132 except Exception as exc: # pragma: no cover — decoder-specific 

133 logger.warning("JWT decode failed: %s", exc) 

134 raise AuthorizationError( 

135 f"Bearer token decode/validation failed: {exc}. Request rejected to prevent cross-user data leakage." 

136 ) from exc 

137 

138 identity = classify_jwt_claims(claims) 

139 return AstrocyteContext( 

140 principal=format_principal(identity), 

141 actor=identity, 

142 ) 

143 

144 

145# --------------------------------------------------------------------------- 

146# Default PyJWT-backed decoder (optional) 

147# --------------------------------------------------------------------------- 

148 

149 

150class _JWKSClient(Protocol): 

151 def get_signing_key_from_jwt(self, token: str) -> Any: # pragma: no cover 

152 """Return the signing key for ``token`` (PyJWT-compatible shape).""" 

153 

154 

155@dataclass 

156class _PyJWTContext: 

157 """Captured state for the default PyJWT decoder.""" 

158 

159 jwks_client: _JWKSClient 

160 audience: str | None 

161 issuer: str | None 

162 algorithms: list[str] 

163 

164 

165def make_pyjwt_decoder(config: JwtMiddlewareConfig) -> Decoder: 

166 """Build a PyJWT-backed :data:`Decoder` from middleware config. 

167 

168 Raises :class:`ImportError` if PyJWT isn't installed. This keeps 

169 PyJWT an optional dependency — deployments that don't enable JWT 

170 middleware don't pay for it. 

171 

172 The decoder enforces: 

173 - Signature via JWKS-resolved public key 

174 - Algorithm allowlist (defaults to asymmetric — RS256, ES256) 

175 - ``aud`` claim match (required — set on config) 

176 - ``iss`` claim match (when configured) 

177 - ``exp``, ``iat`` required claims — rejects tokens missing them 

178 """ 

179 try: 

180 import jwt # PyJWT # type: ignore[import-not-found] 

181 except ImportError as exc: # pragma: no cover — optional dep 

182 raise ImportError( 

183 "JWT identity middleware requires PyJWT. Install with " 

184 "'pip install pyjwt[cryptography]' or enable the 'identity' " 

185 "extra on the astrocyte package." 

186 ) from exc 

187 

188 if not config.jwks_uri: 

189 raise ValueError("identity.jwt_middleware.jwks_uri must be set when enabled.") 

190 if not config.token_audience: 

191 raise ValueError( 

192 "identity.jwt_middleware.token_audience must be set when " 

193 "enabled — accepting any audience risks cross-tenant token " 

194 "reuse." 

195 ) 

196 

197 jwks_client = jwt.PyJWKClient( 

198 config.jwks_uri, 

199 cache_keys=True, 

200 lifespan=config.jwks_refresh_interval_hours * 3600, 

201 ) 

202 ctx = _PyJWTContext( 

203 jwks_client=jwks_client, 

204 audience=config.token_audience, 

205 issuer=config.token_issuer, 

206 algorithms=list(config.algorithms), 

207 ) 

208 

209 def decode(token: str) -> dict[str, Any]: 

210 signing_key = ctx.jwks_client.get_signing_key_from_jwt(token) 

211 options: dict[str, Any] = {"require": ["exp", "iat"]} 

212 kwargs: dict[str, Any] = { 

213 "algorithms": ctx.algorithms, 

214 "options": options, 

215 } 

216 if ctx.audience is not None: 

217 kwargs["audience"] = ctx.audience 

218 if ctx.issuer is not None: 

219 kwargs["issuer"] = ctx.issuer 

220 return jwt.decode(token, signing_key.key, **kwargs) 

221 

222 return decode 

223 

224 

225# --------------------------------------------------------------------------- 

226# Top-level factory 

227# --------------------------------------------------------------------------- 

228 

229 

230def build_jwt_middleware( 

231 config: JwtMiddlewareConfig, 

232 decoder: Decoder | None = None, 

233) -> JwtIdentityMiddleware | None: 

234 """Build a middleware if enabled in config, else return None. 

235 

236 Args: 

237 config: The ``identity.jwt_middleware`` section of AstrocyteConfig. 

238 decoder: Optional decoder override. When None (default), a 

239 PyJWT-backed decoder is constructed from the config; pass a 

240 custom callable for testing or non-PyJWT deployments. 

241 """ 

242 if not config.enabled: 

243 return None 

244 if decoder is None: 

245 decoder = make_pyjwt_decoder(config) 

246 return JwtIdentityMiddleware(config, decoder)