feat: add data-platform plugin (v4.0.0)

Add new data-platform plugin for data engineering workflows with:

MCP Server (32 tools):
- pandas operations (14 tools): read_csv, read_parquet, read_json,
  to_csv, to_parquet, describe, head, tail, filter, select, groupby,
  join, list_data, drop_data
- PostgreSQL/PostGIS (10 tools): pg_connect, pg_query, pg_execute,
  pg_tables, pg_columns, pg_schemas, st_tables, st_geometry_type,
  st_srid, st_extent
- dbt integration (8 tools): dbt_parse, dbt_run, dbt_test, dbt_build,
  dbt_compile, dbt_ls, dbt_docs_generate, dbt_lineage

Plugin Features:
- Arrow IPC data_ref system for DataFrame persistence across tool calls
- Pre-execution validation for dbt with `dbt parse`
- SessionStart hook for PostgreSQL connectivity check (non-blocking)
- Hybrid configuration (system ~/.config/claude/postgres.env + project .env)
- Memory management with 100k row limit and chunking support

Commands: /initial-setup, /ingest, /profile, /schema, /explain, /lineage, /run
Agents: data-ingestion, data-analysis

Test suite: 71 tests covering config, data store, pandas, postgres, dbt tools

Addresses data workflow issues from personal-portfolio project:
- Lost data after multiple interactions (solved by Arrow IPC data_ref)
- dbt 1.9+ syntax deprecation (solved by pre-execution validation)
- Ungraceful PostgreSQL error handling (solved by SessionStart hook)

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
2026-01-25 14:24:03 -05:00
parent 6a267d074b
commit 89f0354ccc
39 changed files with 5413 additions and 6 deletions

View File

@@ -0,0 +1,131 @@
# Data Platform MCP Server
MCP Server providing pandas, PostgreSQL/PostGIS, and dbt tools for Claude Code.
## Features
- **pandas Tools**: DataFrame operations with Arrow IPC data_ref persistence
- **PostgreSQL Tools**: Database queries with asyncpg connection pooling
- **PostGIS Tools**: Spatial data operations
- **dbt Tools**: Build tool wrapper with pre-execution validation
## Installation
```bash
cd mcp-servers/data-platform
python -m venv .venv
source .venv/bin/activate # On Windows: .venv\Scripts\activate
pip install -r requirements.txt
```
## Configuration
### System-Level (PostgreSQL credentials)
Create `~/.config/claude/postgres.env`:
```env
POSTGRES_URL=postgresql://user:password@host:5432/database
```
### Project-Level (dbt paths)
Create `.env` in your project root:
```env
DBT_PROJECT_DIR=/path/to/dbt/project
DBT_PROFILES_DIR=/path/to/.dbt
DATA_PLATFORM_MAX_ROWS=100000
```
## Tools
### pandas Tools (14 tools)
| Tool | Description |
|------|-------------|
| `read_csv` | Load CSV file into DataFrame |
| `read_parquet` | Load Parquet file into DataFrame |
| `read_json` | Load JSON/JSONL file into DataFrame |
| `to_csv` | Export DataFrame to CSV file |
| `to_parquet` | Export DataFrame to Parquet file |
| `describe` | Get statistical summary of DataFrame |
| `head` | Get first N rows of DataFrame |
| `tail` | Get last N rows of DataFrame |
| `filter` | Filter DataFrame rows by condition |
| `select` | Select specific columns from DataFrame |
| `groupby` | Group DataFrame and aggregate |
| `join` | Join two DataFrames |
| `list_data` | List all stored DataFrames |
| `drop_data` | Remove a DataFrame from storage |
### PostgreSQL Tools (6 tools)
| Tool | Description |
|------|-------------|
| `pg_connect` | Test connection and return status |
| `pg_query` | Execute SELECT, return as data_ref |
| `pg_execute` | Execute INSERT/UPDATE/DELETE |
| `pg_tables` | List all tables in schema |
| `pg_columns` | Get column info for table |
| `pg_schemas` | List all schemas |
### PostGIS Tools (4 tools)
| Tool | Description |
|------|-------------|
| `st_tables` | List PostGIS-enabled tables |
| `st_geometry_type` | Get geometry type of column |
| `st_srid` | Get SRID of geometry column |
| `st_extent` | Get bounding box of geometries |
### dbt Tools (8 tools)
| Tool | Description |
|------|-------------|
| `dbt_parse` | Validate project (pre-execution) |
| `dbt_run` | Run models with selection |
| `dbt_test` | Run tests |
| `dbt_build` | Run + test |
| `dbt_compile` | Compile SQL without executing |
| `dbt_ls` | List resources |
| `dbt_docs_generate` | Generate documentation |
| `dbt_lineage` | Get model dependencies |
## data_ref System
All DataFrame operations use a `data_ref` system to persist data across tool calls:
1. **Load data**: Returns a `data_ref` string (e.g., `"df_a1b2c3d4"`)
2. **Use data_ref**: Pass to other tools (filter, join, export)
3. **List data**: Use `list_data` to see all stored DataFrames
4. **Clean up**: Use `drop_data` when done
### Example Flow
```
read_csv("data.csv") → {"data_ref": "sales_data", "rows": 1000}
filter("sales_data", "amount > 100") → {"data_ref": "sales_data_filtered"}
describe("sales_data_filtered") → {statistics}
to_parquet("sales_data_filtered", "output.parquet") → {success}
```
## Memory Management
- Default row limit: 100,000 rows per DataFrame
- Configure via `DATA_PLATFORM_MAX_ROWS` environment variable
- Use chunked processing for large files (`chunk_size` parameter)
- Monitor with `list_data` tool (shows memory usage)
## Running
```bash
python -m mcp_server.server
```
## Development
```bash
pip install -e ".[dev]"
pytest
```

View File

@@ -0,0 +1,7 @@
"""
Data Platform MCP Server.
Provides pandas, PostgreSQL/PostGIS, and dbt tools to Claude Code via MCP.
"""
__version__ = "1.0.0"

View File

@@ -0,0 +1,195 @@
"""
Configuration loader for Data Platform MCP Server.
Implements hybrid configuration system:
- System-level: ~/.config/claude/postgres.env (credentials)
- Project-level: .env (dbt project paths, overrides)
- Auto-detection: dbt_project.yml discovery
"""
from pathlib import Path
from dotenv import load_dotenv
import os
import logging
from typing import Dict, Optional
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DataPlatformConfig:
"""Hybrid configuration loader for data platform tools"""
def __init__(self):
self.postgres_url: Optional[str] = None
self.dbt_project_dir: Optional[str] = None
self.dbt_profiles_dir: Optional[str] = None
self.max_rows: int = 100_000
def load(self) -> Dict[str, Optional[str]]:
"""
Load configuration from system and project levels.
Returns:
Dict containing postgres_url, dbt_project_dir, dbt_profiles_dir, max_rows
Note:
PostgreSQL credentials are optional - server can run in pandas-only mode.
"""
# Load system config (PostgreSQL credentials)
system_config = Path.home() / '.config' / 'claude' / 'postgres.env'
if system_config.exists():
load_dotenv(system_config)
logger.info(f"Loaded system configuration from {system_config}")
else:
logger.info(
f"System config not found: {system_config} - "
"PostgreSQL tools will be unavailable"
)
# Find project directory
project_dir = self._find_project_directory()
# Load project config (overrides system)
if project_dir:
project_config = project_dir / '.env'
if project_config.exists():
load_dotenv(project_config, override=True)
logger.info(f"Loaded project configuration from {project_config}")
# Extract values
self.postgres_url = os.getenv('POSTGRES_URL')
self.dbt_project_dir = os.getenv('DBT_PROJECT_DIR')
self.dbt_profiles_dir = os.getenv('DBT_PROFILES_DIR')
self.max_rows = int(os.getenv('DATA_PLATFORM_MAX_ROWS', '100000'))
# Auto-detect dbt project if not specified
if not self.dbt_project_dir and project_dir:
self.dbt_project_dir = self._find_dbt_project(project_dir)
if self.dbt_project_dir:
logger.info(f"Auto-detected dbt project: {self.dbt_project_dir}")
# Default dbt profiles dir to ~/.dbt
if not self.dbt_profiles_dir:
default_profiles = Path.home() / '.dbt'
if default_profiles.exists():
self.dbt_profiles_dir = str(default_profiles)
return {
'postgres_url': self.postgres_url,
'dbt_project_dir': self.dbt_project_dir,
'dbt_profiles_dir': self.dbt_profiles_dir,
'max_rows': self.max_rows,
'postgres_available': self.postgres_url is not None,
'dbt_available': self.dbt_project_dir is not None
}
def _find_project_directory(self) -> Optional[Path]:
"""
Find the user's project directory.
Returns:
Path to project directory, or None if not found
"""
# Strategy 1: Check CLAUDE_PROJECT_DIR environment variable
project_dir = os.getenv('CLAUDE_PROJECT_DIR')
if project_dir:
path = Path(project_dir)
if path.exists():
logger.info(f"Found project directory from CLAUDE_PROJECT_DIR: {path}")
return path
# Strategy 2: Check PWD
pwd = os.getenv('PWD')
if pwd:
path = Path(pwd)
if path.exists() and (
(path / '.git').exists() or
(path / '.env').exists() or
(path / 'dbt_project.yml').exists()
):
logger.info(f"Found project directory from PWD: {path}")
return path
# Strategy 3: Check current working directory
cwd = Path.cwd()
if (cwd / '.git').exists() or (cwd / '.env').exists() or (cwd / 'dbt_project.yml').exists():
logger.info(f"Found project directory from cwd: {cwd}")
return cwd
logger.debug("Could not determine project directory")
return None
def _find_dbt_project(self, start_dir: Path) -> Optional[str]:
"""
Find dbt_project.yml in the project or its subdirectories.
Args:
start_dir: Directory to start searching from
Returns:
Path to dbt project directory, or None if not found
"""
# Check root
if (start_dir / 'dbt_project.yml').exists():
return str(start_dir)
# Check common subdirectories
for subdir in ['dbt', 'transform', 'analytics', 'models']:
candidate = start_dir / subdir
if (candidate / 'dbt_project.yml').exists():
return str(candidate)
# Search one level deep
for item in start_dir.iterdir():
if item.is_dir() and not item.name.startswith('.'):
if (item / 'dbt_project.yml').exists():
return str(item)
return None
def load_config() -> Dict[str, Optional[str]]:
"""
Convenience function to load configuration.
Returns:
Configuration dictionary
"""
config = DataPlatformConfig()
return config.load()
def check_postgres_connection() -> Dict[str, any]:
"""
Check PostgreSQL connection status for SessionStart hook.
Returns:
Dict with connection status and message
"""
import asyncio
config = load_config()
if not config.get('postgres_url'):
return {
'connected': False,
'message': 'PostgreSQL not configured (POSTGRES_URL not set)'
}
async def test_connection():
try:
import asyncpg
conn = await asyncpg.connect(config['postgres_url'], timeout=5)
version = await conn.fetchval('SELECT version()')
await conn.close()
return {
'connected': True,
'message': f'Connected to PostgreSQL',
'version': version.split(',')[0] if version else 'Unknown'
}
except Exception as e:
return {
'connected': False,
'message': f'PostgreSQL connection failed: {str(e)}'
}
return asyncio.run(test_connection())

View File

@@ -0,0 +1,219 @@
"""
Arrow IPC DataFrame Registry.
Provides persistent storage for DataFrames across tool calls using Apache Arrow
for efficient memory management and serialization.
"""
import pyarrow as pa
import pandas as pd
import uuid
import logging
from typing import Dict, Optional, List, Union
from dataclasses import dataclass
from datetime import datetime
logger = logging.getLogger(__name__)
@dataclass
class DataFrameInfo:
"""Metadata about a stored DataFrame"""
ref: str
rows: int
columns: int
column_names: List[str]
dtypes: Dict[str, str]
memory_bytes: int
created_at: datetime
source: Optional[str] = None
class DataStore:
"""
Singleton registry for Arrow Tables (DataFrames).
Uses Arrow IPC format for efficient memory usage and supports
data_ref based retrieval across multiple tool calls.
"""
_instance = None
_dataframes: Dict[str, pa.Table] = {}
_metadata: Dict[str, DataFrameInfo] = {}
_max_rows: int = 100_000
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._dataframes = {}
cls._metadata = {}
return cls._instance
@classmethod
def get_instance(cls) -> 'DataStore':
"""Get the singleton instance"""
if cls._instance is None:
cls._instance = cls()
return cls._instance
@classmethod
def set_max_rows(cls, max_rows: int):
"""Set the maximum rows limit"""
cls._max_rows = max_rows
def store(
self,
data: Union[pa.Table, pd.DataFrame],
name: Optional[str] = None,
source: Optional[str] = None
) -> str:
"""
Store a DataFrame and return its reference.
Args:
data: Arrow Table or pandas DataFrame
name: Optional name for the reference (auto-generated if not provided)
source: Optional source description (e.g., file path, query)
Returns:
data_ref string to retrieve the DataFrame later
"""
# Convert pandas to Arrow if needed
if isinstance(data, pd.DataFrame):
table = pa.Table.from_pandas(data)
else:
table = data
# Generate reference
data_ref = name or f"df_{uuid.uuid4().hex[:8]}"
# Ensure unique reference
if data_ref in self._dataframes and name is None:
data_ref = f"{data_ref}_{uuid.uuid4().hex[:4]}"
# Store table
self._dataframes[data_ref] = table
# Store metadata
schema = table.schema
self._metadata[data_ref] = DataFrameInfo(
ref=data_ref,
rows=table.num_rows,
columns=table.num_columns,
column_names=[f.name for f in schema],
dtypes={f.name: str(f.type) for f in schema},
memory_bytes=table.nbytes,
created_at=datetime.now(),
source=source
)
logger.info(f"Stored DataFrame '{data_ref}': {table.num_rows} rows, {table.num_columns} cols")
return data_ref
def get(self, data_ref: str) -> Optional[pa.Table]:
"""
Retrieve an Arrow Table by reference.
Args:
data_ref: Reference string from store()
Returns:
Arrow Table or None if not found
"""
return self._dataframes.get(data_ref)
def get_pandas(self, data_ref: str) -> Optional[pd.DataFrame]:
"""
Retrieve a DataFrame as pandas.
Args:
data_ref: Reference string from store()
Returns:
pandas DataFrame or None if not found
"""
table = self.get(data_ref)
if table is not None:
return table.to_pandas()
return None
def get_info(self, data_ref: str) -> Optional[DataFrameInfo]:
"""
Get metadata about a stored DataFrame.
Args:
data_ref: Reference string
Returns:
DataFrameInfo or None if not found
"""
return self._metadata.get(data_ref)
def list_refs(self) -> List[Dict]:
"""
List all stored DataFrame references with metadata.
Returns:
List of dicts with ref, rows, columns, memory info
"""
result = []
for ref, info in self._metadata.items():
result.append({
'ref': ref,
'rows': info.rows,
'columns': info.columns,
'column_names': info.column_names,
'memory_mb': round(info.memory_bytes / (1024 * 1024), 2),
'source': info.source,
'created_at': info.created_at.isoformat()
})
return result
def drop(self, data_ref: str) -> bool:
"""
Remove a DataFrame from the store.
Args:
data_ref: Reference string
Returns:
True if removed, False if not found
"""
if data_ref in self._dataframes:
del self._dataframes[data_ref]
del self._metadata[data_ref]
logger.info(f"Dropped DataFrame '{data_ref}'")
return True
return False
def clear(self):
"""Remove all stored DataFrames"""
count = len(self._dataframes)
self._dataframes.clear()
self._metadata.clear()
logger.info(f"Cleared {count} DataFrames from store")
def total_memory_bytes(self) -> int:
"""Get total memory used by all stored DataFrames"""
return sum(info.memory_bytes for info in self._metadata.values())
def total_memory_mb(self) -> float:
"""Get total memory in MB"""
return round(self.total_memory_bytes() / (1024 * 1024), 2)
def check_row_limit(self, row_count: int) -> Dict:
"""
Check if row count exceeds limit.
Args:
row_count: Number of rows
Returns:
Dict with 'exceeded' bool and 'message' if exceeded
"""
if row_count > self._max_rows:
return {
'exceeded': True,
'message': f"Row count ({row_count:,}) exceeds limit ({self._max_rows:,})",
'suggestion': f"Use chunked processing or filter data first",
'limit': self._max_rows
}
return {'exceeded': False}

View File

@@ -0,0 +1,387 @@
"""
dbt MCP Tools.
Provides dbt CLI wrapper with pre-execution validation.
"""
import subprocess
import json
import logging
import os
from pathlib import Path
from typing import Dict, List, Optional, Any
from .config import load_config
logger = logging.getLogger(__name__)
class DbtTools:
"""dbt CLI wrapper tools with pre-validation"""
def __init__(self):
self.config = load_config()
self.project_dir = self.config.get('dbt_project_dir')
self.profiles_dir = self.config.get('dbt_profiles_dir')
def _get_dbt_command(self, cmd: List[str]) -> List[str]:
"""Build dbt command with project and profiles directories"""
base = ['dbt']
if self.project_dir:
base.extend(['--project-dir', self.project_dir])
if self.profiles_dir:
base.extend(['--profiles-dir', self.profiles_dir])
base.extend(cmd)
return base
def _run_dbt(
self,
cmd: List[str],
timeout: int = 300,
capture_json: bool = False
) -> Dict:
"""
Run dbt command and return result.
Args:
cmd: dbt subcommand and arguments
timeout: Command timeout in seconds
capture_json: If True, parse JSON output
Returns:
Dict with command result
"""
if not self.project_dir:
return {
'error': 'dbt project not found',
'suggestion': 'Set DBT_PROJECT_DIR in project .env or ensure dbt_project.yml exists'
}
full_cmd = self._get_dbt_command(cmd)
logger.info(f"Running: {' '.join(full_cmd)}")
try:
env = os.environ.copy()
# Disable dbt analytics/tracking
env['DBT_SEND_ANONYMOUS_USAGE_STATS'] = 'false'
result = subprocess.run(
full_cmd,
capture_output=True,
text=True,
timeout=timeout,
cwd=self.project_dir,
env=env
)
output = {
'success': result.returncode == 0,
'command': ' '.join(cmd),
'stdout': result.stdout,
'stderr': result.stderr if result.returncode != 0 else None
}
if capture_json and result.returncode == 0:
try:
output['data'] = json.loads(result.stdout)
except json.JSONDecodeError:
pass
return output
except subprocess.TimeoutExpired:
return {
'error': f'Command timed out after {timeout}s',
'command': ' '.join(cmd)
}
except FileNotFoundError:
return {
'error': 'dbt not found in PATH',
'suggestion': 'Install dbt: pip install dbt-core dbt-postgres'
}
except Exception as e:
logger.error(f"dbt command failed: {e}")
return {'error': str(e)}
async def dbt_parse(self) -> Dict:
"""
Validate dbt project without executing (pre-flight check).
Returns:
Dict with validation result and any errors
"""
result = self._run_dbt(['parse'])
# Check if _run_dbt returned an error (e.g., project not found, timeout, dbt not installed)
if 'error' in result:
return result
if not result.get('success'):
# Extract useful error info from stderr
stderr = result.get('stderr', '') or result.get('stdout', '')
errors = []
# Look for common dbt 1.9+ deprecation warnings
if 'deprecated' in stderr.lower():
errors.append({
'type': 'deprecation',
'message': 'Deprecated syntax found - check dbt 1.9+ migration guide'
})
# Look for compilation errors
if 'compilation error' in stderr.lower():
errors.append({
'type': 'compilation',
'message': 'SQL compilation error - check model syntax'
})
return {
'valid': False,
'errors': errors,
'details': stderr[:2000] if stderr else None,
'suggestion': 'Fix issues before running dbt models'
}
return {
'valid': True,
'message': 'dbt project validation passed'
}
async def dbt_run(
self,
select: Optional[str] = None,
exclude: Optional[str] = None,
full_refresh: bool = False
) -> Dict:
"""
Run dbt models with pre-validation.
Args:
select: Model selection (e.g., "model_name", "+model_name", "tag:daily")
exclude: Models to exclude
full_refresh: If True, rebuild incremental models
Returns:
Dict with run result
"""
# ALWAYS validate first
parse_result = await self.dbt_parse()
if not parse_result.get('valid'):
return {
'error': 'Pre-validation failed',
**parse_result
}
cmd = ['run']
if select:
cmd.extend(['--select', select])
if exclude:
cmd.extend(['--exclude', exclude])
if full_refresh:
cmd.append('--full-refresh')
return self._run_dbt(cmd)
async def dbt_test(
self,
select: Optional[str] = None,
exclude: Optional[str] = None
) -> Dict:
"""
Run dbt tests.
Args:
select: Test selection
exclude: Tests to exclude
Returns:
Dict with test results
"""
cmd = ['test']
if select:
cmd.extend(['--select', select])
if exclude:
cmd.extend(['--exclude', exclude])
return self._run_dbt(cmd)
async def dbt_build(
self,
select: Optional[str] = None,
exclude: Optional[str] = None,
full_refresh: bool = False
) -> Dict:
"""
Run dbt build (run + test) with pre-validation.
Args:
select: Model/test selection
exclude: Resources to exclude
full_refresh: If True, rebuild incremental models
Returns:
Dict with build result
"""
# ALWAYS validate first
parse_result = await self.dbt_parse()
if not parse_result.get('valid'):
return {
'error': 'Pre-validation failed',
**parse_result
}
cmd = ['build']
if select:
cmd.extend(['--select', select])
if exclude:
cmd.extend(['--exclude', exclude])
if full_refresh:
cmd.append('--full-refresh')
return self._run_dbt(cmd)
async def dbt_compile(
self,
select: Optional[str] = None
) -> Dict:
"""
Compile dbt models to SQL without executing.
Args:
select: Model selection
Returns:
Dict with compiled SQL info
"""
cmd = ['compile']
if select:
cmd.extend(['--select', select])
return self._run_dbt(cmd)
async def dbt_ls(
self,
select: Optional[str] = None,
resource_type: Optional[str] = None,
output: str = 'name'
) -> Dict:
"""
List dbt resources.
Args:
select: Resource selection
resource_type: Filter by type (model, test, seed, snapshot, source)
output: Output format ('name', 'path', 'json')
Returns:
Dict with list of resources
"""
cmd = ['ls', '--output', output]
if select:
cmd.extend(['--select', select])
if resource_type:
cmd.extend(['--resource-type', resource_type])
result = self._run_dbt(cmd)
if result.get('success') and result.get('stdout'):
lines = [l.strip() for l in result['stdout'].split('\n') if l.strip()]
result['resources'] = lines
result['count'] = len(lines)
return result
async def dbt_docs_generate(self) -> Dict:
"""
Generate dbt documentation.
Returns:
Dict with generation result
"""
result = self._run_dbt(['docs', 'generate'])
if result.get('success') and self.project_dir:
# Check for generated catalog
catalog_path = Path(self.project_dir) / 'target' / 'catalog.json'
manifest_path = Path(self.project_dir) / 'target' / 'manifest.json'
result['catalog_generated'] = catalog_path.exists()
result['manifest_generated'] = manifest_path.exists()
return result
async def dbt_lineage(self, model: str) -> Dict:
"""
Get model dependencies and lineage.
Args:
model: Model name to analyze
Returns:
Dict with upstream and downstream dependencies
"""
if not self.project_dir:
return {'error': 'dbt project not found'}
manifest_path = Path(self.project_dir) / 'target' / 'manifest.json'
# Generate manifest if not exists
if not manifest_path.exists():
compile_result = await self.dbt_compile(select=model)
if not compile_result.get('success'):
return {
'error': 'Failed to compile manifest',
'details': compile_result
}
if not manifest_path.exists():
return {
'error': 'Manifest not found',
'suggestion': 'Run dbt compile first'
}
try:
with open(manifest_path) as f:
manifest = json.load(f)
# Find the model node
model_key = None
for key in manifest.get('nodes', {}):
if key.endswith(f'.{model}') or manifest['nodes'][key].get('name') == model:
model_key = key
break
if not model_key:
return {
'error': f'Model not found: {model}',
'available_models': [
n.get('name') for n in manifest.get('nodes', {}).values()
if n.get('resource_type') == 'model'
][:20]
}
node = manifest['nodes'][model_key]
# Get upstream (depends_on)
upstream = node.get('depends_on', {}).get('nodes', [])
# Get downstream (find nodes that depend on this one)
downstream = []
for key, other_node in manifest.get('nodes', {}).items():
deps = other_node.get('depends_on', {}).get('nodes', [])
if model_key in deps:
downstream.append(key)
return {
'model': model,
'unique_id': model_key,
'materialization': node.get('config', {}).get('materialized'),
'schema': node.get('schema'),
'database': node.get('database'),
'upstream': upstream,
'downstream': downstream,
'description': node.get('description'),
'tags': node.get('tags', [])
}
except Exception as e:
logger.error(f"dbt_lineage failed: {e}")
return {'error': str(e)}

View File

@@ -0,0 +1,500 @@
"""
pandas MCP Tools.
Provides DataFrame operations with Arrow IPC data_ref persistence.
"""
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional, Any, Union
from .data_store import DataStore
from .config import load_config
logger = logging.getLogger(__name__)
class PandasTools:
"""pandas data manipulation tools with data_ref persistence"""
def __init__(self):
self.store = DataStore.get_instance()
config = load_config()
self.max_rows = config.get('max_rows', 100_000)
self.store.set_max_rows(self.max_rows)
def _check_and_store(
self,
df: pd.DataFrame,
name: Optional[str] = None,
source: Optional[str] = None
) -> Dict:
"""Check row limit and store DataFrame if within limits"""
check = self.store.check_row_limit(len(df))
if check['exceeded']:
return {
'error': 'row_limit_exceeded',
**check,
'preview': df.head(100).to_dict(orient='records')
}
data_ref = self.store.store(df, name=name, source=source)
return {
'data_ref': data_ref,
'rows': len(df),
'columns': list(df.columns),
'dtypes': {col: str(dtype) for col, dtype in df.dtypes.items()}
}
async def read_csv(
self,
file_path: str,
name: Optional[str] = None,
chunk_size: Optional[int] = None,
**kwargs
) -> Dict:
"""
Load CSV file into DataFrame.
Args:
file_path: Path to CSV file
name: Optional name for data_ref
chunk_size: If provided, process in chunks
**kwargs: Additional pandas read_csv arguments
Returns:
Dict with data_ref or error info
"""
path = Path(file_path)
if not path.exists():
return {'error': f'File not found: {file_path}'}
try:
if chunk_size:
# Chunked processing - return iterator info
chunks = []
for i, chunk in enumerate(pd.read_csv(path, chunksize=chunk_size, **kwargs)):
chunk_ref = self.store.store(chunk, name=f"{name or 'chunk'}_{i}", source=file_path)
chunks.append({'ref': chunk_ref, 'rows': len(chunk)})
return {
'chunked': True,
'chunks': chunks,
'total_chunks': len(chunks)
}
df = pd.read_csv(path, **kwargs)
return self._check_and_store(df, name=name, source=file_path)
except Exception as e:
logger.error(f"read_csv failed: {e}")
return {'error': str(e)}
async def read_parquet(
self,
file_path: str,
name: Optional[str] = None,
columns: Optional[List[str]] = None
) -> Dict:
"""
Load Parquet file into DataFrame.
Args:
file_path: Path to Parquet file
name: Optional name for data_ref
columns: Optional list of columns to load
Returns:
Dict with data_ref or error info
"""
path = Path(file_path)
if not path.exists():
return {'error': f'File not found: {file_path}'}
try:
table = pq.read_table(path, columns=columns)
df = table.to_pandas()
return self._check_and_store(df, name=name, source=file_path)
except Exception as e:
logger.error(f"read_parquet failed: {e}")
return {'error': str(e)}
async def read_json(
self,
file_path: str,
name: Optional[str] = None,
lines: bool = False,
**kwargs
) -> Dict:
"""
Load JSON/JSONL file into DataFrame.
Args:
file_path: Path to JSON file
name: Optional name for data_ref
lines: If True, read as JSON Lines format
**kwargs: Additional pandas read_json arguments
Returns:
Dict with data_ref or error info
"""
path = Path(file_path)
if not path.exists():
return {'error': f'File not found: {file_path}'}
try:
df = pd.read_json(path, lines=lines, **kwargs)
return self._check_and_store(df, name=name, source=file_path)
except Exception as e:
logger.error(f"read_json failed: {e}")
return {'error': str(e)}
async def to_csv(
self,
data_ref: str,
file_path: str,
index: bool = False,
**kwargs
) -> Dict:
"""
Export DataFrame to CSV file.
Args:
data_ref: Reference to stored DataFrame
file_path: Output file path
index: Whether to include index
**kwargs: Additional pandas to_csv arguments
Returns:
Dict with success status
"""
df = self.store.get_pandas(data_ref)
if df is None:
return {'error': f'DataFrame not found: {data_ref}'}
try:
df.to_csv(file_path, index=index, **kwargs)
return {
'success': True,
'file_path': file_path,
'rows': len(df),
'size_bytes': Path(file_path).stat().st_size
}
except Exception as e:
logger.error(f"to_csv failed: {e}")
return {'error': str(e)}
async def to_parquet(
self,
data_ref: str,
file_path: str,
compression: str = 'snappy'
) -> Dict:
"""
Export DataFrame to Parquet file.
Args:
data_ref: Reference to stored DataFrame
file_path: Output file path
compression: Compression codec
Returns:
Dict with success status
"""
table = self.store.get(data_ref)
if table is None:
return {'error': f'DataFrame not found: {data_ref}'}
try:
pq.write_table(table, file_path, compression=compression)
return {
'success': True,
'file_path': file_path,
'rows': table.num_rows,
'size_bytes': Path(file_path).stat().st_size
}
except Exception as e:
logger.error(f"to_parquet failed: {e}")
return {'error': str(e)}
async def describe(self, data_ref: str) -> Dict:
"""
Get statistical summary of DataFrame.
Args:
data_ref: Reference to stored DataFrame
Returns:
Dict with statistical summary
"""
df = self.store.get_pandas(data_ref)
if df is None:
return {'error': f'DataFrame not found: {data_ref}'}
try:
desc = df.describe(include='all')
info = self.store.get_info(data_ref)
return {
'data_ref': data_ref,
'shape': {'rows': len(df), 'columns': len(df.columns)},
'columns': list(df.columns),
'dtypes': {col: str(dtype) for col, dtype in df.dtypes.items()},
'memory_mb': info.memory_bytes / (1024 * 1024) if info else None,
'null_counts': df.isnull().sum().to_dict(),
'statistics': desc.to_dict()
}
except Exception as e:
logger.error(f"describe failed: {e}")
return {'error': str(e)}
async def head(self, data_ref: str, n: int = 10) -> Dict:
"""
Get first N rows of DataFrame.
Args:
data_ref: Reference to stored DataFrame
n: Number of rows
Returns:
Dict with rows as records
"""
df = self.store.get_pandas(data_ref)
if df is None:
return {'error': f'DataFrame not found: {data_ref}'}
try:
head_df = df.head(n)
return {
'data_ref': data_ref,
'total_rows': len(df),
'returned_rows': len(head_df),
'columns': list(df.columns),
'data': head_df.to_dict(orient='records')
}
except Exception as e:
logger.error(f"head failed: {e}")
return {'error': str(e)}
async def tail(self, data_ref: str, n: int = 10) -> Dict:
"""
Get last N rows of DataFrame.
Args:
data_ref: Reference to stored DataFrame
n: Number of rows
Returns:
Dict with rows as records
"""
df = self.store.get_pandas(data_ref)
if df is None:
return {'error': f'DataFrame not found: {data_ref}'}
try:
tail_df = df.tail(n)
return {
'data_ref': data_ref,
'total_rows': len(df),
'returned_rows': len(tail_df),
'columns': list(df.columns),
'data': tail_df.to_dict(orient='records')
}
except Exception as e:
logger.error(f"tail failed: {e}")
return {'error': str(e)}
async def filter(
self,
data_ref: str,
condition: str,
name: Optional[str] = None
) -> Dict:
"""
Filter DataFrame rows by condition.
Args:
data_ref: Reference to stored DataFrame
condition: pandas query string (e.g., "age > 30 and city == 'NYC'")
name: Optional name for result data_ref
Returns:
Dict with new data_ref for filtered result
"""
df = self.store.get_pandas(data_ref)
if df is None:
return {'error': f'DataFrame not found: {data_ref}'}
try:
filtered = df.query(condition)
result_name = name or f"{data_ref}_filtered"
return self._check_and_store(
filtered,
name=result_name,
source=f"filter({data_ref}, '{condition}')"
)
except Exception as e:
logger.error(f"filter failed: {e}")
return {'error': str(e)}
async def select(
self,
data_ref: str,
columns: List[str],
name: Optional[str] = None
) -> Dict:
"""
Select specific columns from DataFrame.
Args:
data_ref: Reference to stored DataFrame
columns: List of column names to select
name: Optional name for result data_ref
Returns:
Dict with new data_ref for selected columns
"""
df = self.store.get_pandas(data_ref)
if df is None:
return {'error': f'DataFrame not found: {data_ref}'}
try:
# Validate columns exist
missing = [c for c in columns if c not in df.columns]
if missing:
return {
'error': f'Columns not found: {missing}',
'available_columns': list(df.columns)
}
selected = df[columns]
result_name = name or f"{data_ref}_select"
return self._check_and_store(
selected,
name=result_name,
source=f"select({data_ref}, {columns})"
)
except Exception as e:
logger.error(f"select failed: {e}")
return {'error': str(e)}
async def groupby(
self,
data_ref: str,
by: Union[str, List[str]],
agg: Dict[str, Union[str, List[str]]],
name: Optional[str] = None
) -> Dict:
"""
Group DataFrame and aggregate.
Args:
data_ref: Reference to stored DataFrame
by: Column(s) to group by
agg: Aggregation dict (e.g., {"sales": "sum", "count": "mean"})
name: Optional name for result data_ref
Returns:
Dict with new data_ref for aggregated result
"""
df = self.store.get_pandas(data_ref)
if df is None:
return {'error': f'DataFrame not found: {data_ref}'}
try:
grouped = df.groupby(by).agg(agg).reset_index()
# Flatten column names if multi-level
if isinstance(grouped.columns, pd.MultiIndex):
grouped.columns = ['_'.join(col).strip('_') for col in grouped.columns]
result_name = name or f"{data_ref}_grouped"
return self._check_and_store(
grouped,
name=result_name,
source=f"groupby({data_ref}, by={by})"
)
except Exception as e:
logger.error(f"groupby failed: {e}")
return {'error': str(e)}
async def join(
self,
left_ref: str,
right_ref: str,
on: Optional[Union[str, List[str]]] = None,
left_on: Optional[Union[str, List[str]]] = None,
right_on: Optional[Union[str, List[str]]] = None,
how: str = 'inner',
name: Optional[str] = None
) -> Dict:
"""
Join two DataFrames.
Args:
left_ref: Reference to left DataFrame
right_ref: Reference to right DataFrame
on: Column(s) to join on (if same name in both)
left_on: Left join column(s)
right_on: Right join column(s)
how: Join type ('inner', 'left', 'right', 'outer')
name: Optional name for result data_ref
Returns:
Dict with new data_ref for joined result
"""
left_df = self.store.get_pandas(left_ref)
right_df = self.store.get_pandas(right_ref)
if left_df is None:
return {'error': f'DataFrame not found: {left_ref}'}
if right_df is None:
return {'error': f'DataFrame not found: {right_ref}'}
try:
joined = pd.merge(
left_df, right_df,
on=on, left_on=left_on, right_on=right_on,
how=how
)
result_name = name or f"{left_ref}_{right_ref}_joined"
return self._check_and_store(
joined,
name=result_name,
source=f"join({left_ref}, {right_ref}, how={how})"
)
except Exception as e:
logger.error(f"join failed: {e}")
return {'error': str(e)}
async def list_data(self) -> Dict:
"""
List all stored DataFrames.
Returns:
Dict with list of stored DataFrames and their info
"""
refs = self.store.list_refs()
return {
'count': len(refs),
'total_memory_mb': self.store.total_memory_mb(),
'max_rows_limit': self.max_rows,
'dataframes': refs
}
async def drop_data(self, data_ref: str) -> Dict:
"""
Remove a DataFrame from storage.
Args:
data_ref: Reference to drop
Returns:
Dict with success status
"""
if self.store.drop(data_ref):
return {'success': True, 'dropped': data_ref}
return {'error': f'DataFrame not found: {data_ref}'}

View File

@@ -0,0 +1,538 @@
"""
PostgreSQL/PostGIS MCP Tools.
Provides database operations with connection pooling and PostGIS support.
"""
import asyncio
import logging
from typing import Dict, List, Optional, Any
import json
from .data_store import DataStore
from .config import load_config
logger = logging.getLogger(__name__)
# Optional imports - gracefully handle missing dependencies
try:
import asyncpg
ASYNCPG_AVAILABLE = True
except ImportError:
ASYNCPG_AVAILABLE = False
logger.warning("asyncpg not available - PostgreSQL tools will be disabled")
try:
import pandas as pd
PANDAS_AVAILABLE = True
except ImportError:
PANDAS_AVAILABLE = False
class PostgresTools:
"""PostgreSQL/PostGIS database tools"""
def __init__(self):
self.store = DataStore.get_instance()
self.config = load_config()
self.pool: Optional[Any] = None
self.max_rows = self.config.get('max_rows', 100_000)
async def _get_pool(self):
"""Get or create connection pool"""
if not ASYNCPG_AVAILABLE:
raise RuntimeError("asyncpg not installed - run: pip install asyncpg")
if self.pool is None:
postgres_url = self.config.get('postgres_url')
if not postgres_url:
raise RuntimeError(
"PostgreSQL not configured. Set POSTGRES_URL in "
"~/.config/claude/postgres.env"
)
self.pool = await asyncpg.create_pool(postgres_url, min_size=1, max_size=5)
return self.pool
async def pg_connect(self) -> Dict:
"""
Test PostgreSQL connection and return status.
Returns:
Dict with connection status, version, and database info
"""
if not ASYNCPG_AVAILABLE:
return {
'connected': False,
'error': 'asyncpg not installed',
'suggestion': 'pip install asyncpg'
}
postgres_url = self.config.get('postgres_url')
if not postgres_url:
return {
'connected': False,
'error': 'POSTGRES_URL not configured',
'suggestion': 'Create ~/.config/claude/postgres.env with POSTGRES_URL=postgresql://...'
}
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
version = await conn.fetchval('SELECT version()')
db_name = await conn.fetchval('SELECT current_database()')
user = await conn.fetchval('SELECT current_user')
# Check for PostGIS
postgis_version = None
try:
postgis_version = await conn.fetchval('SELECT PostGIS_Version()')
except Exception:
pass
return {
'connected': True,
'database': db_name,
'user': user,
'version': version.split(',')[0] if version else 'Unknown',
'postgis_version': postgis_version,
'postgis_available': postgis_version is not None
}
except Exception as e:
logger.error(f"pg_connect failed: {e}")
return {
'connected': False,
'error': str(e)
}
async def pg_query(
self,
query: str,
params: Optional[List] = None,
name: Optional[str] = None
) -> Dict:
"""
Execute SELECT query and return results as data_ref.
Args:
query: SQL SELECT query
params: Query parameters (positional, use $1, $2, etc.)
name: Optional name for result data_ref
Returns:
Dict with data_ref for results or error
"""
if not PANDAS_AVAILABLE:
return {'error': 'pandas not available'}
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
if params:
rows = await conn.fetch(query, *params)
else:
rows = await conn.fetch(query)
if not rows:
return {
'data_ref': None,
'rows': 0,
'message': 'Query returned no results'
}
# Convert to DataFrame
df = pd.DataFrame([dict(r) for r in rows])
# Check row limit
check = self.store.check_row_limit(len(df))
if check['exceeded']:
return {
'error': 'row_limit_exceeded',
**check,
'preview': df.head(100).to_dict(orient='records')
}
# Store result
data_ref = self.store.store(df, name=name, source=f"pg_query: {query[:100]}...")
return {
'data_ref': data_ref,
'rows': len(df),
'columns': list(df.columns)
}
except Exception as e:
logger.error(f"pg_query failed: {e}")
return {'error': str(e)}
async def pg_execute(
self,
query: str,
params: Optional[List] = None
) -> Dict:
"""
Execute INSERT/UPDATE/DELETE query.
Args:
query: SQL DML query
params: Query parameters
Returns:
Dict with affected rows count
"""
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
if params:
result = await conn.execute(query, *params)
else:
result = await conn.execute(query)
# Parse result (e.g., "INSERT 0 1" or "UPDATE 5")
parts = result.split()
affected = int(parts[-1]) if parts else 0
return {
'success': True,
'command': parts[0] if parts else 'UNKNOWN',
'affected_rows': affected
}
except Exception as e:
logger.error(f"pg_execute failed: {e}")
return {'error': str(e)}
async def pg_tables(self, schema: str = 'public') -> Dict:
"""
List all tables in schema.
Args:
schema: Schema name (default: public)
Returns:
Dict with list of tables
"""
query = """
SELECT
table_name,
table_type,
(SELECT count(*) FROM information_schema.columns c
WHERE c.table_schema = t.table_schema
AND c.table_name = t.table_name) as column_count
FROM information_schema.tables t
WHERE table_schema = $1
ORDER BY table_name
"""
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(query, schema)
tables = [
{
'name': r['table_name'],
'type': r['table_type'],
'columns': r['column_count']
}
for r in rows
]
return {
'schema': schema,
'count': len(tables),
'tables': tables
}
except Exception as e:
logger.error(f"pg_tables failed: {e}")
return {'error': str(e)}
async def pg_columns(self, table: str, schema: str = 'public') -> Dict:
"""
Get column information for a table.
Args:
table: Table name
schema: Schema name (default: public)
Returns:
Dict with column details
"""
query = """
SELECT
column_name,
data_type,
udt_name,
is_nullable,
column_default,
character_maximum_length,
numeric_precision
FROM information_schema.columns
WHERE table_schema = $1 AND table_name = $2
ORDER BY ordinal_position
"""
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(query, schema, table)
columns = [
{
'name': r['column_name'],
'type': r['data_type'],
'udt': r['udt_name'],
'nullable': r['is_nullable'] == 'YES',
'default': r['column_default'],
'max_length': r['character_maximum_length'],
'precision': r['numeric_precision']
}
for r in rows
]
return {
'table': f'{schema}.{table}',
'column_count': len(columns),
'columns': columns
}
except Exception as e:
logger.error(f"pg_columns failed: {e}")
return {'error': str(e)}
async def pg_schemas(self) -> Dict:
"""
List all schemas in database.
Returns:
Dict with list of schemas
"""
query = """
SELECT schema_name
FROM information_schema.schemata
WHERE schema_name NOT IN ('pg_catalog', 'information_schema', 'pg_toast')
ORDER BY schema_name
"""
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(query)
schemas = [r['schema_name'] for r in rows]
return {
'count': len(schemas),
'schemas': schemas
}
except Exception as e:
logger.error(f"pg_schemas failed: {e}")
return {'error': str(e)}
async def st_tables(self, schema: str = 'public') -> Dict:
"""
List PostGIS-enabled tables.
Args:
schema: Schema name (default: public)
Returns:
Dict with list of tables with geometry columns
"""
query = """
SELECT
f_table_name as table_name,
f_geometry_column as geometry_column,
type as geometry_type,
srid,
coord_dimension
FROM geometry_columns
WHERE f_table_schema = $1
ORDER BY f_table_name
"""
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(query, schema)
tables = [
{
'table': r['table_name'],
'geometry_column': r['geometry_column'],
'geometry_type': r['geometry_type'],
'srid': r['srid'],
'dimensions': r['coord_dimension']
}
for r in rows
]
return {
'schema': schema,
'count': len(tables),
'postgis_tables': tables
}
except Exception as e:
if 'geometry_columns' in str(e):
return {
'error': 'PostGIS not installed or extension not enabled',
'suggestion': 'Run: CREATE EXTENSION IF NOT EXISTS postgis;'
}
logger.error(f"st_tables failed: {e}")
return {'error': str(e)}
async def st_geometry_type(self, table: str, column: str, schema: str = 'public') -> Dict:
"""
Get geometry type of a column.
Args:
table: Table name
column: Geometry column name
schema: Schema name
Returns:
Dict with geometry type information
"""
query = f"""
SELECT DISTINCT ST_GeometryType({column}) as geom_type
FROM {schema}.{table}
WHERE {column} IS NOT NULL
LIMIT 10
"""
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(query)
types = [r['geom_type'] for r in rows]
return {
'table': f'{schema}.{table}',
'column': column,
'geometry_types': types
}
except Exception as e:
logger.error(f"st_geometry_type failed: {e}")
return {'error': str(e)}
async def st_srid(self, table: str, column: str, schema: str = 'public') -> Dict:
"""
Get SRID of geometry column.
Args:
table: Table name
column: Geometry column name
schema: Schema name
Returns:
Dict with SRID information
"""
query = f"""
SELECT DISTINCT ST_SRID({column}) as srid
FROM {schema}.{table}
WHERE {column} IS NOT NULL
LIMIT 1
"""
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(query)
srid = row['srid'] if row else None
# Get SRID description
srid_info = None
if srid:
srid_query = """
SELECT srtext, proj4text
FROM spatial_ref_sys
WHERE srid = $1
"""
srid_row = await conn.fetchrow(srid_query, srid)
if srid_row:
srid_info = {
'description': srid_row['srtext'][:200] if srid_row['srtext'] else None,
'proj4': srid_row['proj4text']
}
return {
'table': f'{schema}.{table}',
'column': column,
'srid': srid,
'info': srid_info
}
except Exception as e:
logger.error(f"st_srid failed: {e}")
return {'error': str(e)}
async def st_extent(self, table: str, column: str, schema: str = 'public') -> Dict:
"""
Get bounding box of all geometries.
Args:
table: Table name
column: Geometry column name
schema: Schema name
Returns:
Dict with bounding box coordinates
"""
query = f"""
SELECT
ST_XMin(extent) as xmin,
ST_YMin(extent) as ymin,
ST_XMax(extent) as xmax,
ST_YMax(extent) as ymax
FROM (
SELECT ST_Extent({column}) as extent
FROM {schema}.{table}
) sub
"""
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(query)
if row and row['xmin'] is not None:
return {
'table': f'{schema}.{table}',
'column': column,
'bbox': {
'xmin': float(row['xmin']),
'ymin': float(row['ymin']),
'xmax': float(row['xmax']),
'ymax': float(row['ymax'])
}
}
return {
'table': f'{schema}.{table}',
'column': column,
'bbox': None,
'message': 'No geometries found or all NULL'
}
except Exception as e:
logger.error(f"st_extent failed: {e}")
return {'error': str(e)}
async def close(self):
"""Close connection pool"""
if self.pool:
await self.pool.close()
self.pool = None
def check_connection() -> None:
"""
Check PostgreSQL connection for SessionStart hook.
Prints warning to stderr if connection fails.
"""
import sys
config = load_config()
if not config.get('postgres_url'):
print(
"[data-platform] PostgreSQL not configured (POSTGRES_URL not set)",
file=sys.stderr
)
return
async def test():
try:
if not ASYNCPG_AVAILABLE:
print(
"[data-platform] asyncpg not installed - PostgreSQL tools unavailable",
file=sys.stderr
)
return
conn = await asyncpg.connect(config['postgres_url'], timeout=5)
await conn.close()
print("[data-platform] PostgreSQL connection OK", file=sys.stderr)
except Exception as e:
print(
f"[data-platform] PostgreSQL connection failed: {e}",
file=sys.stderr
)
asyncio.run(test())

View File

@@ -0,0 +1,795 @@
"""
MCP Server entry point for Data Platform integration.
Provides pandas, PostgreSQL/PostGIS, and dbt tools to Claude Code via JSON-RPC 2.0 over stdio.
"""
import asyncio
import logging
import json
from mcp.server import Server
from mcp.server.stdio import stdio_server
from mcp.types import Tool, TextContent
from .config import DataPlatformConfig
from .data_store import DataStore
from .pandas_tools import PandasTools
from .postgres_tools import PostgresTools
from .dbt_tools import DbtTools
# Suppress noisy MCP validation warnings on stderr
logging.basicConfig(level=logging.INFO)
logging.getLogger("root").setLevel(logging.ERROR)
logging.getLogger("mcp").setLevel(logging.ERROR)
logger = logging.getLogger(__name__)
class DataPlatformMCPServer:
"""MCP Server for data platform integration"""
def __init__(self):
self.server = Server("data-platform-mcp")
self.config = None
self.pandas_tools = None
self.postgres_tools = None
self.dbt_tools = None
async def initialize(self):
"""Initialize server and load configuration."""
try:
config_loader = DataPlatformConfig()
self.config = config_loader.load()
self.pandas_tools = PandasTools()
self.postgres_tools = PostgresTools()
self.dbt_tools = DbtTools()
# Log available capabilities
caps = []
caps.append("pandas")
if self.config.get('postgres_available'):
caps.append("PostgreSQL")
if self.config.get('dbt_available'):
caps.append("dbt")
logger.info(f"Data Platform MCP Server initialized with: {', '.join(caps)}")
except Exception as e:
logger.error(f"Failed to initialize: {e}")
raise
def setup_tools(self):
"""Register all available tools with the MCP server"""
@self.server.list_tools()
async def list_tools() -> list[Tool]:
"""Return list of available tools"""
tools = [
# pandas tools - always available
Tool(
name="read_csv",
description="Load CSV file into DataFrame",
inputSchema={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path to CSV file"
},
"name": {
"type": "string",
"description": "Optional name for data_ref"
},
"chunk_size": {
"type": "integer",
"description": "Process in chunks of this size"
}
},
"required": ["file_path"]
}
),
Tool(
name="read_parquet",
description="Load Parquet file into DataFrame",
inputSchema={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path to Parquet file"
},
"name": {
"type": "string",
"description": "Optional name for data_ref"
},
"columns": {
"type": "array",
"items": {"type": "string"},
"description": "Optional list of columns to load"
}
},
"required": ["file_path"]
}
),
Tool(
name="read_json",
description="Load JSON/JSONL file into DataFrame",
inputSchema={
"type": "object",
"properties": {
"file_path": {
"type": "string",
"description": "Path to JSON file"
},
"name": {
"type": "string",
"description": "Optional name for data_ref"
},
"lines": {
"type": "boolean",
"default": False,
"description": "Read as JSON Lines format"
}
},
"required": ["file_path"]
}
),
Tool(
name="to_csv",
description="Export DataFrame to CSV file",
inputSchema={
"type": "object",
"properties": {
"data_ref": {
"type": "string",
"description": "Reference to stored DataFrame"
},
"file_path": {
"type": "string",
"description": "Output file path"
},
"index": {
"type": "boolean",
"default": False,
"description": "Include index column"
}
},
"required": ["data_ref", "file_path"]
}
),
Tool(
name="to_parquet",
description="Export DataFrame to Parquet file",
inputSchema={
"type": "object",
"properties": {
"data_ref": {
"type": "string",
"description": "Reference to stored DataFrame"
},
"file_path": {
"type": "string",
"description": "Output file path"
},
"compression": {
"type": "string",
"default": "snappy",
"description": "Compression codec"
}
},
"required": ["data_ref", "file_path"]
}
),
Tool(
name="describe",
description="Get statistical summary of DataFrame",
inputSchema={
"type": "object",
"properties": {
"data_ref": {
"type": "string",
"description": "Reference to stored DataFrame"
}
},
"required": ["data_ref"]
}
),
Tool(
name="head",
description="Get first N rows of DataFrame",
inputSchema={
"type": "object",
"properties": {
"data_ref": {
"type": "string",
"description": "Reference to stored DataFrame"
},
"n": {
"type": "integer",
"default": 10,
"description": "Number of rows"
}
},
"required": ["data_ref"]
}
),
Tool(
name="tail",
description="Get last N rows of DataFrame",
inputSchema={
"type": "object",
"properties": {
"data_ref": {
"type": "string",
"description": "Reference to stored DataFrame"
},
"n": {
"type": "integer",
"default": 10,
"description": "Number of rows"
}
},
"required": ["data_ref"]
}
),
Tool(
name="filter",
description="Filter DataFrame rows by condition",
inputSchema={
"type": "object",
"properties": {
"data_ref": {
"type": "string",
"description": "Reference to stored DataFrame"
},
"condition": {
"type": "string",
"description": "pandas query string (e.g., 'age > 30 and city == \"NYC\"')"
},
"name": {
"type": "string",
"description": "Optional name for result data_ref"
}
},
"required": ["data_ref", "condition"]
}
),
Tool(
name="select",
description="Select specific columns from DataFrame",
inputSchema={
"type": "object",
"properties": {
"data_ref": {
"type": "string",
"description": "Reference to stored DataFrame"
},
"columns": {
"type": "array",
"items": {"type": "string"},
"description": "List of column names to select"
},
"name": {
"type": "string",
"description": "Optional name for result data_ref"
}
},
"required": ["data_ref", "columns"]
}
),
Tool(
name="groupby",
description="Group DataFrame and aggregate",
inputSchema={
"type": "object",
"properties": {
"data_ref": {
"type": "string",
"description": "Reference to stored DataFrame"
},
"by": {
"oneOf": [
{"type": "string"},
{"type": "array", "items": {"type": "string"}}
],
"description": "Column(s) to group by"
},
"agg": {
"type": "object",
"description": "Aggregation dict (e.g., {\"sales\": \"sum\", \"count\": \"mean\"})"
},
"name": {
"type": "string",
"description": "Optional name for result data_ref"
}
},
"required": ["data_ref", "by", "agg"]
}
),
Tool(
name="join",
description="Join two DataFrames",
inputSchema={
"type": "object",
"properties": {
"left_ref": {
"type": "string",
"description": "Reference to left DataFrame"
},
"right_ref": {
"type": "string",
"description": "Reference to right DataFrame"
},
"on": {
"oneOf": [
{"type": "string"},
{"type": "array", "items": {"type": "string"}}
],
"description": "Column(s) to join on (if same name in both)"
},
"left_on": {
"oneOf": [
{"type": "string"},
{"type": "array", "items": {"type": "string"}}
],
"description": "Left join column(s)"
},
"right_on": {
"oneOf": [
{"type": "string"},
{"type": "array", "items": {"type": "string"}}
],
"description": "Right join column(s)"
},
"how": {
"type": "string",
"enum": ["inner", "left", "right", "outer"],
"default": "inner",
"description": "Join type"
},
"name": {
"type": "string",
"description": "Optional name for result data_ref"
}
},
"required": ["left_ref", "right_ref"]
}
),
Tool(
name="list_data",
description="List all stored DataFrames",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="drop_data",
description="Remove a DataFrame from storage",
inputSchema={
"type": "object",
"properties": {
"data_ref": {
"type": "string",
"description": "Reference to drop"
}
},
"required": ["data_ref"]
}
),
# PostgreSQL tools
Tool(
name="pg_connect",
description="Test PostgreSQL connection and return status",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="pg_query",
description="Execute SELECT query and return results as data_ref",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL SELECT query"
},
"params": {
"type": "array",
"items": {},
"description": "Query parameters (use $1, $2, etc.)"
},
"name": {
"type": "string",
"description": "Optional name for result data_ref"
}
},
"required": ["query"]
}
),
Tool(
name="pg_execute",
description="Execute INSERT/UPDATE/DELETE query",
inputSchema={
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "SQL DML query"
},
"params": {
"type": "array",
"items": {},
"description": "Query parameters"
}
},
"required": ["query"]
}
),
Tool(
name="pg_tables",
description="List all tables in schema",
inputSchema={
"type": "object",
"properties": {
"schema": {
"type": "string",
"default": "public",
"description": "Schema name"
}
}
}
),
Tool(
name="pg_columns",
description="Get column information for a table",
inputSchema={
"type": "object",
"properties": {
"table": {
"type": "string",
"description": "Table name"
},
"schema": {
"type": "string",
"default": "public",
"description": "Schema name"
}
},
"required": ["table"]
}
),
Tool(
name="pg_schemas",
description="List all schemas in database",
inputSchema={
"type": "object",
"properties": {}
}
),
# PostGIS tools
Tool(
name="st_tables",
description="List PostGIS-enabled tables",
inputSchema={
"type": "object",
"properties": {
"schema": {
"type": "string",
"default": "public",
"description": "Schema name"
}
}
}
),
Tool(
name="st_geometry_type",
description="Get geometry type of a column",
inputSchema={
"type": "object",
"properties": {
"table": {
"type": "string",
"description": "Table name"
},
"column": {
"type": "string",
"description": "Geometry column name"
},
"schema": {
"type": "string",
"default": "public",
"description": "Schema name"
}
},
"required": ["table", "column"]
}
),
Tool(
name="st_srid",
description="Get SRID of geometry column",
inputSchema={
"type": "object",
"properties": {
"table": {
"type": "string",
"description": "Table name"
},
"column": {
"type": "string",
"description": "Geometry column name"
},
"schema": {
"type": "string",
"default": "public",
"description": "Schema name"
}
},
"required": ["table", "column"]
}
),
Tool(
name="st_extent",
description="Get bounding box of all geometries",
inputSchema={
"type": "object",
"properties": {
"table": {
"type": "string",
"description": "Table name"
},
"column": {
"type": "string",
"description": "Geometry column name"
},
"schema": {
"type": "string",
"default": "public",
"description": "Schema name"
}
},
"required": ["table", "column"]
}
),
# dbt tools
Tool(
name="dbt_parse",
description="Validate dbt project (pre-flight check)",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="dbt_run",
description="Run dbt models with pre-validation",
inputSchema={
"type": "object",
"properties": {
"select": {
"type": "string",
"description": "Model selection (e.g., 'model_name', '+model_name', 'tag:daily')"
},
"exclude": {
"type": "string",
"description": "Models to exclude"
},
"full_refresh": {
"type": "boolean",
"default": False,
"description": "Rebuild incremental models"
}
}
}
),
Tool(
name="dbt_test",
description="Run dbt tests",
inputSchema={
"type": "object",
"properties": {
"select": {
"type": "string",
"description": "Test selection"
},
"exclude": {
"type": "string",
"description": "Tests to exclude"
}
}
}
),
Tool(
name="dbt_build",
description="Run dbt build (run + test) with pre-validation",
inputSchema={
"type": "object",
"properties": {
"select": {
"type": "string",
"description": "Model/test selection"
},
"exclude": {
"type": "string",
"description": "Resources to exclude"
},
"full_refresh": {
"type": "boolean",
"default": False,
"description": "Rebuild incremental models"
}
}
}
),
Tool(
name="dbt_compile",
description="Compile dbt models to SQL without executing",
inputSchema={
"type": "object",
"properties": {
"select": {
"type": "string",
"description": "Model selection"
}
}
}
),
Tool(
name="dbt_ls",
description="List dbt resources",
inputSchema={
"type": "object",
"properties": {
"select": {
"type": "string",
"description": "Resource selection"
},
"resource_type": {
"type": "string",
"enum": ["model", "test", "seed", "snapshot", "source"],
"description": "Filter by type"
},
"output": {
"type": "string",
"enum": ["name", "path", "json"],
"default": "name",
"description": "Output format"
}
}
}
),
Tool(
name="dbt_docs_generate",
description="Generate dbt documentation",
inputSchema={
"type": "object",
"properties": {}
}
),
Tool(
name="dbt_lineage",
description="Get model dependencies and lineage",
inputSchema={
"type": "object",
"properties": {
"model": {
"type": "string",
"description": "Model name to analyze"
}
},
"required": ["model"]
}
)
]
return tools
@self.server.call_tool()
async def call_tool(name: str, arguments: dict) -> list[TextContent]:
"""Handle tool invocation."""
try:
# Route to appropriate tool handler
# pandas tools
if name == "read_csv":
result = await self.pandas_tools.read_csv(**arguments)
elif name == "read_parquet":
result = await self.pandas_tools.read_parquet(**arguments)
elif name == "read_json":
result = await self.pandas_tools.read_json(**arguments)
elif name == "to_csv":
result = await self.pandas_tools.to_csv(**arguments)
elif name == "to_parquet":
result = await self.pandas_tools.to_parquet(**arguments)
elif name == "describe":
result = await self.pandas_tools.describe(**arguments)
elif name == "head":
result = await self.pandas_tools.head(**arguments)
elif name == "tail":
result = await self.pandas_tools.tail(**arguments)
elif name == "filter":
result = await self.pandas_tools.filter(**arguments)
elif name == "select":
result = await self.pandas_tools.select(**arguments)
elif name == "groupby":
result = await self.pandas_tools.groupby(**arguments)
elif name == "join":
result = await self.pandas_tools.join(**arguments)
elif name == "list_data":
result = await self.pandas_tools.list_data()
elif name == "drop_data":
result = await self.pandas_tools.drop_data(**arguments)
# PostgreSQL tools
elif name == "pg_connect":
result = await self.postgres_tools.pg_connect()
elif name == "pg_query":
result = await self.postgres_tools.pg_query(**arguments)
elif name == "pg_execute":
result = await self.postgres_tools.pg_execute(**arguments)
elif name == "pg_tables":
result = await self.postgres_tools.pg_tables(**arguments)
elif name == "pg_columns":
result = await self.postgres_tools.pg_columns(**arguments)
elif name == "pg_schemas":
result = await self.postgres_tools.pg_schemas()
# PostGIS tools
elif name == "st_tables":
result = await self.postgres_tools.st_tables(**arguments)
elif name == "st_geometry_type":
result = await self.postgres_tools.st_geometry_type(**arguments)
elif name == "st_srid":
result = await self.postgres_tools.st_srid(**arguments)
elif name == "st_extent":
result = await self.postgres_tools.st_extent(**arguments)
# dbt tools
elif name == "dbt_parse":
result = await self.dbt_tools.dbt_parse()
elif name == "dbt_run":
result = await self.dbt_tools.dbt_run(**arguments)
elif name == "dbt_test":
result = await self.dbt_tools.dbt_test(**arguments)
elif name == "dbt_build":
result = await self.dbt_tools.dbt_build(**arguments)
elif name == "dbt_compile":
result = await self.dbt_tools.dbt_compile(**arguments)
elif name == "dbt_ls":
result = await self.dbt_tools.dbt_ls(**arguments)
elif name == "dbt_docs_generate":
result = await self.dbt_tools.dbt_docs_generate()
elif name == "dbt_lineage":
result = await self.dbt_tools.dbt_lineage(**arguments)
else:
raise ValueError(f"Unknown tool: {name}")
return [TextContent(
type="text",
text=json.dumps(result, indent=2, default=str)
)]
except Exception as e:
logger.error(f"Tool {name} failed: {e}")
return [TextContent(
type="text",
text=json.dumps({"error": str(e)}, indent=2)
)]
async def run(self):
"""Run the MCP server"""
await self.initialize()
self.setup_tools()
async with stdio_server() as (read_stream, write_stream):
await self.server.run(
read_stream,
write_stream,
self.server.create_initialization_options()
)
async def main():
"""Main entry point"""
server = DataPlatformMCPServer()
await server.run()
if __name__ == "__main__":
asyncio.run(main())

View File

@@ -0,0 +1,49 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"
[project]
name = "data-platform-mcp"
version = "1.0.0"
description = "MCP Server for data engineering with pandas, PostgreSQL/PostGIS, and dbt"
readme = "README.md"
license = {text = "MIT"}
requires-python = ">=3.10"
authors = [
{name = "Leo Miranda"}
]
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
dependencies = [
"mcp>=0.9.0",
"pandas>=2.0.0",
"pyarrow>=14.0.0",
"asyncpg>=0.29.0",
"geoalchemy2>=0.14.0",
"shapely>=2.0.0",
"dbt-core>=1.9.0",
"dbt-postgres>=1.9.0",
"python-dotenv>=1.0.0",
"pydantic>=2.5.0",
]
[project.optional-dependencies]
dev = [
"pytest>=7.4.3",
"pytest-asyncio>=0.23.0",
]
[tool.setuptools.packages.find]
where = ["."]
include = ["mcp_server*"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]

View File

@@ -0,0 +1,23 @@
# MCP SDK
mcp>=0.9.0
# Data Processing
pandas>=2.0.0
pyarrow>=14.0.0
# PostgreSQL/PostGIS
asyncpg>=0.29.0
geoalchemy2>=0.14.0
shapely>=2.0.0
# dbt
dbt-core>=1.9.0
dbt-postgres>=1.9.0
# Utilities
python-dotenv>=1.0.0
pydantic>=2.5.0
# Testing
pytest>=7.4.3
pytest-asyncio>=0.23.0

View File

@@ -0,0 +1,3 @@
"""
Tests for Data Platform MCP Server.
"""

View File

@@ -0,0 +1,239 @@
"""
Unit tests for configuration loader.
"""
import pytest
from pathlib import Path
import os
def test_load_system_config(tmp_path, monkeypatch):
"""Test loading system-level PostgreSQL configuration"""
# Import here to avoid import errors before setup
from mcp_server.config import DataPlatformConfig
# Mock home directory
config_dir = tmp_path / '.config' / 'claude'
config_dir.mkdir(parents=True)
config_file = config_dir / 'postgres.env'
config_file.write_text(
"POSTGRES_URL=postgresql://user:pass@localhost:5432/testdb\n"
)
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(tmp_path)
config = DataPlatformConfig()
result = config.load()
assert result['postgres_url'] == 'postgresql://user:pass@localhost:5432/testdb'
assert result['postgres_available'] is True
def test_postgres_optional(tmp_path, monkeypatch):
"""Test that PostgreSQL configuration is optional"""
from mcp_server.config import DataPlatformConfig
# No postgres.env file
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(tmp_path)
# Clear any existing env vars
monkeypatch.delenv('POSTGRES_URL', raising=False)
config = DataPlatformConfig()
result = config.load()
assert result['postgres_url'] is None
assert result['postgres_available'] is False
def test_project_config_override(tmp_path, monkeypatch):
"""Test that project config overrides system config"""
from mcp_server.config import DataPlatformConfig
# Set up system config
system_config_dir = tmp_path / '.config' / 'claude'
system_config_dir.mkdir(parents=True)
system_config = system_config_dir / 'postgres.env'
system_config.write_text(
"POSTGRES_URL=postgresql://system:pass@localhost:5432/systemdb\n"
)
# Set up project config
project_dir = tmp_path / 'project'
project_dir.mkdir()
project_config = project_dir / '.env'
project_config.write_text(
"POSTGRES_URL=postgresql://project:pass@localhost:5432/projectdb\n"
"DBT_PROJECT_DIR=/path/to/dbt\n"
)
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(project_dir)
config = DataPlatformConfig()
result = config.load()
# Project config should override
assert result['postgres_url'] == 'postgresql://project:pass@localhost:5432/projectdb'
assert result['dbt_project_dir'] == '/path/to/dbt'
def test_max_rows_config(tmp_path, monkeypatch):
"""Test max rows configuration"""
from mcp_server.config import DataPlatformConfig
project_dir = tmp_path / 'project'
project_dir.mkdir()
project_config = project_dir / '.env'
project_config.write_text("DATA_PLATFORM_MAX_ROWS=50000\n")
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(project_dir)
config = DataPlatformConfig()
result = config.load()
assert result['max_rows'] == 50000
def test_default_max_rows(tmp_path, monkeypatch):
"""Test default max rows value"""
from mcp_server.config import DataPlatformConfig
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(tmp_path)
# Clear any existing env vars
monkeypatch.delenv('DATA_PLATFORM_MAX_ROWS', raising=False)
config = DataPlatformConfig()
result = config.load()
assert result['max_rows'] == 100_000 # Default value
def test_dbt_auto_detection(tmp_path, monkeypatch):
"""Test automatic dbt project detection"""
from mcp_server.config import DataPlatformConfig
# Create project with dbt_project.yml
project_dir = tmp_path / 'project'
project_dir.mkdir()
(project_dir / 'dbt_project.yml').write_text("name: test_project\n")
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(project_dir)
# Clear PWD and DBT_PROJECT_DIR to ensure auto-detection
monkeypatch.delenv('PWD', raising=False)
monkeypatch.delenv('DBT_PROJECT_DIR', raising=False)
monkeypatch.delenv('CLAUDE_PROJECT_DIR', raising=False)
config = DataPlatformConfig()
result = config.load()
assert result['dbt_project_dir'] == str(project_dir)
assert result['dbt_available'] is True
def test_dbt_subdirectory_detection(tmp_path, monkeypatch):
"""Test dbt project detection in subdirectory"""
from mcp_server.config import DataPlatformConfig
# Create project with dbt in subdirectory
project_dir = tmp_path / 'project'
project_dir.mkdir()
# Need a marker file for _find_project_directory to find the project
(project_dir / '.git').mkdir()
dbt_dir = project_dir / 'transform'
dbt_dir.mkdir()
(dbt_dir / 'dbt_project.yml').write_text("name: test_project\n")
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(project_dir)
# Clear env vars to ensure auto-detection
monkeypatch.delenv('PWD', raising=False)
monkeypatch.delenv('DBT_PROJECT_DIR', raising=False)
monkeypatch.delenv('CLAUDE_PROJECT_DIR', raising=False)
config = DataPlatformConfig()
result = config.load()
assert result['dbt_project_dir'] == str(dbt_dir)
assert result['dbt_available'] is True
def test_no_dbt_project(tmp_path, monkeypatch):
"""Test when no dbt project exists"""
from mcp_server.config import DataPlatformConfig
project_dir = tmp_path / 'project'
project_dir.mkdir()
monkeypatch.setenv('HOME', str(tmp_path))
monkeypatch.chdir(project_dir)
# Clear any existing env vars
monkeypatch.delenv('DBT_PROJECT_DIR', raising=False)
config = DataPlatformConfig()
result = config.load()
assert result['dbt_project_dir'] is None
assert result['dbt_available'] is False
def test_find_project_directory_from_env(tmp_path, monkeypatch):
"""Test finding project directory from CLAUDE_PROJECT_DIR env var"""
from mcp_server.config import DataPlatformConfig
project_dir = tmp_path / 'my-project'
project_dir.mkdir()
(project_dir / '.git').mkdir()
monkeypatch.setenv('CLAUDE_PROJECT_DIR', str(project_dir))
config = DataPlatformConfig()
result = config._find_project_directory()
assert result == project_dir
def test_find_project_directory_from_cwd(tmp_path, monkeypatch):
"""Test finding project directory from cwd with .env file"""
from mcp_server.config import DataPlatformConfig
project_dir = tmp_path / 'project'
project_dir.mkdir()
(project_dir / '.env').write_text("TEST=value")
monkeypatch.chdir(project_dir)
monkeypatch.delenv('CLAUDE_PROJECT_DIR', raising=False)
monkeypatch.delenv('PWD', raising=False)
config = DataPlatformConfig()
result = config._find_project_directory()
assert result == project_dir
def test_find_project_directory_none_when_no_markers(tmp_path, monkeypatch):
"""Test returns None when no project markers found"""
from mcp_server.config import DataPlatformConfig
empty_dir = tmp_path / 'empty'
empty_dir.mkdir()
monkeypatch.chdir(empty_dir)
monkeypatch.delenv('CLAUDE_PROJECT_DIR', raising=False)
monkeypatch.delenv('PWD', raising=False)
monkeypatch.delenv('DBT_PROJECT_DIR', raising=False)
config = DataPlatformConfig()
result = config._find_project_directory()
assert result is None

View File

@@ -0,0 +1,240 @@
"""
Unit tests for Arrow IPC DataFrame registry.
"""
import pytest
import pandas as pd
import pyarrow as pa
def test_store_pandas_dataframe():
"""Test storing pandas DataFrame"""
from mcp_server.data_store import DataStore
# Create fresh instance for test
store = DataStore()
store._dataframes = {}
store._metadata = {}
df = pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})
data_ref = store.store(df, name='test_df')
assert data_ref == 'test_df'
assert 'test_df' in store._dataframes
assert store._metadata['test_df'].rows == 3
assert store._metadata['test_df'].columns == 2
def test_store_arrow_table():
"""Test storing Arrow Table directly"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
table = pa.table({'x': [1, 2, 3], 'y': [4, 5, 6]})
data_ref = store.store(table, name='arrow_test')
assert data_ref == 'arrow_test'
assert store._dataframes['arrow_test'].num_rows == 3
def test_store_auto_name():
"""Test auto-generated data_ref names"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
df = pd.DataFrame({'a': [1, 2]})
data_ref = store.store(df)
assert data_ref.startswith('df_')
assert len(data_ref) == 11 # df_ + 8 hex chars
def test_get_dataframe():
"""Test retrieving stored DataFrame"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
df = pd.DataFrame({'a': [1, 2, 3]})
store.store(df, name='get_test')
result = store.get('get_test')
assert result is not None
assert result.num_rows == 3
def test_get_pandas():
"""Test retrieving as pandas DataFrame"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
df = pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})
store.store(df, name='pandas_test')
result = store.get_pandas('pandas_test')
assert isinstance(result, pd.DataFrame)
assert list(result.columns) == ['a', 'b']
assert len(result) == 3
def test_get_nonexistent():
"""Test getting nonexistent data_ref returns None"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
assert store.get('nonexistent') is None
assert store.get_pandas('nonexistent') is None
def test_list_refs():
"""Test listing all stored DataFrames"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
store.store(pd.DataFrame({'a': [1, 2]}), name='df1')
store.store(pd.DataFrame({'b': [3, 4, 5]}), name='df2')
refs = store.list_refs()
assert len(refs) == 2
ref_names = [r['ref'] for r in refs]
assert 'df1' in ref_names
assert 'df2' in ref_names
def test_drop_dataframe():
"""Test dropping a DataFrame"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
store.store(pd.DataFrame({'a': [1]}), name='drop_test')
assert store.get('drop_test') is not None
result = store.drop('drop_test')
assert result is True
assert store.get('drop_test') is None
def test_drop_nonexistent():
"""Test dropping nonexistent data_ref"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
result = store.drop('nonexistent')
assert result is False
def test_clear():
"""Test clearing all DataFrames"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
store.store(pd.DataFrame({'a': [1]}), name='df1')
store.store(pd.DataFrame({'b': [2]}), name='df2')
store.clear()
assert len(store.list_refs()) == 0
def test_get_info():
"""Test getting DataFrame metadata"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
df = pd.DataFrame({'a': [1, 2, 3], 'b': ['x', 'y', 'z']})
store.store(df, name='info_test', source='test source')
info = store.get_info('info_test')
assert info.ref == 'info_test'
assert info.rows == 3
assert info.columns == 2
assert info.column_names == ['a', 'b']
assert info.source == 'test source'
assert info.memory_bytes > 0
def test_total_memory():
"""Test total memory calculation"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
store.store(pd.DataFrame({'a': range(100)}), name='df1')
store.store(pd.DataFrame({'b': range(200)}), name='df2')
total = store.total_memory_bytes()
assert total > 0
total_mb = store.total_memory_mb()
assert total_mb >= 0
def test_check_row_limit():
"""Test row limit checking"""
from mcp_server.data_store import DataStore
store = DataStore()
store._max_rows = 100
# Under limit
result = store.check_row_limit(50)
assert result['exceeded'] is False
# Over limit
result = store.check_row_limit(150)
assert result['exceeded'] is True
assert 'suggestion' in result
def test_metadata_dtypes():
"""Test that dtypes are correctly recorded"""
from mcp_server.data_store import DataStore
store = DataStore()
store._dataframes = {}
store._metadata = {}
df = pd.DataFrame({
'int_col': [1, 2, 3],
'float_col': [1.1, 2.2, 3.3],
'str_col': ['a', 'b', 'c']
})
store.store(df, name='dtype_test')
info = store.get_info('dtype_test')
assert 'int_col' in info.dtypes
assert 'float_col' in info.dtypes
assert 'str_col' in info.dtypes

View File

@@ -0,0 +1,318 @@
"""
Unit tests for dbt MCP tools.
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
import subprocess
import json
import tempfile
import os
@pytest.fixture
def mock_config(tmp_path):
"""Mock configuration with dbt project"""
dbt_dir = tmp_path / 'dbt_project'
dbt_dir.mkdir()
(dbt_dir / 'dbt_project.yml').write_text('name: test_project\n')
return {
'dbt_project_dir': str(dbt_dir),
'dbt_profiles_dir': str(tmp_path / '.dbt')
}
@pytest.fixture
def dbt_tools(mock_config):
"""Create DbtTools instance with mocked config"""
with patch('mcp_server.dbt_tools.load_config', return_value=mock_config):
from mcp_server.dbt_tools import DbtTools
tools = DbtTools()
tools.project_dir = mock_config['dbt_project_dir']
tools.profiles_dir = mock_config['dbt_profiles_dir']
return tools
@pytest.mark.asyncio
async def test_dbt_parse_success(dbt_tools):
"""Test successful dbt parse"""
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = 'Parsed successfully'
mock_result.stderr = ''
with patch('subprocess.run', return_value=mock_result):
result = await dbt_tools.dbt_parse()
assert result['valid'] is True
@pytest.mark.asyncio
async def test_dbt_parse_failure(dbt_tools):
"""Test dbt parse with errors"""
mock_result = MagicMock()
mock_result.returncode = 1
mock_result.stdout = ''
mock_result.stderr = 'Compilation error: deprecated syntax'
with patch('subprocess.run', return_value=mock_result):
result = await dbt_tools.dbt_parse()
assert result['valid'] is False
assert 'deprecated' in str(result.get('details', '')).lower() or len(result.get('errors', [])) > 0
@pytest.mark.asyncio
async def test_dbt_run_with_prevalidation(dbt_tools):
"""Test dbt run includes pre-validation"""
# First call is parse, second is run
mock_parse = MagicMock()
mock_parse.returncode = 0
mock_parse.stdout = 'OK'
mock_parse.stderr = ''
mock_run = MagicMock()
mock_run.returncode = 0
mock_run.stdout = 'Completed successfully'
mock_run.stderr = ''
with patch('subprocess.run', side_effect=[mock_parse, mock_run]):
result = await dbt_tools.dbt_run()
assert result['success'] is True
@pytest.mark.asyncio
async def test_dbt_run_fails_validation(dbt_tools):
"""Test dbt run fails if validation fails"""
mock_parse = MagicMock()
mock_parse.returncode = 1
mock_parse.stdout = ''
mock_parse.stderr = 'Parse error'
with patch('subprocess.run', return_value=mock_parse):
result = await dbt_tools.dbt_run()
assert 'error' in result
assert 'Pre-validation failed' in result['error']
@pytest.mark.asyncio
async def test_dbt_run_with_selection(dbt_tools):
"""Test dbt run with model selection"""
mock_parse = MagicMock()
mock_parse.returncode = 0
mock_parse.stdout = 'OK'
mock_parse.stderr = ''
mock_run = MagicMock()
mock_run.returncode = 0
mock_run.stdout = 'Completed'
mock_run.stderr = ''
calls = []
def track_calls(*args, **kwargs):
calls.append(args[0] if args else kwargs.get('args', []))
if len(calls) == 1:
return mock_parse
return mock_run
with patch('subprocess.run', side_effect=track_calls):
result = await dbt_tools.dbt_run(select='dim_customers')
# Verify --select was passed
assert any('--select' in str(call) for call in calls)
@pytest.mark.asyncio
async def test_dbt_test(dbt_tools):
"""Test dbt test"""
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = 'All tests passed'
mock_result.stderr = ''
with patch('subprocess.run', return_value=mock_result):
result = await dbt_tools.dbt_test()
assert result['success'] is True
@pytest.mark.asyncio
async def test_dbt_build(dbt_tools):
"""Test dbt build with pre-validation"""
mock_parse = MagicMock()
mock_parse.returncode = 0
mock_parse.stdout = 'OK'
mock_parse.stderr = ''
mock_build = MagicMock()
mock_build.returncode = 0
mock_build.stdout = 'Build complete'
mock_build.stderr = ''
with patch('subprocess.run', side_effect=[mock_parse, mock_build]):
result = await dbt_tools.dbt_build()
assert result['success'] is True
@pytest.mark.asyncio
async def test_dbt_compile(dbt_tools):
"""Test dbt compile"""
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = 'Compiled'
mock_result.stderr = ''
with patch('subprocess.run', return_value=mock_result):
result = await dbt_tools.dbt_compile()
assert result['success'] is True
@pytest.mark.asyncio
async def test_dbt_ls(dbt_tools):
"""Test dbt ls"""
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = 'dim_customers\ndim_products\nfct_orders\n'
mock_result.stderr = ''
with patch('subprocess.run', return_value=mock_result):
result = await dbt_tools.dbt_ls()
assert result['success'] is True
assert result['count'] == 3
assert 'dim_customers' in result['resources']
@pytest.mark.asyncio
async def test_dbt_docs_generate(dbt_tools, tmp_path):
"""Test dbt docs generate"""
mock_result = MagicMock()
mock_result.returncode = 0
mock_result.stdout = 'Done'
mock_result.stderr = ''
# Create fake target directory
target_dir = tmp_path / 'dbt_project' / 'target'
target_dir.mkdir(parents=True)
(target_dir / 'catalog.json').write_text('{}')
(target_dir / 'manifest.json').write_text('{}')
dbt_tools.project_dir = str(tmp_path / 'dbt_project')
with patch('subprocess.run', return_value=mock_result):
result = await dbt_tools.dbt_docs_generate()
assert result['success'] is True
assert result['catalog_generated'] is True
assert result['manifest_generated'] is True
@pytest.mark.asyncio
async def test_dbt_lineage(dbt_tools, tmp_path):
"""Test dbt lineage"""
# Create manifest
target_dir = tmp_path / 'dbt_project' / 'target'
target_dir.mkdir(parents=True)
manifest = {
'nodes': {
'model.test.dim_customers': {
'name': 'dim_customers',
'resource_type': 'model',
'schema': 'public',
'database': 'testdb',
'description': 'Customer dimension',
'tags': ['daily'],
'config': {'materialized': 'table'},
'depends_on': {
'nodes': ['model.test.stg_customers']
}
},
'model.test.stg_customers': {
'name': 'stg_customers',
'resource_type': 'model',
'depends_on': {'nodes': []}
},
'model.test.fct_orders': {
'name': 'fct_orders',
'resource_type': 'model',
'depends_on': {
'nodes': ['model.test.dim_customers']
}
}
}
}
(target_dir / 'manifest.json').write_text(json.dumps(manifest))
dbt_tools.project_dir = str(tmp_path / 'dbt_project')
result = await dbt_tools.dbt_lineage('dim_customers')
assert result['model'] == 'dim_customers'
assert 'model.test.stg_customers' in result['upstream']
assert 'model.test.fct_orders' in result['downstream']
@pytest.mark.asyncio
async def test_dbt_lineage_model_not_found(dbt_tools, tmp_path):
"""Test dbt lineage with nonexistent model"""
target_dir = tmp_path / 'dbt_project' / 'target'
target_dir.mkdir(parents=True)
manifest = {
'nodes': {
'model.test.dim_customers': {
'name': 'dim_customers',
'resource_type': 'model'
}
}
}
(target_dir / 'manifest.json').write_text(json.dumps(manifest))
dbt_tools.project_dir = str(tmp_path / 'dbt_project')
result = await dbt_tools.dbt_lineage('nonexistent_model')
assert 'error' in result
assert 'not found' in result['error'].lower()
@pytest.mark.asyncio
async def test_dbt_no_project():
"""Test dbt tools when no project configured"""
with patch('mcp_server.dbt_tools.load_config', return_value={'dbt_project_dir': None}):
from mcp_server.dbt_tools import DbtTools
tools = DbtTools()
tools.project_dir = None
result = await tools.dbt_run()
assert 'error' in result
assert 'not found' in result['error'].lower()
@pytest.mark.asyncio
async def test_dbt_timeout(dbt_tools):
"""Test dbt command timeout handling"""
with patch('subprocess.run', side_effect=subprocess.TimeoutExpired('dbt', 300)):
result = await dbt_tools.dbt_parse()
assert 'error' in result
assert 'timed out' in result['error'].lower()
@pytest.mark.asyncio
async def test_dbt_not_installed(dbt_tools):
"""Test handling when dbt is not installed"""
with patch('subprocess.run', side_effect=FileNotFoundError()):
result = await dbt_tools.dbt_parse()
assert 'error' in result
assert 'not found' in result['error'].lower()

View File

@@ -0,0 +1,301 @@
"""
Unit tests for pandas MCP tools.
"""
import pytest
import pandas as pd
import tempfile
import os
from pathlib import Path
@pytest.fixture
def temp_csv(tmp_path):
"""Create a temporary CSV file for testing"""
csv_path = tmp_path / 'test.csv'
df = pd.DataFrame({
'id': [1, 2, 3, 4, 5],
'name': ['Alice', 'Bob', 'Charlie', 'Diana', 'Eve'],
'value': [10.5, 20.0, 30.5, 40.0, 50.5]
})
df.to_csv(csv_path, index=False)
return str(csv_path)
@pytest.fixture
def temp_parquet(tmp_path):
"""Create a temporary Parquet file for testing"""
parquet_path = tmp_path / 'test.parquet'
df = pd.DataFrame({
'id': [1, 2, 3],
'data': ['a', 'b', 'c']
})
df.to_parquet(parquet_path)
return str(parquet_path)
@pytest.fixture
def temp_json(tmp_path):
"""Create a temporary JSON file for testing"""
json_path = tmp_path / 'test.json'
df = pd.DataFrame({
'x': [1, 2],
'y': [3, 4]
})
df.to_json(json_path, orient='records')
return str(json_path)
@pytest.fixture
def pandas_tools():
"""Create PandasTools instance with fresh store"""
from mcp_server.pandas_tools import PandasTools
from mcp_server.data_store import DataStore
# Reset store for test isolation
store = DataStore.get_instance()
store._dataframes = {}
store._metadata = {}
return PandasTools()
@pytest.mark.asyncio
async def test_read_csv(pandas_tools, temp_csv):
"""Test reading CSV file"""
result = await pandas_tools.read_csv(temp_csv, name='csv_test')
assert 'data_ref' in result
assert result['data_ref'] == 'csv_test'
assert result['rows'] == 5
assert 'id' in result['columns']
assert 'name' in result['columns']
@pytest.mark.asyncio
async def test_read_csv_nonexistent(pandas_tools):
"""Test reading nonexistent CSV file"""
result = await pandas_tools.read_csv('/nonexistent/path.csv')
assert 'error' in result
assert 'not found' in result['error'].lower()
@pytest.mark.asyncio
async def test_read_parquet(pandas_tools, temp_parquet):
"""Test reading Parquet file"""
result = await pandas_tools.read_parquet(temp_parquet, name='parquet_test')
assert 'data_ref' in result
assert result['rows'] == 3
@pytest.mark.asyncio
async def test_read_json(pandas_tools, temp_json):
"""Test reading JSON file"""
result = await pandas_tools.read_json(temp_json, name='json_test')
assert 'data_ref' in result
assert result['rows'] == 2
@pytest.mark.asyncio
async def test_to_csv(pandas_tools, temp_csv, tmp_path):
"""Test exporting to CSV"""
# First load some data
await pandas_tools.read_csv(temp_csv, name='export_test')
# Export to new file
output_path = str(tmp_path / 'output.csv')
result = await pandas_tools.to_csv('export_test', output_path)
assert result['success'] is True
assert os.path.exists(output_path)
@pytest.mark.asyncio
async def test_to_parquet(pandas_tools, temp_csv, tmp_path):
"""Test exporting to Parquet"""
await pandas_tools.read_csv(temp_csv, name='parquet_export')
output_path = str(tmp_path / 'output.parquet')
result = await pandas_tools.to_parquet('parquet_export', output_path)
assert result['success'] is True
assert os.path.exists(output_path)
@pytest.mark.asyncio
async def test_describe(pandas_tools, temp_csv):
"""Test describe statistics"""
await pandas_tools.read_csv(temp_csv, name='describe_test')
result = await pandas_tools.describe('describe_test')
assert 'data_ref' in result
assert 'shape' in result
assert result['shape']['rows'] == 5
assert 'statistics' in result
assert 'null_counts' in result
@pytest.mark.asyncio
async def test_head(pandas_tools, temp_csv):
"""Test getting first N rows"""
await pandas_tools.read_csv(temp_csv, name='head_test')
result = await pandas_tools.head('head_test', n=3)
assert result['returned_rows'] == 3
assert len(result['data']) == 3
@pytest.mark.asyncio
async def test_tail(pandas_tools, temp_csv):
"""Test getting last N rows"""
await pandas_tools.read_csv(temp_csv, name='tail_test')
result = await pandas_tools.tail('tail_test', n=2)
assert result['returned_rows'] == 2
@pytest.mark.asyncio
async def test_filter(pandas_tools, temp_csv):
"""Test filtering rows"""
await pandas_tools.read_csv(temp_csv, name='filter_test')
result = await pandas_tools.filter('filter_test', 'value > 25')
assert 'data_ref' in result
assert result['rows'] == 3 # 30.5, 40.0, 50.5
@pytest.mark.asyncio
async def test_filter_invalid_condition(pandas_tools, temp_csv):
"""Test filter with invalid condition"""
await pandas_tools.read_csv(temp_csv, name='filter_error')
result = await pandas_tools.filter('filter_error', 'invalid_column > 0')
assert 'error' in result
@pytest.mark.asyncio
async def test_select(pandas_tools, temp_csv):
"""Test selecting columns"""
await pandas_tools.read_csv(temp_csv, name='select_test')
result = await pandas_tools.select('select_test', ['id', 'name'])
assert 'data_ref' in result
assert result['columns'] == ['id', 'name']
@pytest.mark.asyncio
async def test_select_invalid_column(pandas_tools, temp_csv):
"""Test select with invalid column"""
await pandas_tools.read_csv(temp_csv, name='select_error')
result = await pandas_tools.select('select_error', ['id', 'nonexistent'])
assert 'error' in result
assert 'available_columns' in result
@pytest.mark.asyncio
async def test_groupby(pandas_tools, tmp_path):
"""Test groupby aggregation"""
# Create test data with groups
csv_path = tmp_path / 'groupby.csv'
df = pd.DataFrame({
'category': ['A', 'A', 'B', 'B'],
'value': [10, 20, 30, 40]
})
df.to_csv(csv_path, index=False)
await pandas_tools.read_csv(str(csv_path), name='groupby_test')
result = await pandas_tools.groupby(
'groupby_test',
by='category',
agg={'value': 'sum'}
)
assert 'data_ref' in result
assert result['rows'] == 2 # Two groups: A, B
@pytest.mark.asyncio
async def test_join(pandas_tools, tmp_path):
"""Test joining DataFrames"""
# Create left table
left_path = tmp_path / 'left.csv'
pd.DataFrame({
'id': [1, 2, 3],
'name': ['A', 'B', 'C']
}).to_csv(left_path, index=False)
# Create right table
right_path = tmp_path / 'right.csv'
pd.DataFrame({
'id': [1, 2, 4],
'value': [100, 200, 400]
}).to_csv(right_path, index=False)
await pandas_tools.read_csv(str(left_path), name='left')
await pandas_tools.read_csv(str(right_path), name='right')
result = await pandas_tools.join('left', 'right', on='id', how='inner')
assert 'data_ref' in result
assert result['rows'] == 2 # Only id 1 and 2 match
@pytest.mark.asyncio
async def test_list_data(pandas_tools, temp_csv):
"""Test listing all DataFrames"""
await pandas_tools.read_csv(temp_csv, name='list_test1')
await pandas_tools.read_csv(temp_csv, name='list_test2')
result = await pandas_tools.list_data()
assert result['count'] == 2
refs = [df['ref'] for df in result['dataframes']]
assert 'list_test1' in refs
assert 'list_test2' in refs
@pytest.mark.asyncio
async def test_drop_data(pandas_tools, temp_csv):
"""Test dropping DataFrame"""
await pandas_tools.read_csv(temp_csv, name='drop_test')
result = await pandas_tools.drop_data('drop_test')
assert result['success'] is True
# Verify it's gone
list_result = await pandas_tools.list_data()
refs = [df['ref'] for df in list_result['dataframes']]
assert 'drop_test' not in refs
@pytest.mark.asyncio
async def test_drop_nonexistent(pandas_tools):
"""Test dropping nonexistent DataFrame"""
result = await pandas_tools.drop_data('nonexistent')
assert 'error' in result
@pytest.mark.asyncio
async def test_operations_on_nonexistent(pandas_tools):
"""Test operations on nonexistent data_ref"""
result = await pandas_tools.describe('nonexistent')
assert 'error' in result
result = await pandas_tools.head('nonexistent')
assert 'error' in result
result = await pandas_tools.filter('nonexistent', 'x > 0')
assert 'error' in result

View File

@@ -0,0 +1,338 @@
"""
Unit tests for PostgreSQL MCP tools.
"""
import pytest
from unittest.mock import Mock, AsyncMock, patch, MagicMock
import pandas as pd
@pytest.fixture
def mock_config():
"""Mock configuration"""
return {
'postgres_url': 'postgresql://test:test@localhost:5432/testdb',
'max_rows': 100000
}
@pytest.fixture
def postgres_tools(mock_config):
"""Create PostgresTools instance with mocked config"""
with patch('mcp_server.postgres_tools.load_config', return_value=mock_config):
from mcp_server.postgres_tools import PostgresTools
from mcp_server.data_store import DataStore
# Reset store
store = DataStore.get_instance()
store._dataframes = {}
store._metadata = {}
tools = PostgresTools()
tools.config = mock_config
return tools
@pytest.mark.asyncio
async def test_pg_connect_no_config():
"""Test pg_connect when no PostgreSQL configured"""
with patch('mcp_server.postgres_tools.load_config', return_value={'postgres_url': None}):
from mcp_server.postgres_tools import PostgresTools
tools = PostgresTools()
tools.config = {'postgres_url': None}
result = await tools.pg_connect()
assert result['connected'] is False
assert 'not configured' in result['error'].lower()
@pytest.mark.asyncio
async def test_pg_connect_success(postgres_tools):
"""Test successful pg_connect"""
mock_conn = AsyncMock()
mock_conn.fetchval = AsyncMock(side_effect=[
'PostgreSQL 15.1', # version
'testdb', # database name
'testuser', # user
None # PostGIS check fails
])
mock_conn.close = AsyncMock()
# Create proper async context manager
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
mock_cm.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.acquire = MagicMock(return_value=mock_cm)
# Use AsyncMock for create_pool since it's awaited
with patch('asyncpg.create_pool', new=AsyncMock(return_value=mock_pool)):
postgres_tools.pool = None
result = await postgres_tools.pg_connect()
assert result['connected'] is True
assert result['database'] == 'testdb'
@pytest.mark.asyncio
async def test_pg_query_success(postgres_tools):
"""Test successful pg_query"""
mock_rows = [
{'id': 1, 'name': 'Alice'},
{'id': 2, 'name': 'Bob'}
]
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(return_value=mock_rows)
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_query('SELECT * FROM users', name='users_data')
assert 'data_ref' in result
assert result['rows'] == 2
@pytest.mark.asyncio
async def test_pg_query_empty_result(postgres_tools):
"""Test pg_query with no results"""
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(return_value=[])
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_query('SELECT * FROM empty_table')
assert result['data_ref'] is None
assert result['rows'] == 0
@pytest.mark.asyncio
async def test_pg_execute_success(postgres_tools):
"""Test successful pg_execute"""
mock_conn = AsyncMock()
mock_conn.execute = AsyncMock(return_value='INSERT 0 3')
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_execute('INSERT INTO users VALUES (1, 2, 3)')
assert result['success'] is True
assert result['affected_rows'] == 3
assert result['command'] == 'INSERT'
@pytest.mark.asyncio
async def test_pg_tables(postgres_tools):
"""Test listing tables"""
mock_rows = [
{'table_name': 'users', 'table_type': 'BASE TABLE', 'column_count': 5},
{'table_name': 'orders', 'table_type': 'BASE TABLE', 'column_count': 8}
]
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(return_value=mock_rows)
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_tables(schema='public')
assert result['schema'] == 'public'
assert result['count'] == 2
assert len(result['tables']) == 2
@pytest.mark.asyncio
async def test_pg_columns(postgres_tools):
"""Test getting column info"""
mock_rows = [
{
'column_name': 'id',
'data_type': 'integer',
'udt_name': 'int4',
'is_nullable': 'NO',
'column_default': "nextval('users_id_seq'::regclass)",
'character_maximum_length': None,
'numeric_precision': 32
},
{
'column_name': 'name',
'data_type': 'character varying',
'udt_name': 'varchar',
'is_nullable': 'YES',
'column_default': None,
'character_maximum_length': 255,
'numeric_precision': None
}
]
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(return_value=mock_rows)
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_columns(table='users')
assert result['table'] == 'public.users'
assert result['column_count'] == 2
assert result['columns'][0]['name'] == 'id'
assert result['columns'][0]['nullable'] is False
@pytest.mark.asyncio
async def test_pg_schemas(postgres_tools):
"""Test listing schemas"""
mock_rows = [
{'schema_name': 'public'},
{'schema_name': 'app'}
]
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(return_value=mock_rows)
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_schemas()
assert result['count'] == 2
assert 'public' in result['schemas']
@pytest.mark.asyncio
async def test_st_tables(postgres_tools):
"""Test listing PostGIS tables"""
mock_rows = [
{
'table_name': 'locations',
'geometry_column': 'geom',
'geometry_type': 'POINT',
'srid': 4326,
'coord_dimension': 2
}
]
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(return_value=mock_rows)
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.st_tables()
assert result['count'] == 1
assert result['postgis_tables'][0]['table'] == 'locations'
assert result['postgis_tables'][0]['srid'] == 4326
@pytest.mark.asyncio
async def test_st_tables_no_postgis(postgres_tools):
"""Test st_tables when PostGIS not installed"""
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(side_effect=Exception("relation \"geometry_columns\" does not exist"))
# Create proper async context manager
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
mock_cm.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.acquire = MagicMock(return_value=mock_cm)
postgres_tools.pool = mock_pool
result = await postgres_tools.st_tables()
assert 'error' in result
assert 'PostGIS' in result['error']
@pytest.mark.asyncio
async def test_st_extent(postgres_tools):
"""Test getting geometry bounding box"""
mock_row = {
'xmin': -122.5,
'ymin': 37.5,
'xmax': -122.0,
'ymax': 38.0
}
mock_conn = AsyncMock()
mock_conn.fetchrow = AsyncMock(return_value=mock_row)
mock_pool = AsyncMock()
mock_pool.acquire = MagicMock(return_value=AsyncMock(
__aenter__=AsyncMock(return_value=mock_conn),
__aexit__=AsyncMock()
))
postgres_tools.pool = mock_pool
result = await postgres_tools.st_extent(table='locations', column='geom')
assert result['bbox']['xmin'] == -122.5
assert result['bbox']['ymax'] == 38.0
@pytest.mark.asyncio
async def test_error_handling(postgres_tools):
"""Test error handling for database errors"""
mock_conn = AsyncMock()
mock_conn.fetch = AsyncMock(side_effect=Exception("Connection refused"))
# Create proper async context manager
mock_cm = AsyncMock()
mock_cm.__aenter__ = AsyncMock(return_value=mock_conn)
mock_cm.__aexit__ = AsyncMock(return_value=None)
mock_pool = MagicMock()
mock_pool.acquire = MagicMock(return_value=mock_cm)
postgres_tools.pool = mock_pool
result = await postgres_tools.pg_query('SELECT 1')
assert 'error' in result
assert 'Connection refused' in result['error']