"""HTTP authentication middleware for MCP server.""" import logging from typing import Awaitable, Callable from starlette.middleware.base import BaseHTTPMiddleware from starlette.requests import Request from starlette.responses import JSONResponse, Response logger = logging.getLogger(__name__) class BearerAuthMiddleware(BaseHTTPMiddleware): """ Middleware to enforce Bearer token authentication on HTTP requests. This middleware validates the Authorization header for all requests. If a token is configured, requests must include "Authorization: Bearer ". If no token is configured, all requests are allowed (open access). """ def __init__(self, app, auth_token: str | None = None): """ Initialize authentication middleware. Args: app: ASGI application to wrap. auth_token: Optional Bearer token for authentication. If None, authentication is disabled. """ super().__init__(app) self.auth_token = auth_token self.auth_enabled = auth_token is not None if self.auth_enabled: logger.info("Bearer authentication enabled") else: logger.warning("Bearer authentication disabled - server is open access") async def dispatch( self, request: Request, call_next: Callable[[Request], Awaitable[Response]] ) -> Response: """ Process request and enforce authentication if enabled. Args: request: Incoming HTTP request. call_next: Next middleware or route handler. Returns: Response from downstream handler or 401/403 error. """ # Skip authentication if disabled if not self.auth_enabled: return await call_next(request) # Skip authentication if marked by HealthCheckBypassMiddleware if getattr(request.state, "skip_auth", False): return await call_next(request) # Extract Authorization header auth_header = request.headers.get("Authorization") # Check if header is present if not auth_header: logger.warning(f"Missing Authorization header from {request.client.host}") return JSONResponse( status_code=401, content={ "error": "Unauthorized", "message": "Missing Authorization header", }, ) # Check if header format is correct if not auth_header.startswith("Bearer "): logger.warning(f"Invalid Authorization format from {request.client.host}") return JSONResponse( status_code=401, content={ "error": "Unauthorized", "message": "Authorization header must use Bearer scheme", }, ) # Extract token provided_token = auth_header[7:] # Remove "Bearer " prefix # Validate token if provided_token != self.auth_token: logger.warning(f"Invalid token from {request.client.host}") return JSONResponse( status_code=403, content={ "error": "Forbidden", "message": "Invalid authentication token", }, ) # Token is valid, proceed to next handler logger.debug(f"Authenticated request from {request.client.host}") return await call_next(request) class HealthCheckBypassMiddleware(BaseHTTPMiddleware): """ Middleware to bypass authentication for health check endpoints. This allows monitoring systems to check server health without authentication. """ def __init__(self, app, health_check_paths: list[str] | None = None): """ Initialize health check bypass middleware. Args: app: ASGI application to wrap. health_check_paths: List of paths to bypass authentication. Defaults to ["/health", "/healthz", "/ping"]. """ super().__init__(app) self.health_check_paths = health_check_paths or ["/health", "/healthz", "/ping"] async def dispatch( self, request: Request, call_next: Callable[[Request], Awaitable[Response]] ) -> Response: """ Process request and bypass authentication for health checks. Args: request: Incoming HTTP request. call_next: Next middleware or route handler. Returns: Response from downstream handler. """ # Check if request is for a health check endpoint if request.url.path in self.health_check_paths: logger.debug(f"Bypassing auth for health check: {request.url.path}") # Mark request to skip authentication in BearerAuthMiddleware request.state.skip_auth = True # Continue to next middleware return await call_next(request)