generated from personal-projects/leo-claude-mktplace
Compare commits
13 Commits
developmen
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 30e28dd09e | |||
| 608b488763 | |||
| 49f2d0bdbb | |||
| f2cba079eb | |||
| 4e81b9bb96 | |||
| d21f85545b | |||
| 3d1fd2e2a6 | |||
| 2fc43ff5c3 | |||
| d11649071e | |||
| 42d625c27f | |||
| 6beb8026df | |||
| acacefeaed | |||
| 604661f096 |
17
src/gitea_http_wrapper/__init__.py
Normal file
17
src/gitea_http_wrapper/__init__.py
Normal file
@@ -0,0 +1,17 @@
|
||||
"""
|
||||
Gitea HTTP MCP Wrapper
|
||||
|
||||
This package provides an HTTP transport wrapper around the official Gitea MCP server.
|
||||
It handles configuration loading, tool filtering, and HTTP authentication middleware.
|
||||
|
||||
Architecture:
|
||||
- config/: Configuration loader module
|
||||
- middleware/: HTTP authentication middleware
|
||||
- filtering/: Tool filtering for Claude Desktop compatibility
|
||||
- server.py: Main HTTP MCP server implementation
|
||||
"""
|
||||
|
||||
from .server import GiteaMCPWrapper, create_app, main
|
||||
|
||||
__version__ = "0.1.0"
|
||||
__all__ = ["__version__", "GiteaMCPWrapper", "create_app", "main"]
|
||||
5
src/gitea_http_wrapper/config/__init__.py
Normal file
5
src/gitea_http_wrapper/config/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Configuration loader module."""
|
||||
|
||||
from .settings import GiteaSettings, load_settings
|
||||
|
||||
__all__ = ["GiteaSettings", "load_settings"]
|
||||
113
src/gitea_http_wrapper/config/settings.py
Normal file
113
src/gitea_http_wrapper/config/settings.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Configuration settings for Gitea HTTP MCP wrapper."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class GiteaSettings(BaseSettings):
|
||||
"""Configuration settings loaded from environment or .env file."""
|
||||
|
||||
model_config = SettingsConfigDict(
|
||||
env_file=".env",
|
||||
env_file_encoding="utf-8",
|
||||
case_sensitive=False,
|
||||
extra="ignore",
|
||||
)
|
||||
|
||||
# Gitea Configuration
|
||||
gitea_url: str = Field(
|
||||
...,
|
||||
description="Gitea instance URL (e.g., https://git.example.com)",
|
||||
)
|
||||
gitea_token: str = Field(
|
||||
...,
|
||||
description="Gitea API token for authentication",
|
||||
)
|
||||
gitea_owner: str = Field(
|
||||
...,
|
||||
description="Default repository owner/organization",
|
||||
)
|
||||
gitea_repo: str = Field(
|
||||
...,
|
||||
description="Default repository name",
|
||||
)
|
||||
|
||||
# HTTP Server Configuration
|
||||
http_host: str = Field(
|
||||
default="127.0.0.1",
|
||||
description="HTTP server bind address",
|
||||
)
|
||||
http_port: int = Field(
|
||||
default=8000,
|
||||
ge=1,
|
||||
le=65535,
|
||||
description="HTTP server port",
|
||||
)
|
||||
|
||||
# Authentication Configuration
|
||||
auth_token: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Bearer token for HTTP authentication (optional)",
|
||||
)
|
||||
|
||||
# Tool Filtering Configuration
|
||||
enabled_tools: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Comma-separated list of enabled tools (optional, enables all if not set)",
|
||||
)
|
||||
disabled_tools: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Comma-separated list of disabled tools (optional)",
|
||||
)
|
||||
|
||||
@field_validator("gitea_url")
|
||||
@classmethod
|
||||
def validate_gitea_url(cls, v: str) -> str:
|
||||
"""Ensure Gitea URL is properly formatted."""
|
||||
if not v.startswith(("http://", "https://")):
|
||||
raise ValueError("gitea_url must start with http:// or https://")
|
||||
return v.rstrip("/")
|
||||
|
||||
@property
|
||||
def enabled_tools_list(self) -> Optional[list[str]]:
|
||||
"""Parse enabled_tools into a list."""
|
||||
if not self.enabled_tools:
|
||||
return None
|
||||
return [tool.strip() for tool in self.enabled_tools.split(",") if tool.strip()]
|
||||
|
||||
@property
|
||||
def disabled_tools_list(self) -> Optional[list[str]]:
|
||||
"""Parse disabled_tools into a list."""
|
||||
if not self.disabled_tools:
|
||||
return None
|
||||
return [tool.strip() for tool in self.disabled_tools.split(",") if tool.strip()]
|
||||
|
||||
def get_gitea_mcp_env(self) -> dict[str, str]:
|
||||
"""Get environment variables for the wrapped Gitea MCP server."""
|
||||
return {
|
||||
"GITEA_BASE_URL": self.gitea_url,
|
||||
"GITEA_API_TOKEN": self.gitea_token,
|
||||
"GITEA_DEFAULT_OWNER": self.gitea_owner,
|
||||
"GITEA_DEFAULT_REPO": self.gitea_repo,
|
||||
}
|
||||
|
||||
|
||||
def load_settings(env_file: Optional[Path] = None) -> GiteaSettings:
|
||||
"""
|
||||
Load settings from environment or .env file.
|
||||
|
||||
Args:
|
||||
env_file: Optional path to .env file. If not provided, searches for .env in current directory.
|
||||
|
||||
Returns:
|
||||
GiteaSettings instance with loaded configuration.
|
||||
|
||||
Raises:
|
||||
ValidationError: If required settings are missing or invalid.
|
||||
"""
|
||||
if env_file:
|
||||
return GiteaSettings(_env_file=env_file)
|
||||
return GiteaSettings()
|
||||
5
src/gitea_http_wrapper/filtering/__init__.py
Normal file
5
src/gitea_http_wrapper/filtering/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Tool filtering module for Claude Desktop compatibility."""
|
||||
|
||||
from .filter import ToolFilter
|
||||
|
||||
__all__ = ["ToolFilter"]
|
||||
108
src/gitea_http_wrapper/filtering/filter.py
Normal file
108
src/gitea_http_wrapper/filtering/filter.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Tool filtering for Claude Desktop compatibility."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class ToolFilter:
|
||||
"""
|
||||
Filter MCP tools based on enabled/disabled lists.
|
||||
|
||||
This class handles tool filtering to ensure only compatible tools are exposed
|
||||
to Claude Desktop, preventing crashes from unsupported tool schemas.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
enabled_tools: list[str] | None = None,
|
||||
disabled_tools: list[str] | None = None,
|
||||
):
|
||||
"""
|
||||
Initialize tool filter.
|
||||
|
||||
Args:
|
||||
enabled_tools: List of tool names to enable. If None, all tools are enabled.
|
||||
disabled_tools: List of tool names to disable. Takes precedence over enabled_tools.
|
||||
|
||||
Raises:
|
||||
ValueError: If both enabled_tools and disabled_tools are specified.
|
||||
"""
|
||||
if enabled_tools is not None and disabled_tools is not None:
|
||||
raise ValueError(
|
||||
"Cannot specify both enabled_tools and disabled_tools. Choose one filtering mode."
|
||||
)
|
||||
|
||||
self.enabled_tools = set(enabled_tools) if enabled_tools else None
|
||||
self.disabled_tools = set(disabled_tools) if disabled_tools else None
|
||||
|
||||
def should_include_tool(self, tool_name: str) -> bool:
|
||||
"""
|
||||
Determine if a tool should be included based on filter rules.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the tool to check.
|
||||
|
||||
Returns:
|
||||
True if tool should be included, False otherwise.
|
||||
"""
|
||||
# If disabled list is specified, exclude disabled tools
|
||||
if self.disabled_tools is not None:
|
||||
return tool_name not in self.disabled_tools
|
||||
|
||||
# If enabled list is specified, only include enabled tools
|
||||
if self.enabled_tools is not None:
|
||||
return tool_name in self.enabled_tools
|
||||
|
||||
# If no filters specified, include all tools
|
||||
return True
|
||||
|
||||
def filter_tools_list(self, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter a list of tool definitions.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions (dicts with at least a 'name' field).
|
||||
|
||||
Returns:
|
||||
Filtered list of tool definitions.
|
||||
"""
|
||||
return [tool for tool in tools if self.should_include_tool(tool.get("name", ""))]
|
||||
|
||||
def filter_tools_response(self, response: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Filter tools from an MCP list_tools response.
|
||||
|
||||
Args:
|
||||
response: MCP response dict containing 'tools' list.
|
||||
|
||||
Returns:
|
||||
Filtered response with tools list updated.
|
||||
"""
|
||||
if "tools" in response and isinstance(response["tools"], list):
|
||||
response = response.copy()
|
||||
response["tools"] = self.filter_tools_list(response["tools"])
|
||||
return response
|
||||
|
||||
def get_filter_stats(self) -> dict[str, Any]:
|
||||
"""
|
||||
Get statistics about the filter configuration.
|
||||
|
||||
Returns:
|
||||
Dict containing filter mode and tool counts.
|
||||
"""
|
||||
if self.disabled_tools is not None:
|
||||
return {
|
||||
"mode": "blacklist",
|
||||
"disabled_count": len(self.disabled_tools),
|
||||
"disabled_tools": sorted(self.disabled_tools),
|
||||
}
|
||||
elif self.enabled_tools is not None:
|
||||
return {
|
||||
"mode": "whitelist",
|
||||
"enabled_count": len(self.enabled_tools),
|
||||
"enabled_tools": sorted(self.enabled_tools),
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"mode": "passthrough",
|
||||
"message": "All tools enabled",
|
||||
}
|
||||
5
src/gitea_http_wrapper/middleware/__init__.py
Normal file
5
src/gitea_http_wrapper/middleware/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""HTTP authentication middleware module."""
|
||||
|
||||
from .auth import BearerAuthMiddleware, HealthCheckBypassMiddleware
|
||||
|
||||
__all__ = ["BearerAuthMiddleware", "HealthCheckBypassMiddleware"]
|
||||
144
src/gitea_http_wrapper/middleware/auth.py
Normal file
144
src/gitea_http_wrapper/middleware/auth.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""HTTP authentication middleware for MCP server."""
|
||||
|
||||
import logging
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse, Response
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BearerAuthMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to enforce Bearer token authentication on HTTP requests.
|
||||
|
||||
This middleware validates the Authorization header for all requests.
|
||||
If a token is configured, requests must include "Authorization: Bearer <token>".
|
||||
If no token is configured, all requests are allowed (open access).
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_token: str | None = None):
|
||||
"""
|
||||
Initialize authentication middleware.
|
||||
|
||||
Args:
|
||||
app: ASGI application to wrap.
|
||||
auth_token: Optional Bearer token for authentication.
|
||||
If None, authentication is disabled.
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.auth_token = auth_token
|
||||
self.auth_enabled = auth_token is not None
|
||||
|
||||
if self.auth_enabled:
|
||||
logger.info("Bearer authentication enabled")
|
||||
else:
|
||||
logger.warning("Bearer authentication disabled - server is open access")
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
"""
|
||||
Process request and enforce authentication if enabled.
|
||||
|
||||
Args:
|
||||
request: Incoming HTTP request.
|
||||
call_next: Next middleware or route handler.
|
||||
|
||||
Returns:
|
||||
Response from downstream handler or 401/403 error.
|
||||
"""
|
||||
# Skip authentication if disabled
|
||||
if not self.auth_enabled:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip authentication if marked by HealthCheckBypassMiddleware
|
||||
if getattr(request.state, "skip_auth", False):
|
||||
return await call_next(request)
|
||||
|
||||
# Extract Authorization header
|
||||
auth_header = request.headers.get("Authorization")
|
||||
|
||||
# Check if header is present
|
||||
if not auth_header:
|
||||
logger.warning(f"Missing Authorization header from {request.client.host}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": "Unauthorized",
|
||||
"message": "Missing Authorization header",
|
||||
},
|
||||
)
|
||||
|
||||
# Check if header format is correct
|
||||
if not auth_header.startswith("Bearer "):
|
||||
logger.warning(f"Invalid Authorization format from {request.client.host}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={
|
||||
"error": "Unauthorized",
|
||||
"message": "Authorization header must use Bearer scheme",
|
||||
},
|
||||
)
|
||||
|
||||
# Extract token
|
||||
provided_token = auth_header[7:] # Remove "Bearer " prefix
|
||||
|
||||
# Validate token
|
||||
if provided_token != self.auth_token:
|
||||
logger.warning(f"Invalid token from {request.client.host}")
|
||||
return JSONResponse(
|
||||
status_code=403,
|
||||
content={
|
||||
"error": "Forbidden",
|
||||
"message": "Invalid authentication token",
|
||||
},
|
||||
)
|
||||
|
||||
# Token is valid, proceed to next handler
|
||||
logger.debug(f"Authenticated request from {request.client.host}")
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
class HealthCheckBypassMiddleware(BaseHTTPMiddleware):
|
||||
"""
|
||||
Middleware to bypass authentication for health check endpoints.
|
||||
|
||||
This allows monitoring systems to check server health without authentication.
|
||||
"""
|
||||
|
||||
def __init__(self, app, health_check_paths: list[str] | None = None):
|
||||
"""
|
||||
Initialize health check bypass middleware.
|
||||
|
||||
Args:
|
||||
app: ASGI application to wrap.
|
||||
health_check_paths: List of paths to bypass authentication.
|
||||
Defaults to ["/health", "/healthz", "/ping"].
|
||||
"""
|
||||
super().__init__(app)
|
||||
self.health_check_paths = health_check_paths or ["/health", "/healthz", "/ping"]
|
||||
|
||||
async def dispatch(
|
||||
self, request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
"""
|
||||
Process request and bypass authentication for health checks.
|
||||
|
||||
Args:
|
||||
request: Incoming HTTP request.
|
||||
call_next: Next middleware or route handler.
|
||||
|
||||
Returns:
|
||||
Response from downstream handler.
|
||||
"""
|
||||
# Check if request is for a health check endpoint
|
||||
if request.url.path in self.health_check_paths:
|
||||
logger.debug(f"Bypassing auth for health check: {request.url.path}")
|
||||
# Mark request to skip authentication in BearerAuthMiddleware
|
||||
request.state.skip_auth = True
|
||||
|
||||
# Continue to next middleware
|
||||
return await call_next(request)
|
||||
309
src/gitea_http_wrapper/server.py
Normal file
309
src/gitea_http_wrapper/server.py
Normal file
@@ -0,0 +1,309 @@
|
||||
"""HTTP MCP server implementation wrapping Gitea MCP."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
from mcp.server import Server
|
||||
from mcp.server.stdio import stdio_server
|
||||
from starlette.applications import Starlette
|
||||
from starlette.requests import Request
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
|
||||
from gitea_http_wrapper.config import GiteaSettings, load_settings
|
||||
from gitea_http_wrapper.filtering import ToolFilter
|
||||
from gitea_http_wrapper.middleware import (
|
||||
BearerAuthMiddleware,
|
||||
HealthCheckBypassMiddleware,
|
||||
)
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GiteaMCPWrapper:
|
||||
"""
|
||||
HTTP wrapper around the official Gitea MCP server.
|
||||
|
||||
This class manages:
|
||||
1. Starting the Gitea MCP server as a subprocess with stdio transport
|
||||
2. Proxying HTTP requests to the MCP server
|
||||
3. Filtering tools based on configuration
|
||||
4. Handling responses and errors
|
||||
"""
|
||||
|
||||
def __init__(self, settings: GiteaSettings):
|
||||
"""
|
||||
Initialize the MCP wrapper.
|
||||
|
||||
Args:
|
||||
settings: Configuration settings for Gitea and HTTP server.
|
||||
"""
|
||||
self.settings = settings
|
||||
self.tool_filter = ToolFilter(
|
||||
enabled_tools=settings.enabled_tools_list,
|
||||
disabled_tools=settings.disabled_tools_list,
|
||||
)
|
||||
self.process = None
|
||||
self.reader = None
|
||||
self.writer = None
|
||||
|
||||
async def start_gitea_mcp(self) -> None:
|
||||
"""
|
||||
Start the Gitea MCP server as a subprocess.
|
||||
|
||||
The server runs with stdio transport, and we communicate via stdin/stdout.
|
||||
"""
|
||||
logger.info("Starting Gitea MCP server subprocess")
|
||||
|
||||
# Set environment variables for Gitea MCP
|
||||
env = os.environ.copy()
|
||||
env.update(self.settings.get_gitea_mcp_env())
|
||||
|
||||
# Start the process
|
||||
# Note: This assumes gitea-mcp-server is installed and on PATH
|
||||
# In production Docker, this should be guaranteed
|
||||
try:
|
||||
self.process = await asyncio.create_subprocess_exec(
|
||||
"gitea-mcp-server",
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
env=env,
|
||||
)
|
||||
self.reader = self.process.stdout
|
||||
self.writer = self.process.stdin
|
||||
logger.info("Gitea MCP server started successfully")
|
||||
except FileNotFoundError:
|
||||
logger.error("gitea-mcp-server not found in PATH")
|
||||
raise RuntimeError(
|
||||
"gitea-mcp-server not found. Ensure it's installed: pip install gitea-mcp-server"
|
||||
)
|
||||
|
||||
async def stop_gitea_mcp(self) -> None:
|
||||
"""Stop the Gitea MCP server subprocess."""
|
||||
if self.process:
|
||||
logger.info("Stopping Gitea MCP server subprocess")
|
||||
self.process.terminate()
|
||||
await self.process.wait()
|
||||
logger.info("Gitea MCP server stopped")
|
||||
|
||||
async def send_mcp_request(self, method: str, params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Send a JSON-RPC request to the MCP server.
|
||||
|
||||
Args:
|
||||
method: MCP method name (e.g., "tools/list", "tools/call").
|
||||
params: Method parameters.
|
||||
|
||||
Returns:
|
||||
JSON-RPC response from MCP server.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If MCP server is not running or communication fails.
|
||||
"""
|
||||
if not self.writer or not self.reader:
|
||||
raise RuntimeError("MCP server not started")
|
||||
|
||||
# Build JSON-RPC request
|
||||
request = {
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": method,
|
||||
"params": params,
|
||||
}
|
||||
|
||||
# Send request
|
||||
request_json = json.dumps(request) + "\n"
|
||||
self.writer.write(request_json.encode())
|
||||
await self.writer.drain()
|
||||
|
||||
# Read response
|
||||
response_line = await self.reader.readline()
|
||||
response = json.loads(response_line.decode())
|
||||
|
||||
# Check for JSON-RPC error
|
||||
if "error" in response:
|
||||
logger.error(f"MCP error: {response['error']}")
|
||||
raise RuntimeError(f"MCP error: {response['error']}")
|
||||
|
||||
return response.get("result", {})
|
||||
|
||||
async def list_tools(self) -> dict[str, Any]:
|
||||
"""
|
||||
List available tools from MCP server with filtering applied.
|
||||
|
||||
Returns:
|
||||
Filtered tools list response.
|
||||
"""
|
||||
response = await self.send_mcp_request("tools/list", {})
|
||||
filtered_response = self.tool_filter.filter_tools_response(response)
|
||||
|
||||
logger.info(
|
||||
f"Listed {len(filtered_response.get('tools', []))} tools "
|
||||
f"(filter: {self.tool_filter.get_filter_stats()['mode']})"
|
||||
)
|
||||
return filtered_response
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Call a tool on the MCP server.
|
||||
|
||||
Args:
|
||||
tool_name: Name of tool to call.
|
||||
arguments: Tool arguments.
|
||||
|
||||
Returns:
|
||||
Tool execution result.
|
||||
|
||||
Raises:
|
||||
ValueError: If tool is filtered out.
|
||||
"""
|
||||
# Check if tool is allowed
|
||||
if not self.tool_filter.should_include_tool(tool_name):
|
||||
raise ValueError(f"Tool '{tool_name}' is not available (filtered)")
|
||||
|
||||
logger.info(f"Calling tool: {tool_name}")
|
||||
result = await self.send_mcp_request(
|
||||
"tools/call",
|
||||
{"name": tool_name, "arguments": arguments},
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
# Global wrapper instance
|
||||
wrapper: GiteaMCPWrapper | None = None
|
||||
|
||||
|
||||
async def health_check(request: Request) -> JSONResponse:
|
||||
"""Health check endpoint."""
|
||||
return JSONResponse({"status": "healthy"})
|
||||
|
||||
|
||||
async def list_tools_endpoint(request: Request) -> JSONResponse:
|
||||
"""List available tools."""
|
||||
try:
|
||||
tools = await wrapper.list_tools()
|
||||
return JSONResponse(tools)
|
||||
except Exception as e:
|
||||
logger.exception("Error listing tools")
|
||||
return JSONResponse(
|
||||
{"error": str(e)},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
async def call_tool_endpoint(request: Request) -> JSONResponse:
|
||||
"""Call a tool."""
|
||||
try:
|
||||
body = await request.json()
|
||||
tool_name = body.get("name")
|
||||
arguments = body.get("arguments", {})
|
||||
|
||||
if not tool_name:
|
||||
return JSONResponse(
|
||||
{"error": "Missing 'name' field"},
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
result = await wrapper.call_tool(tool_name, arguments)
|
||||
return JSONResponse(result)
|
||||
except ValueError as e:
|
||||
# Tool filtered
|
||||
return JSONResponse(
|
||||
{"error": str(e)},
|
||||
status_code=403,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Error calling tool")
|
||||
return JSONResponse(
|
||||
{"error": str(e)},
|
||||
status_code=500,
|
||||
)
|
||||
|
||||
|
||||
async def startup() -> None:
|
||||
"""Application startup handler."""
|
||||
global wrapper
|
||||
settings = load_settings()
|
||||
wrapper = GiteaMCPWrapper(settings)
|
||||
await wrapper.start_gitea_mcp()
|
||||
logger.info(f"HTTP MCP server starting on {settings.http_host}:{settings.http_port}")
|
||||
|
||||
|
||||
async def shutdown() -> None:
|
||||
"""Application shutdown handler."""
|
||||
global wrapper
|
||||
if wrapper:
|
||||
await wrapper.stop_gitea_mcp()
|
||||
|
||||
|
||||
# Define routes
|
||||
routes = [
|
||||
Route("/health", health_check, methods=["GET"]),
|
||||
Route("/healthz", health_check, methods=["GET"]),
|
||||
Route("/ping", health_check, methods=["GET"]),
|
||||
Route("/tools/list", list_tools_endpoint, methods=["POST"]),
|
||||
Route("/tools/call", call_tool_endpoint, methods=["POST"]),
|
||||
]
|
||||
|
||||
# Create Starlette app
|
||||
app = Starlette(
|
||||
routes=routes,
|
||||
on_startup=[startup],
|
||||
on_shutdown=[shutdown],
|
||||
)
|
||||
|
||||
|
||||
def create_app(settings: GiteaSettings | None = None) -> Starlette:
|
||||
"""
|
||||
Create and configure the Starlette application.
|
||||
|
||||
Args:
|
||||
settings: Optional settings override for testing.
|
||||
|
||||
Returns:
|
||||
Configured Starlette application.
|
||||
"""
|
||||
if settings is None:
|
||||
settings = load_settings()
|
||||
|
||||
# Add middleware
|
||||
app.add_middleware(HealthCheckBypassMiddleware)
|
||||
app.add_middleware(BearerAuthMiddleware, auth_token=settings.auth_token)
|
||||
|
||||
return app
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point for the HTTP MCP server."""
|
||||
settings = load_settings()
|
||||
|
||||
# Log filter configuration
|
||||
filter_stats = ToolFilter(
|
||||
enabled_tools=settings.enabled_tools_list,
|
||||
disabled_tools=settings.disabled_tools_list,
|
||||
).get_filter_stats()
|
||||
logger.info(f"Tool filtering: {filter_stats}")
|
||||
|
||||
# Run server
|
||||
uvicorn.run(
|
||||
"gitea_http_wrapper.server:app",
|
||||
host=settings.http_host,
|
||||
port=settings.http_port,
|
||||
log_level="info",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
9
src/gitea_http_wrapper/tests/__init__.py
Normal file
9
src/gitea_http_wrapper/tests/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
"""Test suite for HTTP wrapper functionality."""
|
||||
|
||||
# This package contains tests for:
|
||||
# - config: Configuration loader and validation
|
||||
# - filtering: Tool filtering for Claude Desktop compatibility
|
||||
# - middleware: HTTP authentication middleware
|
||||
# - server: Core HTTP MCP server (integration tests would go here)
|
||||
|
||||
__all__ = []
|
||||
59
src/gitea_http_wrapper/tests/conftest.py
Normal file
59
src/gitea_http_wrapper/tests/conftest.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Pytest configuration and shared fixtures for test suite."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_gitea_config():
|
||||
"""Provide sample Gitea configuration for tests."""
|
||||
return {
|
||||
"gitea_url": "https://gitea.test.com",
|
||||
"gitea_token": "test_token_123",
|
||||
"gitea_owner": "test_owner",
|
||||
"gitea_repo": "test_repo",
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tools_list():
|
||||
"""Provide sample MCP tools list for testing."""
|
||||
return [
|
||||
{
|
||||
"name": "list_issues",
|
||||
"description": "List issues in repository",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"state": {"type": "string", "enum": ["open", "closed", "all"]},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "create_issue",
|
||||
"description": "Create a new issue",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"body": {"type": "string"},
|
||||
},
|
||||
"required": ["title"],
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": "list_labels",
|
||||
"description": "List labels in repository",
|
||||
"inputSchema": {"type": "object", "properties": {}},
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_mcp_response(sample_tools_list):
|
||||
"""Provide sample MCP list_tools response."""
|
||||
return {
|
||||
"tools": sample_tools_list,
|
||||
"meta": {
|
||||
"version": "1.0",
|
||||
},
|
||||
}
|
||||
211
src/gitea_http_wrapper/tests/test_config.py
Normal file
211
src/gitea_http_wrapper/tests/test_config.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""Tests for configuration loader module."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from gitea_http_wrapper.config import GiteaSettings, load_settings
|
||||
|
||||
|
||||
class TestGiteaSettings:
|
||||
"""Test GiteaSettings configuration class."""
|
||||
|
||||
def test_required_fields(self):
|
||||
"""Test that required fields are enforced."""
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
GiteaSettings()
|
||||
|
||||
errors = exc_info.value.errors()
|
||||
# Note: gitea_repo is optional (for PMO mode)
|
||||
required_fields = {"gitea_url", "gitea_token", "gitea_owner"}
|
||||
error_fields = {error["loc"][0] for error in errors}
|
||||
assert required_fields.issubset(error_fields)
|
||||
|
||||
def test_valid_configuration(self):
|
||||
"""Test valid configuration creation."""
|
||||
settings = GiteaSettings(
|
||||
gitea_url="https://gitea.example.com",
|
||||
gitea_token="test_token",
|
||||
gitea_owner="test_owner",
|
||||
gitea_repo="test_repo",
|
||||
)
|
||||
|
||||
assert settings.gitea_url == "https://gitea.example.com"
|
||||
assert settings.gitea_token == "test_token"
|
||||
assert settings.gitea_owner == "test_owner"
|
||||
assert settings.gitea_repo == "test_repo"
|
||||
assert settings.http_host == "127.0.0.1"
|
||||
assert settings.http_port == 8000
|
||||
assert settings.auth_token is None
|
||||
|
||||
def test_gitea_url_validation(self):
|
||||
"""Test Gitea URL validation."""
|
||||
# Valid URLs
|
||||
valid_urls = [
|
||||
"http://gitea.local",
|
||||
"https://gitea.example.com",
|
||||
"http://192.168.1.1:3000",
|
||||
]
|
||||
|
||||
for url in valid_urls:
|
||||
settings = GiteaSettings(
|
||||
gitea_url=url,
|
||||
gitea_token="token",
|
||||
gitea_owner="owner",
|
||||
gitea_repo="repo",
|
||||
)
|
||||
assert settings.gitea_url == url.rstrip("/")
|
||||
|
||||
# Invalid URL (no protocol)
|
||||
with pytest.raises(ValidationError) as exc_info:
|
||||
GiteaSettings(
|
||||
gitea_url="gitea.example.com",
|
||||
gitea_token="token",
|
||||
gitea_owner="owner",
|
||||
gitea_repo="repo",
|
||||
)
|
||||
assert "must start with http://" in str(exc_info.value)
|
||||
|
||||
def test_gitea_url_trailing_slash_removed(self):
|
||||
"""Test that trailing slashes are removed from Gitea URL."""
|
||||
settings = GiteaSettings(
|
||||
gitea_url="https://gitea.example.com/",
|
||||
gitea_token="token",
|
||||
gitea_owner="owner",
|
||||
gitea_repo="repo",
|
||||
)
|
||||
assert settings.gitea_url == "https://gitea.example.com"
|
||||
|
||||
def test_http_port_validation(self):
|
||||
"""Test HTTP port validation."""
|
||||
# Valid port
|
||||
settings = GiteaSettings(
|
||||
gitea_url="https://gitea.example.com",
|
||||
gitea_token="token",
|
||||
gitea_owner="owner",
|
||||
gitea_repo="repo",
|
||||
http_port=9000,
|
||||
)
|
||||
assert settings.http_port == 9000
|
||||
|
||||
# Invalid port (too high)
|
||||
with pytest.raises(ValidationError):
|
||||
GiteaSettings(
|
||||
gitea_url="https://gitea.example.com",
|
||||
gitea_token="token",
|
||||
gitea_owner="owner",
|
||||
gitea_repo="repo",
|
||||
http_port=70000,
|
||||
)
|
||||
|
||||
# Invalid port (too low)
|
||||
with pytest.raises(ValidationError):
|
||||
GiteaSettings(
|
||||
gitea_url="https://gitea.example.com",
|
||||
gitea_token="token",
|
||||
gitea_owner="owner",
|
||||
gitea_repo="repo",
|
||||
http_port=0,
|
||||
)
|
||||
|
||||
def test_enabled_tools_list_parsing(self):
|
||||
"""Test enabled_tools string parsing to list."""
|
||||
settings = GiteaSettings(
|
||||
gitea_url="https://gitea.example.com",
|
||||
gitea_token="token",
|
||||
gitea_owner="owner",
|
||||
gitea_repo="repo",
|
||||
enabled_tools="tool1,tool2,tool3",
|
||||
)
|
||||
assert settings.enabled_tools_list == ["tool1", "tool2", "tool3"]
|
||||
|
||||
# Test with spaces
|
||||
settings = GiteaSettings(
|
||||
gitea_url="https://gitea.example.com",
|
||||
gitea_token="token",
|
||||
gitea_owner="owner",
|
||||
gitea_repo="repo",
|
||||
enabled_tools="tool1, tool2 , tool3",
|
||||
)
|
||||
assert settings.enabled_tools_list == ["tool1", "tool2", "tool3"]
|
||||
|
||||
# Test empty string
|
||||
settings = GiteaSettings(
|
||||
gitea_url="https://gitea.example.com",
|
||||
gitea_token="token",
|
||||
gitea_owner="owner",
|
||||
gitea_repo="repo",
|
||||
enabled_tools="",
|
||||
)
|
||||
assert settings.enabled_tools_list is None
|
||||
|
||||
def test_disabled_tools_list_parsing(self):
|
||||
"""Test disabled_tools string parsing to list."""
|
||||
settings = GiteaSettings(
|
||||
gitea_url="https://gitea.example.com",
|
||||
gitea_token="token",
|
||||
gitea_owner="owner",
|
||||
gitea_repo="repo",
|
||||
disabled_tools="tool1,tool2",
|
||||
)
|
||||
assert settings.disabled_tools_list == ["tool1", "tool2"]
|
||||
|
||||
def test_get_gitea_mcp_env(self):
|
||||
"""Test environment variable generation for wrapped MCP server."""
|
||||
settings = GiteaSettings(
|
||||
gitea_url="https://gitea.example.com",
|
||||
gitea_token="test_token",
|
||||
gitea_owner="test_owner",
|
||||
gitea_repo="test_repo",
|
||||
)
|
||||
|
||||
env = settings.get_gitea_mcp_env()
|
||||
|
||||
assert env["GITEA_BASE_URL"] == "https://gitea.example.com"
|
||||
assert env["GITEA_API_TOKEN"] == "test_token"
|
||||
assert env["GITEA_DEFAULT_OWNER"] == "test_owner"
|
||||
assert env["GITEA_DEFAULT_REPO"] == "test_repo"
|
||||
|
||||
|
||||
class TestLoadSettings:
|
||||
"""Test load_settings factory function."""
|
||||
|
||||
def test_load_from_env_file(self, tmp_path):
|
||||
"""Test loading settings from a .env file."""
|
||||
env_file = tmp_path / ".env"
|
||||
env_file.write_text(
|
||||
"""
|
||||
GITEA_URL=https://gitea.test.com
|
||||
GITEA_TOKEN=test_token_123
|
||||
GITEA_OWNER=test_owner
|
||||
GITEA_REPO=test_repo
|
||||
HTTP_PORT=9000
|
||||
"""
|
||||
)
|
||||
|
||||
settings = load_settings(env_file)
|
||||
|
||||
assert settings.gitea_url == "https://gitea.test.com"
|
||||
assert settings.gitea_token == "test_token_123"
|
||||
assert settings.gitea_owner == "test_owner"
|
||||
assert settings.gitea_repo == "test_repo"
|
||||
assert settings.http_port == 9000
|
||||
|
||||
def test_load_from_environment(self, monkeypatch):
|
||||
"""Test loading settings from environment variables."""
|
||||
monkeypatch.setenv("GITEA_URL", "https://env.gitea.com")
|
||||
monkeypatch.setenv("GITEA_TOKEN", "env_token")
|
||||
monkeypatch.setenv("GITEA_OWNER", "env_owner")
|
||||
monkeypatch.setenv("GITEA_REPO", "env_repo")
|
||||
monkeypatch.setenv("HTTP_PORT", "8080")
|
||||
|
||||
# Mock _env_file to prevent loading actual .env
|
||||
settings = GiteaSettings()
|
||||
|
||||
assert settings.gitea_url == "https://env.gitea.com"
|
||||
assert settings.gitea_token == "env_token"
|
||||
assert settings.gitea_owner == "env_owner"
|
||||
assert settings.gitea_repo == "env_repo"
|
||||
assert settings.http_port == 8080
|
||||
143
src/gitea_http_wrapper/tests/test_filtering.py
Normal file
143
src/gitea_http_wrapper/tests/test_filtering.py
Normal file
@@ -0,0 +1,143 @@
|
||||
"""Tests for tool filtering module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from gitea_http_wrapper.filtering import ToolFilter
|
||||
|
||||
|
||||
class TestToolFilter:
|
||||
"""Test ToolFilter class."""
|
||||
|
||||
def test_init_with_both_lists_raises(self):
|
||||
"""Test that specifying both enabled and disabled lists raises error."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
ToolFilter(enabled_tools=["tool1"], disabled_tools=["tool2"])
|
||||
|
||||
assert "Cannot specify both" in str(exc_info.value)
|
||||
|
||||
def test_passthrough_mode(self):
|
||||
"""Test passthrough mode (no filtering)."""
|
||||
filter = ToolFilter()
|
||||
|
||||
assert filter.should_include_tool("any_tool")
|
||||
assert filter.should_include_tool("another_tool")
|
||||
|
||||
stats = filter.get_filter_stats()
|
||||
assert stats["mode"] == "passthrough"
|
||||
|
||||
def test_whitelist_mode(self):
|
||||
"""Test whitelist mode (enabled_tools)."""
|
||||
filter = ToolFilter(enabled_tools=["tool1", "tool2"])
|
||||
|
||||
assert filter.should_include_tool("tool1")
|
||||
assert filter.should_include_tool("tool2")
|
||||
assert not filter.should_include_tool("tool3")
|
||||
assert not filter.should_include_tool("tool4")
|
||||
|
||||
stats = filter.get_filter_stats()
|
||||
assert stats["mode"] == "whitelist"
|
||||
assert stats["enabled_count"] == 2
|
||||
assert "tool1" in stats["enabled_tools"]
|
||||
assert "tool2" in stats["enabled_tools"]
|
||||
|
||||
def test_blacklist_mode(self):
|
||||
"""Test blacklist mode (disabled_tools)."""
|
||||
filter = ToolFilter(disabled_tools=["tool1", "tool2"])
|
||||
|
||||
assert not filter.should_include_tool("tool1")
|
||||
assert not filter.should_include_tool("tool2")
|
||||
assert filter.should_include_tool("tool3")
|
||||
assert filter.should_include_tool("tool4")
|
||||
|
||||
stats = filter.get_filter_stats()
|
||||
assert stats["mode"] == "blacklist"
|
||||
assert stats["disabled_count"] == 2
|
||||
assert "tool1" in stats["disabled_tools"]
|
||||
assert "tool2" in stats["disabled_tools"]
|
||||
|
||||
def test_filter_tools_list(self):
|
||||
"""Test filtering a list of tool definitions."""
|
||||
filter = ToolFilter(enabled_tools=["tool1", "tool3"])
|
||||
|
||||
tools = [
|
||||
{"name": "tool1", "description": "First tool"},
|
||||
{"name": "tool2", "description": "Second tool"},
|
||||
{"name": "tool3", "description": "Third tool"},
|
||||
{"name": "tool4", "description": "Fourth tool"},
|
||||
]
|
||||
|
||||
filtered = filter.filter_tools_list(tools)
|
||||
|
||||
assert len(filtered) == 2
|
||||
assert filtered[0]["name"] == "tool1"
|
||||
assert filtered[1]["name"] == "tool3"
|
||||
|
||||
def test_filter_tools_response(self):
|
||||
"""Test filtering an MCP list_tools response."""
|
||||
filter = ToolFilter(disabled_tools=["tool2"])
|
||||
|
||||
response = {
|
||||
"tools": [
|
||||
{"name": "tool1", "description": "First tool"},
|
||||
{"name": "tool2", "description": "Second tool"},
|
||||
{"name": "tool3", "description": "Third tool"},
|
||||
],
|
||||
"other_data": "preserved",
|
||||
}
|
||||
|
||||
filtered = filter.filter_tools_response(response)
|
||||
|
||||
assert len(filtered["tools"]) == 2
|
||||
assert filtered["tools"][0]["name"] == "tool1"
|
||||
assert filtered["tools"][1]["name"] == "tool3"
|
||||
assert filtered["other_data"] == "preserved"
|
||||
|
||||
def test_filter_tools_response_no_tools_key(self):
|
||||
"""Test filtering response without 'tools' key."""
|
||||
filter = ToolFilter(enabled_tools=["tool1"])
|
||||
|
||||
response = {"other_data": "value"}
|
||||
filtered = filter.filter_tools_response(response)
|
||||
|
||||
assert filtered == response
|
||||
|
||||
def test_filter_tools_response_immutable(self):
|
||||
"""Test that original response is not mutated."""
|
||||
filter = ToolFilter(enabled_tools=["tool1"])
|
||||
|
||||
original = {
|
||||
"tools": [
|
||||
{"name": "tool1"},
|
||||
{"name": "tool2"},
|
||||
]
|
||||
}
|
||||
|
||||
filtered = filter.filter_tools_response(original)
|
||||
|
||||
# Original should still have 2 tools
|
||||
assert len(original["tools"]) == 2
|
||||
# Filtered should have 1 tool
|
||||
assert len(filtered["tools"]) == 1
|
||||
|
||||
def test_empty_tool_list(self):
|
||||
"""Test filtering empty tool list."""
|
||||
filter = ToolFilter(enabled_tools=["tool1"])
|
||||
|
||||
result = filter.filter_tools_list([])
|
||||
assert result == []
|
||||
|
||||
def test_tool_with_no_name(self):
|
||||
"""Test handling tool without name field."""
|
||||
filter = ToolFilter(enabled_tools=["tool1"])
|
||||
|
||||
tools = [
|
||||
{"name": "tool1"},
|
||||
{"description": "No name"},
|
||||
{"name": "tool2"},
|
||||
]
|
||||
|
||||
filtered = filter.filter_tools_list(tools)
|
||||
|
||||
# Only tool1 should match, tool without name is excluded
|
||||
assert len(filtered) == 1
|
||||
assert filtered[0]["name"] == "tool1"
|
||||
162
src/gitea_http_wrapper/tests/test_middleware.py
Normal file
162
src/gitea_http_wrapper/tests/test_middleware.py
Normal file
@@ -0,0 +1,162 @@
|
||||
"""Tests for HTTP authentication middleware."""
|
||||
|
||||
import pytest
|
||||
from starlette.applications import Starlette
|
||||
from starlette.responses import JSONResponse
|
||||
from starlette.routing import Route
|
||||
from starlette.testclient import TestClient
|
||||
|
||||
from gitea_http_wrapper.middleware import (
|
||||
BearerAuthMiddleware,
|
||||
HealthCheckBypassMiddleware,
|
||||
)
|
||||
|
||||
|
||||
# Test application endpoint
|
||||
async def test_endpoint(request):
|
||||
return JSONResponse({"message": "success"})
|
||||
|
||||
|
||||
class TestBearerAuthMiddleware:
|
||||
"""Test BearerAuthMiddleware."""
|
||||
|
||||
def test_no_auth_configured(self):
|
||||
"""Test that requests pass through when no auth token is configured."""
|
||||
app = Starlette(routes=[Route("/test", test_endpoint)])
|
||||
app.add_middleware(BearerAuthMiddleware, auth_token=None)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/test")
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "success"
|
||||
|
||||
def test_auth_configured_valid_token(self):
|
||||
"""Test successful authentication with valid token."""
|
||||
app = Starlette(routes=[Route("/test", test_endpoint)])
|
||||
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/test", headers={"Authorization": "Bearer secret_token"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert response.json()["message"] == "success"
|
||||
|
||||
def test_auth_configured_missing_header(self):
|
||||
"""Test rejection when Authorization header is missing."""
|
||||
app = Starlette(routes=[Route("/test", test_endpoint)])
|
||||
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/test")
|
||||
|
||||
assert response.status_code == 401
|
||||
assert "Missing Authorization header" in response.json()["message"]
|
||||
|
||||
def test_auth_configured_invalid_format(self):
|
||||
"""Test rejection when Authorization header has wrong format."""
|
||||
app = Starlette(routes=[Route("/test", test_endpoint)])
|
||||
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Test with wrong scheme
|
||||
response = client.get("/test", headers={"Authorization": "Basic secret_token"})
|
||||
assert response.status_code == 401
|
||||
assert "Bearer scheme" in response.json()["message"]
|
||||
|
||||
# Test with no scheme
|
||||
response = client.get("/test", headers={"Authorization": "secret_token"})
|
||||
assert response.status_code == 401
|
||||
|
||||
def test_auth_configured_invalid_token(self):
|
||||
"""Test rejection when token is invalid."""
|
||||
app = Starlette(routes=[Route("/test", test_endpoint)])
|
||||
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/test", headers={"Authorization": "Bearer wrong_token"})
|
||||
|
||||
assert response.status_code == 403
|
||||
assert "Invalid authentication token" in response.json()["message"]
|
||||
|
||||
def test_auth_case_sensitive_token(self):
|
||||
"""Test that token comparison is case-sensitive."""
|
||||
app = Starlette(routes=[Route("/test", test_endpoint)])
|
||||
app.add_middleware(BearerAuthMiddleware, auth_token="Secret_Token")
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Correct case
|
||||
response = client.get("/test", headers={"Authorization": "Bearer Secret_Token"})
|
||||
assert response.status_code == 200
|
||||
|
||||
# Wrong case
|
||||
response = client.get("/test", headers={"Authorization": "Bearer secret_token"})
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
class TestHealthCheckBypassMiddleware:
|
||||
"""Test HealthCheckBypassMiddleware."""
|
||||
|
||||
def test_default_health_check_paths(self):
|
||||
"""Test that default health check paths bypass auth."""
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Route("/health", test_endpoint),
|
||||
Route("/healthz", test_endpoint),
|
||||
Route("/ping", test_endpoint),
|
||||
Route("/test", test_endpoint),
|
||||
]
|
||||
)
|
||||
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||
app.add_middleware(HealthCheckBypassMiddleware)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Health checks should work without auth
|
||||
assert client.get("/health").status_code == 200
|
||||
assert client.get("/healthz").status_code == 200
|
||||
assert client.get("/ping").status_code == 200
|
||||
|
||||
# Regular endpoint should require auth
|
||||
assert client.get("/test").status_code == 401
|
||||
|
||||
def test_custom_health_check_paths(self):
|
||||
"""Test custom health check paths."""
|
||||
app = Starlette(
|
||||
routes=[
|
||||
Route("/custom-health", test_endpoint),
|
||||
Route("/test", test_endpoint),
|
||||
]
|
||||
)
|
||||
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||
app.add_middleware(
|
||||
HealthCheckBypassMiddleware,
|
||||
health_check_paths=["/custom-health"],
|
||||
)
|
||||
|
||||
client = TestClient(app)
|
||||
|
||||
# Custom health check should work without auth
|
||||
assert client.get("/custom-health").status_code == 200
|
||||
|
||||
# Regular endpoint should require auth
|
||||
assert client.get("/test").status_code == 401
|
||||
|
||||
def test_middleware_order(self):
|
||||
"""Test that middleware order is correct."""
|
||||
# HealthCheckBypass should be added BEFORE BearerAuth
|
||||
# so it can bypass the auth check
|
||||
|
||||
app = Starlette(routes=[Route("/health", test_endpoint)])
|
||||
|
||||
# Correct order: HealthCheck bypass first, then Auth
|
||||
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||
app.add_middleware(HealthCheckBypassMiddleware)
|
||||
|
||||
client = TestClient(app)
|
||||
response = client.get("/health")
|
||||
|
||||
# Should succeed without auth
|
||||
assert response.status_code == 200
|
||||
Reference in New Issue
Block a user