generated from personal-projects/leo-claude-mktplace
Merge feat/17: Create test suite for wrapper functionality
This commit is contained in:
@@ -37,6 +37,8 @@ dev = [
|
|||||||
"pytest>=7.0.0",
|
"pytest>=7.0.0",
|
||||||
"pytest-asyncio>=0.21.0",
|
"pytest-asyncio>=0.21.0",
|
||||||
"pytest-cov>=4.0.0",
|
"pytest-cov>=4.0.0",
|
||||||
|
"httpx>=0.24.0",
|
||||||
|
"starlette>=0.36.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.scripts]
|
[project.scripts]
|
||||||
|
|||||||
18
pytest.ini
Normal file
18
pytest.ini
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
[pytest]
|
||||||
|
testpaths = src/gitea_http_wrapper/tests
|
||||||
|
python_files = test_*.py
|
||||||
|
python_classes = Test*
|
||||||
|
python_functions = test_*
|
||||||
|
asyncio_mode = auto
|
||||||
|
|
||||||
|
# Coverage options
|
||||||
|
addopts =
|
||||||
|
--verbose
|
||||||
|
--strict-markers
|
||||||
|
--tb=short
|
||||||
|
|
||||||
|
# Markers for test categorization
|
||||||
|
markers =
|
||||||
|
unit: Unit tests (fast, no external dependencies)
|
||||||
|
integration: Integration tests (may require external services)
|
||||||
|
slow: Slow-running tests
|
||||||
@@ -1,3 +1,9 @@
|
|||||||
"""Test suite for HTTP wrapper functionality."""
|
"""Test suite for HTTP wrapper functionality."""
|
||||||
|
|
||||||
|
# This package contains tests for:
|
||||||
|
# - config: Configuration loader and validation
|
||||||
|
# - filtering: Tool filtering for Claude Desktop compatibility
|
||||||
|
# - middleware: HTTP authentication middleware
|
||||||
|
# - server: Core HTTP MCP server (integration tests would go here)
|
||||||
|
|
||||||
__all__ = []
|
__all__ = []
|
||||||
|
|||||||
59
src/gitea_http_wrapper/tests/conftest.py
Normal file
59
src/gitea_http_wrapper/tests/conftest.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""Pytest configuration and shared fixtures for test suite."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_gitea_config():
|
||||||
|
"""Provide sample Gitea configuration for tests."""
|
||||||
|
return {
|
||||||
|
"gitea_url": "https://gitea.test.com",
|
||||||
|
"gitea_token": "test_token_123",
|
||||||
|
"gitea_owner": "test_owner",
|
||||||
|
"gitea_repo": "test_repo",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_tools_list():
|
||||||
|
"""Provide sample MCP tools list for testing."""
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "list_issues",
|
||||||
|
"description": "List issues in repository",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"state": {"type": "string", "enum": ["open", "closed", "all"]},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "create_issue",
|
||||||
|
"description": "Create a new issue",
|
||||||
|
"inputSchema": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"title": {"type": "string"},
|
||||||
|
"body": {"type": "string"},
|
||||||
|
},
|
||||||
|
"required": ["title"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "list_labels",
|
||||||
|
"description": "List labels in repository",
|
||||||
|
"inputSchema": {"type": "object", "properties": {}},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_mcp_response(sample_tools_list):
|
||||||
|
"""Provide sample MCP list_tools response."""
|
||||||
|
return {
|
||||||
|
"tools": sample_tools_list,
|
||||||
|
"meta": {
|
||||||
|
"version": "1.0",
|
||||||
|
},
|
||||||
|
}
|
||||||
210
src/gitea_http_wrapper/tests/test_config.py
Normal file
210
src/gitea_http_wrapper/tests/test_config.py
Normal file
@@ -0,0 +1,210 @@
|
|||||||
|
"""Tests for configuration loader module."""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from gitea_http_wrapper.config import GiteaSettings, load_settings
|
||||||
|
|
||||||
|
|
||||||
|
class TestGiteaSettings:
|
||||||
|
"""Test GiteaSettings configuration class."""
|
||||||
|
|
||||||
|
def test_required_fields(self):
|
||||||
|
"""Test that required fields are enforced."""
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
GiteaSettings()
|
||||||
|
|
||||||
|
errors = exc_info.value.errors()
|
||||||
|
required_fields = {"gitea_url", "gitea_token", "gitea_owner", "gitea_repo"}
|
||||||
|
error_fields = {error["loc"][0] for error in errors}
|
||||||
|
assert required_fields.issubset(error_fields)
|
||||||
|
|
||||||
|
def test_valid_configuration(self):
|
||||||
|
"""Test valid configuration creation."""
|
||||||
|
settings = GiteaSettings(
|
||||||
|
gitea_url="https://gitea.example.com",
|
||||||
|
gitea_token="test_token",
|
||||||
|
gitea_owner="test_owner",
|
||||||
|
gitea_repo="test_repo",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert settings.gitea_url == "https://gitea.example.com"
|
||||||
|
assert settings.gitea_token == "test_token"
|
||||||
|
assert settings.gitea_owner == "test_owner"
|
||||||
|
assert settings.gitea_repo == "test_repo"
|
||||||
|
assert settings.http_host == "127.0.0.1"
|
||||||
|
assert settings.http_port == 8000
|
||||||
|
assert settings.auth_token is None
|
||||||
|
|
||||||
|
def test_gitea_url_validation(self):
|
||||||
|
"""Test Gitea URL validation."""
|
||||||
|
# Valid URLs
|
||||||
|
valid_urls = [
|
||||||
|
"http://gitea.local",
|
||||||
|
"https://gitea.example.com",
|
||||||
|
"http://192.168.1.1:3000",
|
||||||
|
]
|
||||||
|
|
||||||
|
for url in valid_urls:
|
||||||
|
settings = GiteaSettings(
|
||||||
|
gitea_url=url,
|
||||||
|
gitea_token="token",
|
||||||
|
gitea_owner="owner",
|
||||||
|
gitea_repo="repo",
|
||||||
|
)
|
||||||
|
assert settings.gitea_url == url.rstrip("/")
|
||||||
|
|
||||||
|
# Invalid URL (no protocol)
|
||||||
|
with pytest.raises(ValidationError) as exc_info:
|
||||||
|
GiteaSettings(
|
||||||
|
gitea_url="gitea.example.com",
|
||||||
|
gitea_token="token",
|
||||||
|
gitea_owner="owner",
|
||||||
|
gitea_repo="repo",
|
||||||
|
)
|
||||||
|
assert "must start with http://" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_gitea_url_trailing_slash_removed(self):
|
||||||
|
"""Test that trailing slashes are removed from Gitea URL."""
|
||||||
|
settings = GiteaSettings(
|
||||||
|
gitea_url="https://gitea.example.com/",
|
||||||
|
gitea_token="token",
|
||||||
|
gitea_owner="owner",
|
||||||
|
gitea_repo="repo",
|
||||||
|
)
|
||||||
|
assert settings.gitea_url == "https://gitea.example.com"
|
||||||
|
|
||||||
|
def test_http_port_validation(self):
|
||||||
|
"""Test HTTP port validation."""
|
||||||
|
# Valid port
|
||||||
|
settings = GiteaSettings(
|
||||||
|
gitea_url="https://gitea.example.com",
|
||||||
|
gitea_token="token",
|
||||||
|
gitea_owner="owner",
|
||||||
|
gitea_repo="repo",
|
||||||
|
http_port=9000,
|
||||||
|
)
|
||||||
|
assert settings.http_port == 9000
|
||||||
|
|
||||||
|
# Invalid port (too high)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
GiteaSettings(
|
||||||
|
gitea_url="https://gitea.example.com",
|
||||||
|
gitea_token="token",
|
||||||
|
gitea_owner="owner",
|
||||||
|
gitea_repo="repo",
|
||||||
|
http_port=70000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Invalid port (too low)
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
GiteaSettings(
|
||||||
|
gitea_url="https://gitea.example.com",
|
||||||
|
gitea_token="token",
|
||||||
|
gitea_owner="owner",
|
||||||
|
gitea_repo="repo",
|
||||||
|
http_port=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_enabled_tools_list_parsing(self):
|
||||||
|
"""Test enabled_tools string parsing to list."""
|
||||||
|
settings = GiteaSettings(
|
||||||
|
gitea_url="https://gitea.example.com",
|
||||||
|
gitea_token="token",
|
||||||
|
gitea_owner="owner",
|
||||||
|
gitea_repo="repo",
|
||||||
|
enabled_tools="tool1,tool2,tool3",
|
||||||
|
)
|
||||||
|
assert settings.enabled_tools_list == ["tool1", "tool2", "tool3"]
|
||||||
|
|
||||||
|
# Test with spaces
|
||||||
|
settings = GiteaSettings(
|
||||||
|
gitea_url="https://gitea.example.com",
|
||||||
|
gitea_token="token",
|
||||||
|
gitea_owner="owner",
|
||||||
|
gitea_repo="repo",
|
||||||
|
enabled_tools="tool1, tool2 , tool3",
|
||||||
|
)
|
||||||
|
assert settings.enabled_tools_list == ["tool1", "tool2", "tool3"]
|
||||||
|
|
||||||
|
# Test empty string
|
||||||
|
settings = GiteaSettings(
|
||||||
|
gitea_url="https://gitea.example.com",
|
||||||
|
gitea_token="token",
|
||||||
|
gitea_owner="owner",
|
||||||
|
gitea_repo="repo",
|
||||||
|
enabled_tools="",
|
||||||
|
)
|
||||||
|
assert settings.enabled_tools_list is None
|
||||||
|
|
||||||
|
def test_disabled_tools_list_parsing(self):
|
||||||
|
"""Test disabled_tools string parsing to list."""
|
||||||
|
settings = GiteaSettings(
|
||||||
|
gitea_url="https://gitea.example.com",
|
||||||
|
gitea_token="token",
|
||||||
|
gitea_owner="owner",
|
||||||
|
gitea_repo="repo",
|
||||||
|
disabled_tools="tool1,tool2",
|
||||||
|
)
|
||||||
|
assert settings.disabled_tools_list == ["tool1", "tool2"]
|
||||||
|
|
||||||
|
def test_get_gitea_mcp_env(self):
|
||||||
|
"""Test environment variable generation for wrapped MCP server."""
|
||||||
|
settings = GiteaSettings(
|
||||||
|
gitea_url="https://gitea.example.com",
|
||||||
|
gitea_token="test_token",
|
||||||
|
gitea_owner="test_owner",
|
||||||
|
gitea_repo="test_repo",
|
||||||
|
)
|
||||||
|
|
||||||
|
env = settings.get_gitea_mcp_env()
|
||||||
|
|
||||||
|
assert env["GITEA_BASE_URL"] == "https://gitea.example.com"
|
||||||
|
assert env["GITEA_API_TOKEN"] == "test_token"
|
||||||
|
assert env["GITEA_DEFAULT_OWNER"] == "test_owner"
|
||||||
|
assert env["GITEA_DEFAULT_REPO"] == "test_repo"
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoadSettings:
|
||||||
|
"""Test load_settings factory function."""
|
||||||
|
|
||||||
|
def test_load_from_env_file(self, tmp_path):
|
||||||
|
"""Test loading settings from a .env file."""
|
||||||
|
env_file = tmp_path / ".env"
|
||||||
|
env_file.write_text(
|
||||||
|
"""
|
||||||
|
GITEA_URL=https://gitea.test.com
|
||||||
|
GITEA_TOKEN=test_token_123
|
||||||
|
GITEA_OWNER=test_owner
|
||||||
|
GITEA_REPO=test_repo
|
||||||
|
HTTP_PORT=9000
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
settings = load_settings(env_file)
|
||||||
|
|
||||||
|
assert settings.gitea_url == "https://gitea.test.com"
|
||||||
|
assert settings.gitea_token == "test_token_123"
|
||||||
|
assert settings.gitea_owner == "test_owner"
|
||||||
|
assert settings.gitea_repo == "test_repo"
|
||||||
|
assert settings.http_port == 9000
|
||||||
|
|
||||||
|
def test_load_from_environment(self, monkeypatch):
|
||||||
|
"""Test loading settings from environment variables."""
|
||||||
|
monkeypatch.setenv("GITEA_URL", "https://env.gitea.com")
|
||||||
|
monkeypatch.setenv("GITEA_TOKEN", "env_token")
|
||||||
|
monkeypatch.setenv("GITEA_OWNER", "env_owner")
|
||||||
|
monkeypatch.setenv("GITEA_REPO", "env_repo")
|
||||||
|
monkeypatch.setenv("HTTP_PORT", "8080")
|
||||||
|
|
||||||
|
# Mock _env_file to prevent loading actual .env
|
||||||
|
settings = GiteaSettings()
|
||||||
|
|
||||||
|
assert settings.gitea_url == "https://env.gitea.com"
|
||||||
|
assert settings.gitea_token == "env_token"
|
||||||
|
assert settings.gitea_owner == "env_owner"
|
||||||
|
assert settings.gitea_repo == "env_repo"
|
||||||
|
assert settings.http_port == 8080
|
||||||
143
src/gitea_http_wrapper/tests/test_filtering.py
Normal file
143
src/gitea_http_wrapper/tests/test_filtering.py
Normal file
@@ -0,0 +1,143 @@
|
|||||||
|
"""Tests for tool filtering module."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from gitea_http_wrapper.filtering import ToolFilter
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolFilter:
|
||||||
|
"""Test ToolFilter class."""
|
||||||
|
|
||||||
|
def test_init_with_both_lists_raises(self):
|
||||||
|
"""Test that specifying both enabled and disabled lists raises error."""
|
||||||
|
with pytest.raises(ValueError) as exc_info:
|
||||||
|
ToolFilter(enabled_tools=["tool1"], disabled_tools=["tool2"])
|
||||||
|
|
||||||
|
assert "Cannot specify both" in str(exc_info.value)
|
||||||
|
|
||||||
|
def test_passthrough_mode(self):
|
||||||
|
"""Test passthrough mode (no filtering)."""
|
||||||
|
filter = ToolFilter()
|
||||||
|
|
||||||
|
assert filter.should_include_tool("any_tool")
|
||||||
|
assert filter.should_include_tool("another_tool")
|
||||||
|
|
||||||
|
stats = filter.get_filter_stats()
|
||||||
|
assert stats["mode"] == "passthrough"
|
||||||
|
|
||||||
|
def test_whitelist_mode(self):
|
||||||
|
"""Test whitelist mode (enabled_tools)."""
|
||||||
|
filter = ToolFilter(enabled_tools=["tool1", "tool2"])
|
||||||
|
|
||||||
|
assert filter.should_include_tool("tool1")
|
||||||
|
assert filter.should_include_tool("tool2")
|
||||||
|
assert not filter.should_include_tool("tool3")
|
||||||
|
assert not filter.should_include_tool("tool4")
|
||||||
|
|
||||||
|
stats = filter.get_filter_stats()
|
||||||
|
assert stats["mode"] == "whitelist"
|
||||||
|
assert stats["enabled_count"] == 2
|
||||||
|
assert "tool1" in stats["enabled_tools"]
|
||||||
|
assert "tool2" in stats["enabled_tools"]
|
||||||
|
|
||||||
|
def test_blacklist_mode(self):
|
||||||
|
"""Test blacklist mode (disabled_tools)."""
|
||||||
|
filter = ToolFilter(disabled_tools=["tool1", "tool2"])
|
||||||
|
|
||||||
|
assert not filter.should_include_tool("tool1")
|
||||||
|
assert not filter.should_include_tool("tool2")
|
||||||
|
assert filter.should_include_tool("tool3")
|
||||||
|
assert filter.should_include_tool("tool4")
|
||||||
|
|
||||||
|
stats = filter.get_filter_stats()
|
||||||
|
assert stats["mode"] == "blacklist"
|
||||||
|
assert stats["disabled_count"] == 2
|
||||||
|
assert "tool1" in stats["disabled_tools"]
|
||||||
|
assert "tool2" in stats["disabled_tools"]
|
||||||
|
|
||||||
|
def test_filter_tools_list(self):
|
||||||
|
"""Test filtering a list of tool definitions."""
|
||||||
|
filter = ToolFilter(enabled_tools=["tool1", "tool3"])
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{"name": "tool1", "description": "First tool"},
|
||||||
|
{"name": "tool2", "description": "Second tool"},
|
||||||
|
{"name": "tool3", "description": "Third tool"},
|
||||||
|
{"name": "tool4", "description": "Fourth tool"},
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered = filter.filter_tools_list(tools)
|
||||||
|
|
||||||
|
assert len(filtered) == 2
|
||||||
|
assert filtered[0]["name"] == "tool1"
|
||||||
|
assert filtered[1]["name"] == "tool3"
|
||||||
|
|
||||||
|
def test_filter_tools_response(self):
|
||||||
|
"""Test filtering an MCP list_tools response."""
|
||||||
|
filter = ToolFilter(disabled_tools=["tool2"])
|
||||||
|
|
||||||
|
response = {
|
||||||
|
"tools": [
|
||||||
|
{"name": "tool1", "description": "First tool"},
|
||||||
|
{"name": "tool2", "description": "Second tool"},
|
||||||
|
{"name": "tool3", "description": "Third tool"},
|
||||||
|
],
|
||||||
|
"other_data": "preserved",
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered = filter.filter_tools_response(response)
|
||||||
|
|
||||||
|
assert len(filtered["tools"]) == 2
|
||||||
|
assert filtered["tools"][0]["name"] == "tool1"
|
||||||
|
assert filtered["tools"][1]["name"] == "tool3"
|
||||||
|
assert filtered["other_data"] == "preserved"
|
||||||
|
|
||||||
|
def test_filter_tools_response_no_tools_key(self):
|
||||||
|
"""Test filtering response without 'tools' key."""
|
||||||
|
filter = ToolFilter(enabled_tools=["tool1"])
|
||||||
|
|
||||||
|
response = {"other_data": "value"}
|
||||||
|
filtered = filter.filter_tools_response(response)
|
||||||
|
|
||||||
|
assert filtered == response
|
||||||
|
|
||||||
|
def test_filter_tools_response_immutable(self):
|
||||||
|
"""Test that original response is not mutated."""
|
||||||
|
filter = ToolFilter(enabled_tools=["tool1"])
|
||||||
|
|
||||||
|
original = {
|
||||||
|
"tools": [
|
||||||
|
{"name": "tool1"},
|
||||||
|
{"name": "tool2"},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
filtered = filter.filter_tools_response(original)
|
||||||
|
|
||||||
|
# Original should still have 2 tools
|
||||||
|
assert len(original["tools"]) == 2
|
||||||
|
# Filtered should have 1 tool
|
||||||
|
assert len(filtered["tools"]) == 1
|
||||||
|
|
||||||
|
def test_empty_tool_list(self):
|
||||||
|
"""Test filtering empty tool list."""
|
||||||
|
filter = ToolFilter(enabled_tools=["tool1"])
|
||||||
|
|
||||||
|
result = filter.filter_tools_list([])
|
||||||
|
assert result == []
|
||||||
|
|
||||||
|
def test_tool_with_no_name(self):
|
||||||
|
"""Test handling tool without name field."""
|
||||||
|
filter = ToolFilter(enabled_tools=["tool1"])
|
||||||
|
|
||||||
|
tools = [
|
||||||
|
{"name": "tool1"},
|
||||||
|
{"description": "No name"},
|
||||||
|
{"name": "tool2"},
|
||||||
|
]
|
||||||
|
|
||||||
|
filtered = filter.filter_tools_list(tools)
|
||||||
|
|
||||||
|
# Only tool1 should match, tool without name is excluded
|
||||||
|
assert len(filtered) == 1
|
||||||
|
assert filtered[0]["name"] == "tool1"
|
||||||
162
src/gitea_http_wrapper/tests/test_middleware.py
Normal file
162
src/gitea_http_wrapper/tests/test_middleware.py
Normal file
@@ -0,0 +1,162 @@
|
|||||||
|
"""Tests for HTTP authentication middleware."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.responses import JSONResponse
|
||||||
|
from starlette.routing import Route
|
||||||
|
from starlette.testclient import TestClient
|
||||||
|
|
||||||
|
from gitea_http_wrapper.middleware import (
|
||||||
|
BearerAuthMiddleware,
|
||||||
|
HealthCheckBypassMiddleware,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Test application endpoint
|
||||||
|
async def test_endpoint(request):
|
||||||
|
return JSONResponse({"message": "success"})
|
||||||
|
|
||||||
|
|
||||||
|
class TestBearerAuthMiddleware:
|
||||||
|
"""Test BearerAuthMiddleware."""
|
||||||
|
|
||||||
|
def test_no_auth_configured(self):
|
||||||
|
"""Test that requests pass through when no auth token is configured."""
|
||||||
|
app = Starlette(routes=[Route("/test", test_endpoint)])
|
||||||
|
app.add_middleware(BearerAuthMiddleware, auth_token=None)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/test")
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["message"] == "success"
|
||||||
|
|
||||||
|
def test_auth_configured_valid_token(self):
|
||||||
|
"""Test successful authentication with valid token."""
|
||||||
|
app = Starlette(routes=[Route("/test", test_endpoint)])
|
||||||
|
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/test", headers={"Authorization": "Bearer secret_token"})
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.json()["message"] == "success"
|
||||||
|
|
||||||
|
def test_auth_configured_missing_header(self):
|
||||||
|
"""Test rejection when Authorization header is missing."""
|
||||||
|
app = Starlette(routes=[Route("/test", test_endpoint)])
|
||||||
|
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/test")
|
||||||
|
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Missing Authorization header" in response.json()["message"]
|
||||||
|
|
||||||
|
def test_auth_configured_invalid_format(self):
|
||||||
|
"""Test rejection when Authorization header has wrong format."""
|
||||||
|
app = Starlette(routes=[Route("/test", test_endpoint)])
|
||||||
|
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Test with wrong scheme
|
||||||
|
response = client.get("/test", headers={"Authorization": "Basic secret_token"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
assert "Bearer scheme" in response.json()["message"]
|
||||||
|
|
||||||
|
# Test with no scheme
|
||||||
|
response = client.get("/test", headers={"Authorization": "secret_token"})
|
||||||
|
assert response.status_code == 401
|
||||||
|
|
||||||
|
def test_auth_configured_invalid_token(self):
|
||||||
|
"""Test rejection when token is invalid."""
|
||||||
|
app = Starlette(routes=[Route("/test", test_endpoint)])
|
||||||
|
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/test", headers={"Authorization": "Bearer wrong_token"})
|
||||||
|
|
||||||
|
assert response.status_code == 403
|
||||||
|
assert "Invalid authentication token" in response.json()["message"]
|
||||||
|
|
||||||
|
def test_auth_case_sensitive_token(self):
|
||||||
|
"""Test that token comparison is case-sensitive."""
|
||||||
|
app = Starlette(routes=[Route("/test", test_endpoint)])
|
||||||
|
app.add_middleware(BearerAuthMiddleware, auth_token="Secret_Token")
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Correct case
|
||||||
|
response = client.get("/test", headers={"Authorization": "Bearer Secret_Token"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Wrong case
|
||||||
|
response = client.get("/test", headers={"Authorization": "Bearer secret_token"})
|
||||||
|
assert response.status_code == 403
|
||||||
|
|
||||||
|
|
||||||
|
class TestHealthCheckBypassMiddleware:
|
||||||
|
"""Test HealthCheckBypassMiddleware."""
|
||||||
|
|
||||||
|
def test_default_health_check_paths(self):
|
||||||
|
"""Test that default health check paths bypass auth."""
|
||||||
|
app = Starlette(
|
||||||
|
routes=[
|
||||||
|
Route("/health", test_endpoint),
|
||||||
|
Route("/healthz", test_endpoint),
|
||||||
|
Route("/ping", test_endpoint),
|
||||||
|
Route("/test", test_endpoint),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||||
|
app.add_middleware(HealthCheckBypassMiddleware)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Health checks should work without auth
|
||||||
|
assert client.get("/health").status_code == 200
|
||||||
|
assert client.get("/healthz").status_code == 200
|
||||||
|
assert client.get("/ping").status_code == 200
|
||||||
|
|
||||||
|
# Regular endpoint should require auth
|
||||||
|
assert client.get("/test").status_code == 401
|
||||||
|
|
||||||
|
def test_custom_health_check_paths(self):
|
||||||
|
"""Test custom health check paths."""
|
||||||
|
app = Starlette(
|
||||||
|
routes=[
|
||||||
|
Route("/custom-health", test_endpoint),
|
||||||
|
Route("/test", test_endpoint),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||||
|
app.add_middleware(
|
||||||
|
HealthCheckBypassMiddleware,
|
||||||
|
health_check_paths=["/custom-health"],
|
||||||
|
)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Custom health check should work without auth
|
||||||
|
assert client.get("/custom-health").status_code == 200
|
||||||
|
|
||||||
|
# Regular endpoint should require auth
|
||||||
|
assert client.get("/test").status_code == 401
|
||||||
|
|
||||||
|
def test_middleware_order(self):
|
||||||
|
"""Test that middleware order is correct."""
|
||||||
|
# HealthCheckBypass should be added BEFORE BearerAuth
|
||||||
|
# so it can bypass the auth check
|
||||||
|
|
||||||
|
app = Starlette(routes=[Route("/health", test_endpoint)])
|
||||||
|
|
||||||
|
# Correct order: HealthCheck bypass first, then Auth
|
||||||
|
app.add_middleware(BearerAuthMiddleware, auth_token="secret_token")
|
||||||
|
app.add_middleware(HealthCheckBypassMiddleware)
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
response = client.get("/health")
|
||||||
|
|
||||||
|
# Should succeed without auth
|
||||||
|
assert response.status_code == 200
|
||||||
Reference in New Issue
Block a user