Coverage for astrocyte/_mcp_identity.py: 75%
71 statements
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
« prev ^ index » next coverage.py v7.15.0, created at 2026-07-04 05:24 +0000
1"""JWT identity middleware wiring for the Astrocyte MCP server.
3This module is the **transport layer** counterpart to the pure classifier
4in :mod:`astrocyte.identity_jwt`. It accepts a Bearer token from an
5incoming request, validates it, extracts the claims, and builds an
6:class:`AstrocyteContext` for the request.
8The decoder is **injectable** so the wiring can be tested without PyJWT
9as a runtime dependency (unit tests feed pre-decoded claim dicts); the
10default decoder uses PyJWT + a JWKS client when available.
12See:
14- :mod:`astrocyte.identity_jwt` — the pure claim-to-identity classifier
15- ``docs/_plugins/jwt-identity-middleware.md`` — operator guide
16- ``docs/_design/astrocyte_identity_spec.md`` §3 Gap 1 — architectural spec
17"""
19from __future__ import annotations
21import logging
22from dataclasses import dataclass
23from typing import Any, Callable, Protocol
25from astrocyte.config import JwtMiddlewareConfig
26from astrocyte.errors import AuthorizationError
27from astrocyte.identity import format_principal
28from astrocyte.identity_jwt import classify_jwt_claims
29from astrocyte.types import AstrocyteContext
31logger = logging.getLogger("astrocyte.identity.mcp")
33#: Type alias for the decoder strategy. Given a raw JWT string, returns a
34#: decoded and signature-validated claim dict. Raises on any failure —
35#: the middleware maps exceptions to :class:`AuthorizationError`.
36Decoder = Callable[[str], dict[str, Any]]
39# ---------------------------------------------------------------------------
40# Header extraction
41# ---------------------------------------------------------------------------
44def extract_bearer_token(headers: dict[str, str]) -> str | None:
45 """Pull the Bearer token from an ``Authorization`` header, if any.
47 Header lookup is case-insensitive. Returns None when the header is
48 absent or doesn't start with ``Bearer ``; the middleware decides how
49 to handle that (fail closed vs anonymous fallback).
50 """
51 if not headers:
52 return None
53 # Case-insensitive lookup — HTTP header names aren't case-sensitive.
54 for key, value in headers.items():
55 if key.lower() == "authorization":
56 if not isinstance(value, str):
57 return None
58 stripped = value.strip()
59 if stripped.lower().startswith("bearer "):
60 # Slice past "Bearer " (7 chars) preserving the token.
61 return stripped[7:].strip() or None
62 return None
63 return None
66# ---------------------------------------------------------------------------
67# Middleware core — transport-independent
68# ---------------------------------------------------------------------------
71class JwtIdentityMiddleware:
72 """Resolve an :class:`AstrocyteContext` from inbound request headers.
74 The middleware orchestrates four steps:
76 1. Extract the Bearer token from the Authorization header.
77 2. If no token present, apply the anonymous / fail-closed policy.
78 3. Decode + validate the token via the injected :data:`Decoder`.
79 4. Classify the claims via :func:`classify_jwt_claims` and build an
80 :class:`AstrocyteContext` whose ``actor`` is the resolved identity
81 and whose ``principal`` is the canonical ``{type}:{id}`` form.
83 The decoder is injected so tests (and embedded deployments with
84 custom token formats) can replace PyJWT with a minimal substitute.
85 Production deployments use :func:`make_pyjwt_decoder` which wires up
86 a JWKS client against the configured endpoint.
87 """
89 def __init__(
90 self,
91 config: JwtMiddlewareConfig,
92 decoder: Decoder,
93 ) -> None:
94 self._config = config
95 self._decoder = decoder
97 def resolve(self, headers: dict[str, str]) -> AstrocyteContext:
98 """Resolve caller identity from inbound headers.
100 Returns a fully-formed :class:`AstrocyteContext` — ``principal``
101 is always populated so downstream access control and audit logs
102 have a label even for anonymous callers.
104 Raises:
105 AuthorizationError: Bearer token was presented but invalid
106 (signature, expiry, audience, issuer, or classification
107 failure); or no token was presented but anonymous access
108 is disabled; or the middleware is enabled but the caller
109 has no ``Authorization`` header.
110 """
111 token = extract_bearer_token(headers)
113 if token is None:
114 # No credential presented. Policy decides whether that's OK.
115 if self._config.fail_closed and not self._config.allow_anonymous:
116 raise AuthorizationError(
117 "No Bearer token presented and anonymous access is "
118 "disabled (identity.jwt_middleware.allow_anonymous=false)."
119 )
120 # Anonymous fallback — principal label makes this visible in logs.
121 return AstrocyteContext(principal="anonymous")
123 # Token presented — validate and classify. Any failure here is
124 # fail-closed regardless of allow_anonymous: a broken token is
125 # never silently downgraded to anonymous (that would let an
126 # attacker bypass identity-aware MIP rules by malforming their
127 # token).
128 try:
129 claims = self._decoder(token)
130 except AuthorizationError:
131 raise
132 except Exception as exc: # pragma: no cover — decoder-specific
133 logger.warning("JWT decode failed: %s", exc)
134 raise AuthorizationError(
135 f"Bearer token decode/validation failed: {exc}. Request rejected to prevent cross-user data leakage."
136 ) from exc
138 identity = classify_jwt_claims(claims)
139 return AstrocyteContext(
140 principal=format_principal(identity),
141 actor=identity,
142 )
145# ---------------------------------------------------------------------------
146# Default PyJWT-backed decoder (optional)
147# ---------------------------------------------------------------------------
150class _JWKSClient(Protocol):
151 def get_signing_key_from_jwt(self, token: str) -> Any: # pragma: no cover
152 """Return the signing key for ``token`` (PyJWT-compatible shape)."""
155@dataclass
156class _PyJWTContext:
157 """Captured state for the default PyJWT decoder."""
159 jwks_client: _JWKSClient
160 audience: str | None
161 issuer: str | None
162 algorithms: list[str]
165def make_pyjwt_decoder(config: JwtMiddlewareConfig) -> Decoder:
166 """Build a PyJWT-backed :data:`Decoder` from middleware config.
168 Raises :class:`ImportError` if PyJWT isn't installed. This keeps
169 PyJWT an optional dependency — deployments that don't enable JWT
170 middleware don't pay for it.
172 The decoder enforces:
173 - Signature via JWKS-resolved public key
174 - Algorithm allowlist (defaults to asymmetric — RS256, ES256)
175 - ``aud`` claim match (required — set on config)
176 - ``iss`` claim match (when configured)
177 - ``exp``, ``iat`` required claims — rejects tokens missing them
178 """
179 try:
180 import jwt # PyJWT # type: ignore[import-not-found]
181 except ImportError as exc: # pragma: no cover — optional dep
182 raise ImportError(
183 "JWT identity middleware requires PyJWT. Install with "
184 "'pip install pyjwt[cryptography]' or enable the 'identity' "
185 "extra on the astrocyte package."
186 ) from exc
188 if not config.jwks_uri:
189 raise ValueError("identity.jwt_middleware.jwks_uri must be set when enabled.")
190 if not config.token_audience:
191 raise ValueError(
192 "identity.jwt_middleware.token_audience must be set when "
193 "enabled — accepting any audience risks cross-tenant token "
194 "reuse."
195 )
197 jwks_client = jwt.PyJWKClient(
198 config.jwks_uri,
199 cache_keys=True,
200 lifespan=config.jwks_refresh_interval_hours * 3600,
201 )
202 ctx = _PyJWTContext(
203 jwks_client=jwks_client,
204 audience=config.token_audience,
205 issuer=config.token_issuer,
206 algorithms=list(config.algorithms),
207 )
209 def decode(token: str) -> dict[str, Any]:
210 signing_key = ctx.jwks_client.get_signing_key_from_jwt(token)
211 options: dict[str, Any] = {"require": ["exp", "iat"]}
212 kwargs: dict[str, Any] = {
213 "algorithms": ctx.algorithms,
214 "options": options,
215 }
216 if ctx.audience is not None:
217 kwargs["audience"] = ctx.audience
218 if ctx.issuer is not None:
219 kwargs["issuer"] = ctx.issuer
220 return jwt.decode(token, signing_key.key, **kwargs)
222 return decode
225# ---------------------------------------------------------------------------
226# Top-level factory
227# ---------------------------------------------------------------------------
230def build_jwt_middleware(
231 config: JwtMiddlewareConfig,
232 decoder: Decoder | None = None,
233) -> JwtIdentityMiddleware | None:
234 """Build a middleware if enabled in config, else return None.
236 Args:
237 config: The ``identity.jwt_middleware`` section of AstrocyteConfig.
238 decoder: Optional decoder override. When None (default), a
239 PyJWT-backed decoder is constructed from the config; pass a
240 custom callable for testing or non-PyJWT deployments.
241 """
242 if not config.enabled:
243 return None
244 if decoder is None:
245 decoder = make_pyjwt_decoder(config)
246 return JwtIdentityMiddleware(config, decoder)