feat: add data-platform plugin (v4.0.0)

Add new data-platform plugin for data engineering workflows with:

MCP Server (32 tools):
- pandas operations (14 tools): read_csv, read_parquet, read_json,
  to_csv, to_parquet, describe, head, tail, filter, select, groupby,
  join, list_data, drop_data
- PostgreSQL/PostGIS (10 tools): pg_connect, pg_query, pg_execute,
  pg_tables, pg_columns, pg_schemas, st_tables, st_geometry_type,
  st_srid, st_extent
- dbt integration (8 tools): dbt_parse, dbt_run, dbt_test, dbt_build,
  dbt_compile, dbt_ls, dbt_docs_generate, dbt_lineage

Plugin Features:
- Arrow IPC data_ref system for DataFrame persistence across tool calls
- Pre-execution validation for dbt with `dbt parse`
- SessionStart hook for PostgreSQL connectivity check (non-blocking)
- Hybrid configuration (system ~/.config/claude/postgres.env + project .env)
- Memory management with 100k row limit and chunking support

Commands: /initial-setup, /ingest, /profile, /schema, /explain, /lineage, /run
Agents: data-ingestion, data-analysis

Test suite: 71 tests covering config, data store, pandas, postgres, dbt tools

Addresses data workflow issues from personal-portfolio project:
- Lost data after multiple interactions (solved by Arrow IPC data_ref)
- dbt 1.9+ syntax deprecation (solved by pre-execution validation)
- Ungraceful PostgreSQL error handling (solved by SessionStart hook)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-25 14:24:03 -05:00
parent 6a267d074b
commit 89f0354ccc
39 changed files with 5413 additions and 6 deletions

View File

@@ -0,0 +1,3 @@
"""
Tests for Data Platform MCP Server.
"""

View File

@@ -0,0 +1,239 @@
"""
Unit tests for configuration loader.
"""
import pytest
from pathlib import Path
import os
def test_load_system_config(tmp_path, monkeypatch):
"""Test loading system-level PostgreSQL configuration"""
# Import here to avoid import errors before setup
from mcp_server.config import DataPlatformConfig
# Mock home directory
config_dir = tmp_path / '.config' / 'claude'
config_dir.mkdir(parents=True)
config_file = config_dir / 'postgres.env'
config_file.write_text(
"POSTGRES_URL=postgresql://user:pass@localhost:5432/testdb\n"
)
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(tmp_path)
config = DataPlatformConfig()
result = config.load()
assert result['postgres_url'] == 'postgresql://user:pass@localhost:5432/testdb'
assert result['postgres_available'] is True
def test_postgres_optional(tmp_path, monkeypatch):
"""Test that PostgreSQL configuration is optional"""
from mcp_server.config import DataPlatformConfig
# No postgres.env file
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(tmp_path)
# Clear any existing env vars
monkeypatch.delenv('POSTGRES_URL', raising=False)
config = DataPlatformConfig()
result = config.load()
assert result['postgres_url'] is None
assert result['postgres_available'] is False
def test_project_config_override(tmp_path, monkeypatch):
"""Test that project config overrides system config"""
from mcp_server.config import DataPlatformConfig
# Set up system config
system_config_dir = tmp_path / '.config' / 'claude'
system_config_dir.mkdir(parents=True)
system_config = system_config_dir / 'postgres.env'
system_config.write_text(
"POSTGRES_URL=postgresql://system:pass@localhost:5432/systemdb\n"
)
# Set up project config
project_dir = tmp_path / 'project'
project_dir.mkdir()
project_config = project_dir / '.env'
project_config.write_text(
"POSTGRES_URL=postgresql://project:pass@localhost:5432/projectdb\n"
"DBT_PROJECT_DIR=/path/to/dbt\n"
)
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(project_dir)
config = DataPlatformConfig()
result = config.load()
# Project config should override
assert result['postgres_url'] == 'postgresql://project:pass@localhost:5432/projectdb'
assert result['dbt_project_dir'] == '/path/to/dbt'
def test_max_rows_config(tmp_path, monkeypatch):
"""Test max rows configuration"""
from mcp_server.config import DataPlatformConfig
project_dir = tmp_path / 'project'
project_dir.mkdir()
project_config = project_dir / '.env'
project_config.write_text("DATA_PLATFORM_MAX_ROWS=50000\n")
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(project_dir)
config = DataPlatformConfig()
result = config.load()
assert result['max_rows'] == 50000
def test_default_max_rows(tmp_path, monkeypatch):
"""Test default max rows value"""
from mcp_server.config import DataPlatformConfig
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(tmp_path)
# Clear any existing env vars
monkeypatch.delenv('DATA_PLATFORM_MAX_ROWS', raising=False)
config = DataPlatformConfig()
result = config.load()
assert result['max_rows'] == 100_000 # Default value
def test_dbt_auto_detection(tmp_path, monkeypatch):
"""Test automatic dbt project detection"""
from mcp_server.config import DataPlatformConfig
# Create project with dbt_project.yml
project_dir = tmp_path / 'project'
project_dir.mkdir()
(project_dir / 'dbt_project.yml').write_text("name: test_project\n")
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(project_dir)
# Clear PWD and DBT_PROJECT_DIR to ensure auto-detection
monkeypatch.delenv('PWD', raising=False)
monkeypatch.delenv('DBT_PROJECT_DIR', raising=False)
monkeypatch.delenv('CLAUDE_PROJECT_DIR', raising=False)
config = DataPlatformConfig()
result = config.load()
assert result['dbt_project_dir'] == str(project_dir)
assert result['dbt_available'] is True
def test_dbt_subdirectory_detection(tmp_path, monkeypatch):
"""Test dbt project detection in subdirectory"""
from mcp_server.config import DataPlatformConfig
# Create project with dbt in subdirectory
project_dir = tmp_path / 'project'
project_dir.mkdir()
# Need a marker file for _find_project_directory to find the project
(project_dir / '.git').mkdir()
dbt_dir = project_dir / 'transform'
dbt_dir.mkdir()
(dbt_dir / 'dbt_project.yml').write_text("name: test_project\n")
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(project_dir)
# Clear env vars to ensure auto-detection
monkeypatch.delenv('PWD', raising=False)
monkeypatch.delenv('DBT_PROJECT_DIR', raising=False)
monkeypatch.delenv('CLAUDE_PROJECT_DIR', raising=False)
config = DataPlatformConfig()
result = config.load()
assert result['dbt_project_dir'] == str(dbt_dir)
assert result['dbt_available'] is True
def test_no_dbt_project(tmp_path, monkeypatch):
"""Test when no dbt project exists"""
from mcp_server.config import DataPlatformConfig
project_dir = tmp_path / 'project'
project_dir.mkdir()
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(project_dir)
# Clear any existing env vars
monkeypatch.delenv('DBT_PROJECT_DIR', raising=False)
config = DataPlatformConfig()
result = config.load()
assert result['dbt_project_dir'] is None
assert result['dbt_available'] is False
def test_find_project_directory_from_env(tmp_path, monkeypatch):
"""Test finding project directory from CLAUDE_PROJECT_DIR env var"""
from mcp_server.config import DataPlatformConfig
project_dir = tmp_path / 'my-project'
project_dir.mkdir()
(project_dir / '.git').mkdir()
monkeypatch.setenv('CLAUDE_PROJECT_DIR', str(project_dir))
config = DataPlatformConfig()
result = config._find_project_directory()
assert result == project_dir
def test_find_project_directory_from_cwd(tmp_path, monkeypatch):
"""Test finding project directory from cwd with .env file"""
from mcp_server.config import DataPlatformConfig
project_dir = tmp_path / 'project'
project_dir.mkdir()
(project_dir / '.env').write_text("TEST=value")
monkeypatch.chdir(project_dir)
monkeypatch.delenv('CLAUDE_PROJECT_DIR', raising=False)
monkeypatch.delenv('PWD', raising=False)
config = DataPlatformConfig()
result = config._find_project_directory()
assert result == project_dir
def test_find_project_directory_none_when_no_markers(tmp_path, monkeypatch):
"""Test returns None when no project markers found"""
from mcp_server.config import DataPlatformConfig
empty_dir = tmp_path / 'empty'
empty_dir.mkdir()
monkeypatch.chdir(empty_dir)
monkeypatch.delenv('CLAUDE_PROJECT_DIR', raising=False)
monkeypatch.delenv('PWD', raising=False)
monkeypatch.delenv('DBT_PROJECT_DIR', raising=False)
config = DataPlatformConfig()
result = config._find_project_directory()
assert result is None

View File

@@ -0,0 +1,240 @@
"""
Unit tests for Arrow IPC DataFrame registry.
"""
import pytest
import pandas as pd
import pyarrow as pa
def test_store_pandas_dataframe():
"""Test storing pandas DataFrame"""
from mcp_server.data_store import DataStore
# Create fresh instance for test
store = DataStore()
store._dataframes = {}
store._metadata = {}
df = pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})
data_ref = store.store(df, name='test_df')
assert data_ref == 'test_df'
assert 'test_df' in store._dataframes
assert store._metadata['test_df'].rows == 3
assert store._metadata['test_df'].columns == 2
def test_store_arrow_table():
"""Test storing Arrow Table directly"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
table = pa.table({'x': [1, 2, 3], 'y': [4, 5, 6]})
data_ref = store.store(table, name='arrow_test')
assert data_ref == 'arrow_test'
assert store._dataframes['arrow_test'].num_rows == 3
def test_store_auto_name():
"""Test auto-generated data_ref names"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
df = pd.DataFrame({'a': [1, 2]})
data_ref = store.store(df)
assert data_ref.startswith('df_')
assert len(data_ref) == 11 # df_ + 8 hex chars
def test_get_dataframe():
"""Test retrieving stored DataFrame"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
df = pd.DataFrame({'a': [1, 2, 3]})
store.store(df, name='get_test')
result = store.get('get_test')
assert result is not None
assert result.num_rows == 3
def test_get_pandas():
"""Test retrieving as pandas DataFrame"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
df = pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})
store.store(df, name='pandas_test')
result = store.get_pandas('pandas_test')
assert isinstance(result, pd.DataFrame)
assert list(result.columns) == ['a', 'b']
assert len(result) == 3
def test_get_nonexistent():
"""Test getting nonexistent data_ref returns None"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
assert store.get('nonexistent') is None
assert store.get_pandas('nonexistent') is None
def test_list_refs():
"""Test listing all stored DataFrames"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
store.store(pd.DataFrame({'a': [1, 2]}), name='df1')
store.store(pd.DataFrame({'b': [3, 4, 5]}), name='df2')
refs = store.list_refs()
assert len(refs) == 2
ref_names = [r['ref'] for r in refs]
assert 'df1' in ref_names
assert 'df2' in ref_names
def test_drop_dataframe():
"""Test dropping a DataFrame"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
store.store(pd.DataFrame({'a': [1]}), name='drop_test')
assert store.get('drop_test') is not None
result = store.drop('drop_test')
assert result is True
assert store.get('drop_test') is None
def test_drop_nonexistent():
"""Test dropping nonexistent data_ref"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
result = store.drop('nonexistent')
assert result is False
def test_clear():
"""Test clearing all DataFrames"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
store.store(pd.DataFrame({'a': [1]}), name='df1')
store.store(pd.DataFrame({'b': [2]}), name='df2')
store.clear()
assert len(store.list_refs()) == 0
def test_get_info():
"""Test getting DataFrame metadata"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
df = pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})
store.store(df, name='info_test', source='test source')
info = store.get_info('info_test')
assert info.ref == 'info_test'
assert info.rows == 3
assert info.columns == 2
assert info.column_names == ['a', 'b']
assert info.source == 'test source'
assert info.memory_bytes > 0
def test_total_memory():
"""Test total memory calculation"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
store.store(pd.DataFrame({'a': range(100)}), name='df1')
store.store(pd.DataFrame({'b': range(200)}), name='df2')
total = store.total_memory_bytes()
assert total > 0
total_mb = store.total_memory_mb()
assert total_mb >= 0
def test_check_row_limit():
"""Test row limit checking"""
from mcp_server.data_store import DataStore
store = DataStore()
store._max_rows = 100
# Under limit
result = store.check_row_limit(50)
assert result['exceeded'] is False
# Over limit
result = store.check_row_limit(150)
assert result['exceeded'] is True
assert 'suggestion' in result
def test_metadata_dtypes():
"""Test that dtypes are correctly recorded"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
df = pd.DataFrame({
'int_col': [1, 2, 3],
'float_col': [1.1, 2.2, 3.3],
'str_col': ['a', 'b', 'c']
})
store.store(df, name='dtype_test')
info = store.get_info('dtype_test')
assert 'int_col' in info.dtypes
assert 'float_col' in info.dtypes
assert 'str_col' in info.dtypes

View File

@@ -0,0 +1,318 @@
"""
Unit tests for dbt MCP tools.
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
import subprocess
import json
import tempfile
import os
@pytest.fixture
def mock_config(tmp_path):
"""Mock configuration with dbt project"""
dbt_dir = tmp_path / 'dbt_project'
dbt_dir.mkdir()
(dbt_dir / 'dbt_project.yml').write_text('name: test_project\n')
return {
'dbt_project_dir': str(dbt_dir),
'dbt_profiles_dir': str(tmp_path / '.dbt')
}
@pytest.fixture
def dbt_tools(mock_config):
"""Create DbtTools instance with mocked config"""
with patch('mcp_server.dbt_tools.load_config', return_value=mock_config):
from mcp_server.dbt_tools import DbtTools
tools = DbtTools()
tools.project_dir = mock_config['dbt_project_dir']
tools.profiles_dir = mock_config['dbt_profiles_dir']
return tools
@pytest.mark.asyncio
async def test_dbt_parse_success(dbt_tools):
"""Test successful dbt parse"""
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = 'Parsed successfully'
mock_result.stderr = ''
with patch('subprocess.run', return_value=mock_result):
result = await dbt_tools.dbt_parse()
assert result['valid'] is True
@pytest.mark.asyncio
async def test_dbt_parse_failure(dbt_tools):
"""Test dbt parse with errors"""
mock_result = MagicMock()
mock_result.returncode = 1
mock_result.stdout = ''
mock_result.stderr = 'Compilation error: deprecated syntax'
with patch('subprocess.run', return_value=mock_result):
result = await dbt_tools.dbt_parse()
assert result['valid'] is False
assert 'deprecated' in str(result.get('details', '')).lower() or len(result.get('errors', [])) > 0
@pytest.mark.asyncio
async def test_dbt_run_with_prevalidation(dbt_tools):
"""Test dbt run includes pre-validation"""
# First call is parse, second is run
mock_parse = MagicMock()
mock_parse.returncode = 0
mock_parse.stdout = 'OK'
mock_parse.stderr = ''
mock_run = MagicMock()
mock_run.returncode = 0
mock_run.stdout = 'Completed successfully'
mock_run.stderr = ''
with patch('subprocess.run', side_effect=[mock_parse, mock_run]):
result = await dbt_tools.dbt_run()
assert result['success'] is True
@pytest.mark.asyncio
async def test_dbt_run_fails_validation(dbt_tools):
"""Test dbt run fails if validation fails"""
mock_parse = MagicMock()
mock_parse.returncode = 1
mock_parse.stdout = ''
mock_parse.stderr = 'Parse error'
with patch('subprocess.run', return_value=mock_parse):
result = await dbt_tools.dbt_run()
assert 'error' in result
assert 'Pre-validation failed' in result['error']
@pytest.mark.asyncio
async def test_dbt_run_with_selection(dbt_tools):
"""Test dbt run with model selection"""
mock_parse = MagicMock()
mock_parse.returncode = 0
mock_parse.stdout = 'OK'
mock_parse.stderr = ''
mock_run = MagicMock()
mock_run.returncode = 0
mock_run.stdout = 'Completed'
mock_run.stderr = ''
calls = []
def track_calls(*args, **kwargs):
calls.append(args[0] if args else kwargs.get('args', []))
if len(calls) == 1:
return mock_parse
return mock_run
with patch('subprocess.run', side_effect=track_calls):
result = await dbt_tools.dbt_run(select='dim_customers')
# Verify --select was passed
assert any('--select' in str(call) for call in calls)
@pytest.mark.asyncio
async def test_dbt_test(dbt_tools):
"""Test dbt test"""
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = 'All tests passed'
mock_result.stderr = ''
with patch('subprocess.run', return_value=mock_result):
result = await dbt_tools.dbt_test()
assert result['success'] is True
@pytest.mark.asyncio
async def test_dbt_build(dbt_tools):
"""Test dbt build with pre-validation"""
mock_parse = MagicMock()
mock_parse.returncode = 0
mock_parse.stdout = 'OK'
mock_parse.stderr = ''
mock_build = MagicMock()
mock_build.returncode = 0
mock_build.stdout = 'Build complete'
mock_build.stderr = ''
with patch('subprocess.run', side_effect=[mock_parse, mock_build]):
result = await dbt_tools.dbt_build()
assert result['success'] is True
@pytest.mark.asyncio
async def test_dbt_compile(dbt_tools):
"""Test dbt compile"""
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = 'Compiled'
mock_result.stderr = ''
with patch('subprocess.run', return_value=mock_result):
result = await dbt_tools.dbt_compile()
assert result['success'] is True
@pytest.mark.asyncio
async def test_dbt_ls(dbt_tools):
"""Test dbt ls"""
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = 'dim_customers\ndim_products\nfct_orders\n'
mock_result.stderr = ''
with patch('subprocess.run', return_value=mock_result):
result = await dbt_tools.dbt_ls()
assert result['success'] is True
assert result['count'] == 3
assert 'dim_customers' in result['resources']
@pytest.mark.asyncio
async def test_dbt_docs_generate(dbt_tools, tmp_path):
"""Test dbt docs generate"""
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = 'Done'
mock_result.stderr = ''
# Create fake target directory
target_dir = tmp_path / 'dbt_project' / 'target'
target_dir.mkdir(parents=True)
(target_dir / 'catalog.json').write_text('{}')
(target_dir / 'manifest.json').write_text('{}')
dbt_tools.project_dir = str(tmp_path / 'dbt_project')
with patch('subprocess.run', return_value=mock_result):
result = await dbt_tools.dbt_docs_generate()
assert result['success'] is True
assert result['catalog_generated'] is True
assert result['manifest_generated'] is True
@pytest.mark.asyncio
async def test_dbt_lineage(dbt_tools, tmp_path):
"""Test dbt lineage"""
# Create manifest
target_dir = tmp_path / 'dbt_project' / 'target'
target_dir.mkdir(parents=True)
manifest = {
'nodes': {
'model.test.dim_customers': {
'name': 'dim_customers',
'resource_type': 'model',
'schema': 'public',
'database': 'testdb',
'description': 'Customer dimension',
'tags': ['daily'],
'config': {'materialized': 'table'},
'depends_on': {
'nodes': ['model.test.stg_customers']
}
},
'model.test.stg_customers': {
'name': 'stg_customers',
'resource_type': 'model',
'depends_on': {'nodes': []}
},
'model.test.fct_orders': {
'name': 'fct_orders',
'resource_type': 'model',
'depends_on': {
'nodes': ['model.test.dim_customers']
}
}
}
}
(target_dir / 'manifest.json').write_text(json.dumps(manifest))
dbt_tools.project_dir = str(tmp_path / 'dbt_project')
result = await dbt_tools.dbt_lineage('dim_customers')
assert result['model'] == 'dim_customers'
assert 'model.test.stg_customers' in result['upstream']
assert 'model.test.fct_orders' in result['downstream']
@pytest.mark.asyncio
async def test_dbt_lineage_model_not_found(dbt_tools, tmp_path):
"""Test dbt lineage with nonexistent model"""
target_dir = tmp_path / 'dbt_project' / 'target'
target_dir.mkdir(parents=True)
manifest = {
'nodes': {
'model.test.dim_customers': {
'name': 'dim_customers',
'resource_type': 'model'
}
}
}
(target_dir / 'manifest.json').write_text(json.dumps(manifest))
dbt_tools.project_dir = str(tmp_path / 'dbt_project')
result = await dbt_tools.dbt_lineage('nonexistent_model')
assert 'error' in result
assert 'not found' in result['error'].lower()
@pytest.mark.asyncio
async def test_dbt_no_project():
"""Test dbt tools when no project configured"""
with patch('mcp_server.dbt_tools.load_config', return_value={'dbt_project_dir': None}):
from mcp_server.dbt_tools import DbtTools
tools = DbtTools()
tools.project_dir = None
result = await tools.dbt_run()
assert 'error' in result
assert 'not found' in result['error'].lower()
@pytest.mark.asyncio
async def test_dbt_timeout(dbt_tools):
"""Test dbt command timeout handling"""
with patch('subprocess.run', side_effect=subprocess.TimeoutExpired('dbt', 300)):
result = await dbt_tools.dbt_parse()
assert 'error' in result
assert 'timed out' in result['error'].lower()
@pytest.mark.asyncio
async def test_dbt_not_installed(dbt_tools):
"""Test handling when dbt is not installed"""
with patch('subprocess.run', side_effect=FileNotFoundError()):
result = await dbt_tools.dbt_parse()
assert 'error' in result
assert 'not found' in result['error'].lower()

View File

@@ -0,0 +1,301 @@
"""
Unit tests for pandas MCP tools.
"""
import pytest
import pandas as pd
import tempfile
import os
from pathlib import Path
@pytest.fixture
def temp_csv(tmp_path):
"""Create a temporary CSV file for testing"""
csv_path = tmp_path / 'test.csv'
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'],
'value': [10.5, 20.0, 30.5, 40.0, 50.5]
})
df.to_csv(csv_path, index=False)
return str(csv_path)
@pytest.fixture
def temp_parquet(tmp_path):
"""Create a temporary Parquet file for testing"""
parquet_path = tmp_path / 'test.parquet'
df = pd.DataFrame({
'id': [1, 2, 3],
'data': ['a', 'b', 'c']
})
df.to_parquet(parquet_path)
return str(parquet_path)
@pytest.fixture
def temp_json(tmp_path):
"""Create a temporary JSON file for testing"""
json_path = tmp_path / 'test.json'
df = pd.DataFrame({
'x': [1, 2],
'y': [3, 4]
})
df.to_json(json_path, orient='records')
return str(json_path)
@pytest.fixture
def pandas_tools():
"""Create PandasTools instance with fresh store"""
from mcp_server.pandas_tools import PandasTools
from mcp_server.data_store import DataStore
# Reset store for test isolation
store = DataStore.get_instance()
store._dataframes = {}
store._metadata = {}
return PandasTools()
@pytest.mark.asyncio
async def test_read_csv(pandas_tools, temp_csv):
"""Test reading CSV file"""
result = await pandas_tools.read_csv(temp_csv, name='csv_test')
assert 'data_ref' in result
assert result['data_ref'] == 'csv_test'
assert result['rows'] == 5
assert 'id' in result['columns']
assert 'name' in result['columns']
@pytest.mark.asyncio
async def test_read_csv_nonexistent(pandas_tools):
"""Test reading nonexistent CSV file"""
result = await pandas_tools.read_csv('/nonexistent/path.csv')
assert 'error' in result
assert 'not found' in result['error'].lower()
@pytest.mark.asyncio
async def test_read_parquet(pandas_tools, temp_parquet):
"""Test reading Parquet file"""
result = await pandas_tools.read_parquet(temp_parquet, name='parquet_test')
assert 'data_ref' in result
assert result['rows'] == 3
@pytest.mark.asyncio
async def test_read_json(pandas_tools, temp_json):
"""Test reading JSON file"""
result = await pandas_tools.read_json(temp_json, name='json_test')
assert 'data_ref' in result
assert result['rows'] == 2
@pytest.mark.asyncio
async def test_to_csv(pandas_tools, temp_csv, tmp_path):
"""Test exporting to CSV"""
# First load some data
await pandas_tools.read_csv(temp_csv, name='export_test')
# Export to new file
output_path = str(tmp_path / 'output.csv')
result = await pandas_tools.to_csv('export_test', output_path)
assert result['success'] is True
assert os.path.exists(output_path)
@pytest.mark.asyncio
async def test_to_parquet(pandas_tools, temp_csv, tmp_path):
"""Test exporting to Parquet"""
await pandas_tools.read_csv(temp_csv, name='parquet_export')
output_path = str(tmp_path / 'output.parquet')
result = await pandas_tools.to_parquet('parquet_export', output_path)
assert result['success'] is True
assert os.path.exists(output_path)
@pytest.mark.asyncio
async def test_describe(pandas_tools, temp_csv):
"""Test describe statistics"""
await pandas_tools.read_csv(temp_csv, name='describe_test')
result = await pandas_tools.describe('describe_test')
assert 'data_ref' in result
assert 'shape' in result
assert result['shape']['rows'] == 5
assert 'statistics' in result
assert 'null_counts' in result
@pytest.mark.asyncio
async def test_head(pandas_tools, temp_csv):
"""Test getting first N rows"""
await pandas_tools.read_csv(temp_csv, name='head_test')
result = await pandas_tools.head('head_test', n=3)
assert result['returned_rows'] == 3
assert len(result['data']) == 3
@pytest.mark.asyncio
async def test_tail(pandas_tools, temp_csv):
"""Test getting last N rows"""
await pandas_tools.read_csv(temp_csv, name='tail_test')
result = await pandas_tools.tail('tail_test', n=2)
assert result['returned_rows'] == 2
@pytest.mark.asyncio
async def test_filter(pandas_tools, temp_csv):
"""Test filtering rows"""
await pandas_tools.read_csv(temp_csv, name='filter_test')
result = await pandas_tools.filter('filter_test', 'value > 25')
assert 'data_ref' in result
assert result['rows'] == 3 # 30.5, 40.0, 50.5
@pytest.mark.asyncio
async def test_filter_invalid_condition(pandas_tools, temp_csv):
"""Test filter with invalid condition"""
await pandas_tools.read_csv(temp_csv, name='filter_error')
result = await pandas_tools.filter('filter_error', 'invalid_column > 0')
assert 'error' in result
@pytest.mark.asyncio
async def test_select(pandas_tools, temp_csv):
"""Test selecting columns"""
await pandas_tools.read_csv(temp_csv, name='select_test')
result = await pandas_tools.select('select_test', ['id', 'name'])
assert 'data_ref' in result
assert result['columns'] == ['id', 'name']
@pytest.mark.asyncio
async def test_select_invalid_column(pandas_tools, temp_csv):
"""Test select with invalid column"""
await pandas_tools.read_csv(temp_csv, name='select_error')
result = await pandas_tools.select('select_error', ['id', 'nonexistent'])
assert 'error' in result
assert 'available_columns' in result
@pytest.mark.asyncio
async def test_groupby(pandas_tools, tmp_path):
"""Test groupby aggregation"""
# Create test data with groups
csv_path = tmp_path / 'groupby.csv'
df = pd.DataFrame({
'category': ['A', 'A', 'B', 'B'],
'value': [10, 20, 30, 40]
})
df.to_csv(csv_path, index=False)
await pandas_tools.read_csv(str(csv_path), name='groupby_test')
result = await pandas_tools.groupby(
'groupby_test',
by='category',
agg={'value': 'sum'}
)
assert 'data_ref' in result
assert result['rows'] == 2 # Two groups: A, B
@pytest.mark.asyncio
async def test_join(pandas_tools, tmp_path):
"""Test joining DataFrames"""
# Create left table
left_path = tmp_path / 'left.csv'
pd.DataFrame({
'id': [1, 2, 3],
'name': ['A', 'B', 'C']
}).to_csv(left_path, index=False)
# Create right table
right_path = tmp_path / 'right.csv'
pd.DataFrame({
'id': [1, 2, 4],
'value': [100, 200, 400]
}).to_csv(right_path, index=False)
await pandas_tools.read_csv(str(left_path), name='left')
await pandas_tools.read_csv(str(right_path), name='right')
result = await pandas_tools.join('left', 'right', on='id', how='inner')
assert 'data_ref' in result
assert result['rows'] == 2 # Only id 1 and 2 match
@pytest.mark.asyncio
async def test_list_data(pandas_tools, temp_csv):
"""Test listing all DataFrames"""
await pandas_tools.read_csv(temp_csv, name='list_test1')
await pandas_tools.read_csv(temp_csv, name='list_test2')
result = await pandas_tools.list_data()
assert result['count'] == 2
refs = [df['ref'] for df in result['dataframes']]
assert 'list_test1' in refs
assert 'list_test2' in refs
@pytest.mark.asyncio
async def test_drop_data(pandas_tools, temp_csv):
"""Test dropping DataFrame"""
await pandas_tools.read_csv(temp_csv, name='drop_test')
result = await pandas_tools.drop_data('drop_test')
assert result['success'] is True
# Verify it's gone
list_result = await pandas_tools.list_data()
refs = [df['ref'] for df in list_result['dataframes']]
assert 'drop_test' not in refs
@pytest.mark.asyncio
async def test_drop_nonexistent(pandas_tools):
"""Test dropping nonexistent DataFrame"""
result = await pandas_tools.drop_data('nonexistent')
assert 'error' in result
@pytest.mark.asyncio
async def test_operations_on_nonexistent(pandas_tools):
"""Test operations on nonexistent data_ref"""
result = await pandas_tools.describe('nonexistent')
assert 'error' in result
result = await pandas_tools.head('nonexistent')
assert 'error' in result
result = await pandas_tools.filter('nonexistent', 'x > 0')
assert 'error' in result

View File

@@ -0,0 +1,338 @@
"""
Unit tests for PostgreSQL MCP tools.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch, MagicMock
import pandas as pd
@pytest.fixture
def mock_config():
"""Mock configuration"""
return {
'postgres_url': 'postgresql://test:test@localhost:5432/testdb',
'max_rows': 100000
}
@pytest.fixture
def postgres_tools(mock_config):
"""Create PostgresTools instance with mocked config"""
with patch('mcp_server.postgres_tools.load_config', return_value=mock_config):
from mcp_server.postgres_tools import PostgresTools
from mcp_server.data_store import DataStore
# Reset store
store = DataStore.get_instance()
store._dataframes = {}
store._metadata = {}
tools = PostgresTools()
tools.config = mock_config
return tools
@pytest.mark.asyncio
async def test_pg_connect_no_config():
"""Test pg_connect when no PostgreSQL configured"""
with patch('mcp_server.postgres_tools.load_config', return_value={'postgres_url': None}):
from mcp_server.postgres_tools import PostgresTools
tools = PostgresTools()
tools.config = {'postgres_url': None}
result = await tools.pg_connect()
assert result['connected'] is False
assert 'not configured' in result['error'].lower()
@pytest.mark.asyncio
async def test_pg_connect_success(postgres_tools):
"""Test successful pg_connect"""
mock_conn = AsyncMock()
mock_conn.fetchval = AsyncMock(side_effect=[
'PostgreSQL 15.1', # version
'testdb', # database name
'testuser', # user
None # PostGIS check fails
])
mock_conn.close = AsyncMock()
# Create proper async context manager
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
mock_cm.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.acquire = MagicMock(return_value=mock_cm)
# Use AsyncMock for create_pool since it's awaited
with patch('asyncpg.create_pool', new=AsyncMock(return_value=mock_pool)):
postgres_tools.pool = None
result = await postgres_tools.pg_connect()
assert result['connected'] is True
assert result['database'] == 'testdb'
@pytest.mark.asyncio
async def test_pg_query_success(postgres_tools):
"""Test successful pg_query"""
mock_rows = [
{'id': 1, 'name': 'Alice'},
{'id': 2, 'name': 'Bob'}
]
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(return_value=mock_rows)
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_query('SELECT * FROM users', name='users_data')
assert 'data_ref' in result
assert result['rows'] == 2
@pytest.mark.asyncio
async def test_pg_query_empty_result(postgres_tools):
"""Test pg_query with no results"""
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(return_value=[])
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_query('SELECT * FROM empty_table')
assert result['data_ref'] is None
assert result['rows'] == 0
@pytest.mark.asyncio
async def test_pg_execute_success(postgres_tools):
"""Test successful pg_execute"""
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock(return_value='INSERT 0 3')
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_execute('INSERT INTO users VALUES (1, 2, 3)')
assert result['success'] is True
assert result['affected_rows'] == 3
assert result['command'] == 'INSERT'
@pytest.mark.asyncio
async def test_pg_tables(postgres_tools):
"""Test listing tables"""
mock_rows = [
{'table_name': 'users', 'table_type': 'BASE TABLE', 'column_count': 5},
{'table_name': 'orders', 'table_type': 'BASE TABLE', 'column_count': 8}
]
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(return_value=mock_rows)
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_tables(schema='public')
assert result['schema'] == 'public'
assert result['count'] == 2
assert len(result['tables']) == 2
@pytest.mark.asyncio
async def test_pg_columns(postgres_tools):
"""Test getting column info"""
mock_rows = [
{
'column_name': 'id',
'data_type': 'integer',
'udt_name': 'int4',
'is_nullable': 'NO',
'column_default': "nextval('users_id_seq'::regclass)",
'character_maximum_length': None,
'numeric_precision': 32
},
{
'column_name': 'name',
'data_type': 'character varying',
'udt_name': 'varchar',
'is_nullable': 'YES',
'column_default': None,
'character_maximum_length': 255,
'numeric_precision': None
}
]
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(return_value=mock_rows)
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_columns(table='users')
assert result['table'] == 'public.users'
assert result['column_count'] == 2
assert result['columns'][0]['name'] == 'id'
assert result['columns'][0]['nullable'] is False
@pytest.mark.asyncio
async def test_pg_schemas(postgres_tools):
"""Test listing schemas"""
mock_rows = [
{'schema_name': 'public'},
{'schema_name': 'app'}
]
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(return_value=mock_rows)
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_schemas()
assert result['count'] == 2
assert 'public' in result['schemas']
@pytest.mark.asyncio
async def test_st_tables(postgres_tools):
"""Test listing PostGIS tables"""
mock_rows = [
{
'table_name': 'locations',
'geometry_column': 'geom',
'geometry_type': 'POINT',
'srid': 4326,
'coord_dimension': 2
}
]
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(return_value=mock_rows)
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.st_tables()
assert result['count'] == 1
assert result['postgis_tables'][0]['table'] == 'locations'
assert result['postgis_tables'][0]['srid'] == 4326
@pytest.mark.asyncio
async def test_st_tables_no_postgis(postgres_tools):
"""Test st_tables when PostGIS not installed"""
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(side_effect=Exception("relation \"geometry_columns\" does not exist"))
# Create proper async context manager
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
mock_cm.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.acquire = MagicMock(return_value=mock_cm)
postgres_tools.pool = mock_pool
result = await postgres_tools.st_tables()
assert 'error' in result
assert 'PostGIS' in result['error']
@pytest.mark.asyncio
async def test_st_extent(postgres_tools):
"""Test getting geometry bounding box"""
mock_row = {
'xmin': -122.5,
'ymin': 37.5,
'xmax': -122.0,
'ymax': 38.0
}
mock_conn = AsyncMock()
mock_conn.fetchrow = AsyncMock(return_value=mock_row)
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.st_extent(table='locations', column='geom')
assert result['bbox']['xmin'] == -122.5
assert result['bbox']['ymax'] == 38.0
@pytest.mark.asyncio
async def test_error_handling(postgres_tools):
"""Test error handling for database errors"""
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(side_effect=Exception("Connection refused"))
# Create proper async context manager
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
mock_cm.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.acquire = MagicMock(return_value=mock_cm)
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_query('SELECT 1')
assert 'error' in result
assert 'Connection refused' in result['error']