This guide provides a comprehensive walkthrough for implementing new Data Loaders in the Project amp Python client library.
- Overview
- Architecture
- Getting Started
- Implementation Guide
- Configuration
- Metadata Methods
- Testing
- Best Practices
- Examples
Data Loaders are plugins that enable loading Arrow data into various storage systems. The architecture is designed for:
- Zero-copy operations using PyArrow for performance
- Auto-discovery mechanism via
__init_subclass__ - Standardized interfaces across all loaders
- Type-safe configuration with dataclasses
- Comprehensive error handling and metadata collection
DataLoader[TConfig] (ABC, Generic)
├── PostgreSQLLoader[PostgreSQLConfig]
├── RedisLoader[RedisConfig]
├── SnowflakeLoader[SnowflakeConnectionConfig]
├── DeltaLakeLoader[DeltaLakeStorageConfig]
├── IcebergLoader[IcebergStorageConfig]
└── LMDBLoader[LMDBConfig]
- DataLoader: Generic base class with common functionality
- LoadMode: Enum for load operations (APPEND, OVERWRITE, UPSERT, MERGE)
- LoadResult: Standardized result object with metadata
- Auto-discovery: Automatic registration via class inheritance
Create a new file in src/amp/loaders/implementations/ following the naming pattern {system}_loader.py:
# src/amp/loaders/implementations/example_loader.py
from dataclasses import dataclass
from typing import Any, Dict
import pyarrow as pa
from ..base import DataLoader, LoadMode
@dataclass
class ExampleConfig:
host: str
port: int = 5432
database: str
timeout: int = 30
class ExampleLoader(DataLoader[ExampleConfig]):
"""Example loader implementation"""
# Declare supported capabilities
SUPPORTED_MODES = {LoadMode.APPEND, LoadMode.OVERWRITE}
REQUIRES_SCHEMA_MATCH = False
SUPPORTS_TRANSACTIONS = True
def _parse_config(self, config: Dict[str, Any]) -> ExampleConfig:
return ExampleConfig(**config)
def _get_required_config_fields(self) -> list[str]:
return ['host', 'database']
def connect(self) -> None:
# Implementation here
self._is_connected = True
def disconnect(self) -> None:
# Implementation here
self._is_connected = False
def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> int:
# Implementation here - return number of rows loaded
return batch.num_rowsAdd import to src/amp/loaders/implementations/__init__.py:
try:
from .example_loader import ExampleLoader
except ImportError:
ExampleLoader = None
if ExampleLoader:
__all__.append('ExampleLoader')The loader will automatically be registered and available as 'example'.
Parse configuration into typed format:
def _parse_config(self, config: Dict[str, Any]) -> ExampleConfig:
try:
return ExampleConfig(**config)
except (TypeError, KeyError) as e:
raise ValueError(f"Invalid configuration: {e}")Establish connection to your target system:
def connect(self) -> None:
try:
self._connection = create_connection(
host=self.config.host,
port=self.config.port,
database=self.config.database
)
# Test connection and log info
info = self._connection.get_info()
self.logger.info(f"Connected to {info['system']} v{info['version']}")
self._is_connected = True
except Exception as e:
self.logger.error(f"Failed to connect: {e}")
raiseClean up connections:
def disconnect(self) -> None:
if self._connection:
self._connection.close()
self._connection = None
self._is_connected = False
self.logger.info("Disconnected")Core loading logic - the base class handles everything else:
def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> int:
# Base class already handled:
# - Connection checking
# - Mode validation
# - Table creation (_create_table_from_schema)
# - Overwrite clearing (_clear_table)
# - Error handling and LoadResult creation
# - Timing and metadata collection
# Just implement the actual data loading
data_dict = batch.to_pydict()
rows_written = 0
for i in range(batch.num_rows):
# Process each row using zero-copy Arrow operations
row_data = {col: data_dict[col][i] for col in data_dict.keys()}
self._connection.insert(table_name, row_data)
rows_written += 1
return rows_writtendef _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None:
"""Create table from Arrow schema"""
columns = []
for field in schema:
if pa.types.is_timestamp(field.type):
sql_type = 'TIMESTAMP'
elif pa.types.is_int64(field.type):
sql_type = 'BIGINT'
elif pa.types.is_string(field.type):
sql_type = 'VARCHAR'
else:
sql_type = 'VARCHAR' # Safe fallback
nullable = '' if field.nullable else ' NOT NULL'
columns.append(f'"{field.name}" {sql_type}{nullable}')
sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(columns)})"
self._connection.execute(sql)
def _clear_table(self, table_name: str) -> None:
"""Clear table for overwrite mode"""
self._connection.execute(f"DELETE FROM {table_name}")def get_table_info(self, table_name: str) -> Optional[Dict[str, Any]]:
"""Get table information"""
try:
return {
'table_name': table_name,
'row_count': self._get_row_count(table_name),
'columns': self._get_column_info(table_name),
'size_bytes': self._get_table_size(table_name)
}
except Exception as e:
self.logger.error(f"Failed to get table info: {e}")
return Nonefrom dataclasses import dataclass
from typing import Optional
@dataclass
class ExampleConfig:
host: str
port: int = 5432
database: str
user: str
password: str
timeout: Optional[int] = None
max_connections: int = 10
class ExampleLoader(DataLoader[ExampleConfig]):
def _parse_config(self, config: Dict[str, Any]) -> ExampleConfig:
try:
return ExampleConfig(**config)
except (TypeError, KeyError) as e:
raise ValueError(f"Invalid configuration: {e}")
def _get_required_config_fields(self) -> list[str]:
return ['host', 'database', 'user', 'password']Both metadata methods are required and must include specific fields for consistency across loaders.
def _get_batch_metadata(self, batch: pa.RecordBatch, duration: float, **kwargs) -> Dict[str, Any]:
"""Get metadata for batch operation"""
return {
'operation': 'load_batch', # REQUIRED
'batch_size': batch.num_rows,
'schema_fields': len(batch.schema),
'throughput_rows_per_sec': round(batch.num_rows / duration, 2) if duration > 0 else 0,
# Add loader-specific fields
'loading_method': self.config.method,
'connection_pool_size': self.config.max_connections
}_get_table_metadata(self, table: pa.Table, duration: float, batch_count: int, **kwargs) -> Dict[str, Any]
def _get_table_metadata(self, table: pa.Table, duration: float, batch_count: int, **kwargs) -> Dict[str, Any]:
"""Get metadata for table operation"""
return {
'operation': 'load_table', # REQUIRED
'batch_count': batch_count,
'batches_processed': batch_count, # REQUIRED for some tests
'total_rows': table.num_rows,
'schema_fields': len(table.schema),
'avg_batch_size': round(table.num_rows / batch_count, 2) if batch_count > 0 else 0,
'table_size_mb': round(table.nbytes / 1024 / 1024, 2),
'throughput_rows_per_sec': round(table.num_rows / duration, 2) if duration > 0 else 0,
# Add loader-specific fields
'loading_method': self.config.method
}The project uses a generalized test infrastructure that eliminates code duplication across loader tests. Instead of writing standalone tests for each loader, you inherit from shared base test classes.
tests/integration/loaders/
├── conftest.py # Base classes and fixtures
├── test_base_loader.py # 7 core tests (all loaders inherit)
├── test_base_streaming.py # 5 streaming tests (for loaders with reorg support)
└── backends/
├── test_postgresql.py # PostgreSQL-specific config + tests
├── test_redis.py # Redis-specific config + tests
└── test_example.py # Your loader tests here
Add your loader's configuration fixture to tests/conftest.py:
@pytest.fixture(scope='session')
def example_test_config(request):
"""Example loader configuration from testcontainer or environment"""
# Use testcontainers for CI, or fall back to environment variables
if TESTCONTAINERS_AVAILABLE and USE_TESTCONTAINERS:
# Set up testcontainer (if applicable)
example_container = request.getfixturevalue('example_container')
return {
'host': example_container.get_container_host_ip(),
'port': example_container.get_exposed_port(5432),
'database': 'test_db',
'user': 'test_user',
'password': 'test_pass',
}
else:
# Fall back to environment variables
return {
'host': os.getenv('EXAMPLE_HOST', 'localhost'),
'port': int(os.getenv('EXAMPLE_PORT', '5432')),
'database': os.getenv('EXAMPLE_DB', 'test_db'),
'user': os.getenv('EXAMPLE_USER', 'test_user'),
'password': os.getenv('EXAMPLE_PASSWORD', 'test_pass'),
}Create tests/integration/loaders/backends/test_example.py:
"""
Example loader integration tests using generalized test infrastructure.
"""
from typing import Any, Dict, List, Optional
import pytest
from src.amp.loaders.implementations.example_loader import ExampleLoader
from tests.integration.loaders.conftest import LoaderTestConfig
from tests.integration.loaders.test_base_loader import BaseLoaderTests
from tests.integration.loaders.test_base_streaming import BaseStreamingTests
class ExampleTestConfig(LoaderTestConfig):
"""Example-specific test configuration"""
loader_class = ExampleLoader
config_fixture_name = 'example_test_config'
# Declare loader capabilities
supports_overwrite = True
supports_streaming = True # Set to False if no streaming support
supports_multi_network = True # For blockchain loaders with reorg
supports_null_values = True
def get_row_count(self, loader: ExampleLoader, table_name: str) -> int:
"""Get row count from table"""
# Implement using your loader's API
return loader._connection.query(f"SELECT COUNT(*) FROM {table_name}")[0]['count']
def query_rows(
self,
loader: ExampleLoader,
table_name: str,
where: Optional[str] = None,
order_by: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Query rows from table"""
query = f"SELECT * FROM {table_name}"
if where:
query += f" WHERE {where}"
if order_by:
query += f" ORDER BY {order_by}"
return loader._connection.query(query)
def cleanup_table(self, loader: ExampleLoader, table_name: str) -> None:
"""Drop table"""
loader._connection.execute(f"DROP TABLE IF EXISTS {table_name}")
def get_column_names(self, loader: ExampleLoader, table_name: str) -> List[str]:
"""Get column names from table"""
result = loader._connection.query(
f"SELECT column_name FROM information_schema.columns WHERE table_name = '{table_name}'"
)
return [row['column_name'] for row in result]
# Core tests - ALL loaders must inherit these
class TestExampleCore(BaseLoaderTests):
"""Inherits 7 core tests: connection, context manager, batching, modes, null handling, errors"""
config = ExampleTestConfig()
# Streaming tests - Only for loaders with streaming/reorg support
class TestExampleStreaming(BaseStreamingTests):
"""Inherits 5 streaming tests: metadata columns, reorg deletion, overlapping ranges, multi-network, microbatch dedup"""
config = ExampleTestConfig()
# Loader-specific tests
@pytest.mark.integration
@pytest.mark.example
class TestExampleSpecific:
"""Example-specific functionality tests"""
config = ExampleTestConfig()
def test_custom_feature(self, loader, test_table_name, cleanup_tables):
"""Test example-specific functionality"""
cleanup_tables.append(test_table_name)
with loader:
# Test your loader's unique features
result = loader.some_custom_method(test_table_name)
assert result.successBy inheriting from the base test classes, you automatically get:
From BaseLoaderTests (7 core tests):
test_connection- Connection establishment and disconnectiontest_context_manager- Context manager functionalitytest_batch_loading- Basic batch loadingtest_append_mode- Append mode operationstest_overwrite_mode- Overwrite mode operationstest_null_handling- Null value handlingtest_error_handling- Error scenarios
From BaseStreamingTests (5 streaming tests):
test_streaming_metadata_columns- Metadata column creationtest_reorg_deletion- Blockchain reorganization handlingtest_reorg_overlapping_ranges- Overlapping range invalidationtest_reorg_multi_network- Multi-network reorg isolationtest_microbatch_deduplication- Microbatch duplicate detection
You must implement these four methods in your LoaderTestConfig subclass:
def get_row_count(self, loader, table_name: str) -> int:
"""Return number of rows in table"""
def query_rows(self, loader, table_name: str, where=None, order_by=None) -> List[Dict]:
"""Query and return rows as list of dicts"""
def cleanup_table(self, loader, table_name: str) -> None:
"""Drop/delete the table"""
def get_column_names(self, loader, table_name: str) -> List[str]:
"""Return list of column names"""Set these flags in your LoaderTestConfig to control which tests run:
supports_overwrite = True # Can overwrite existing data
supports_streaming = True # Supports streaming with metadata
supports_multi_network = True # Supports multi-network isolation (blockchain loaders)
supports_null_values = True # Handles NULL values correctly# Run all tests for your loader
uv run pytest tests/integration/loaders/backends/test_example.py -v
# Run only core tests
uv run pytest tests/integration/loaders/backends/test_example.py::TestExampleCore -v
# Run only streaming tests
uv run pytest tests/integration/loaders/backends/test_example.py::TestExampleStreaming -v
# Run specific test
uv run pytest tests/integration/loaders/backends/test_example.py::TestExampleCore::test_connection -v- Use Arrow directly: Avoid unnecessary pandas conversions
- Batch operations: Minimize network round trips
- Zero-copy when possible: Use
batch.to_pydict()for efficient conversion - Connection pooling: Reuse connections for multiple operations
def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> int:
rows_loaded = 0
errors = []
try:
for i in range(batch.num_rows):
try:
# Process row
rows_loaded += 1
except Exception as e:
errors.append(f"Row {i}: {e}")
if len(errors) > 100: # Reasonable limit
raise Exception(f"Too many errors: {len(errors)}")
if errors:
self.logger.warning(f"Completed with {len(errors)} errors")
# Important: Report failure if no rows loaded but errors exist
if rows_loaded == 0 and errors:
error_summary = errors[:5] # Show first 5 errors
if len(errors) > 5:
error_summary.append(f"... and {len(errors) - 5} more errors")
raise Exception(f"Failed to load any rows. Errors: {'; '.join(error_summary)}")
return rows_loaded
except Exception as e:
self.logger.error(f"Loading failed: {e}")
raise- Use dataclasses for type safety and validation
- Provide sensible defaults for optional parameters
- Support environment variables for sensitive data
- Validate early in the constructor
def connect(self) -> None:
try:
self._connection = create_connection(self.config)
# Always test the connection
self._connection.ping()
self.logger.info(f"Connected to {self.config.host}:{self.config.port}")
self._is_connected = True
except Exception as e:
self.logger.error(f"Connection failed: {e}")
raise
def disconnect(self) -> None:
if self._connection:
self._connection.close()
self._connection = None
self._is_connected = False
self.logger.info("Disconnected")from dataclasses import dataclass
from typing import Any, Dict, Optional
import pyarrow as pa
from ..base import DataLoader, LoadMode
@dataclass
class ExampleConfig:
host: str
port: int = 5432
database: str
user: str
password: str
timeout: int = 30
class ExampleLoader(DataLoader[ExampleConfig]):
"""Complete example loader with all required methods"""
SUPPORTED_MODES = {LoadMode.APPEND, LoadMode.OVERWRITE}
REQUIRES_SCHEMA_MATCH = False
SUPPORTS_TRANSACTIONS = True
def __init__(self, config: Dict[str, Any]):
super().__init__(config)
self._connection = None
def _parse_config(self, config: Dict[str, Any]) -> ExampleConfig:
return ExampleConfig(**config)
def _get_required_config_fields(self) -> list[str]:
return ['host', 'database', 'user', 'password']
def connect(self) -> None:
try:
self._connection = create_connection(
host=self.config.host,
port=self.config.port,
database=self.config.database,
user=self.config.user,
password=self.config.password,
timeout=self.config.timeout
)
self._is_connected = True
self.logger.info(f"Connected to {self.config.host}:{self.config.port}")
except Exception as e:
self.logger.error(f"Failed to connect: {e}")
raise
def disconnect(self) -> None:
if self._connection:
self._connection.close()
self._connection = None
self._is_connected = False
self.logger.info("Disconnected")
def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> int:
# Convert batch to format your system understands
data_dict = batch.to_pydict()
rows_loaded = 0
for i in range(batch.num_rows):
row_data = {col: data_dict[col][i] for col in data_dict.keys()}
self._connection.insert(table_name, row_data)
rows_loaded += 1
return rows_loaded
def _get_batch_metadata(self, batch: pa.RecordBatch, duration: float, **kwargs) -> Dict[str, Any]:
return {
'operation': 'load_batch',
'batch_size': batch.num_rows,
'schema_fields': len(batch.schema),
'throughput_rows_per_sec': round(batch.num_rows / duration, 2) if duration > 0 else 0,
'host': self.config.host,
'database': self.config.database
}
def _get_table_metadata(self, table: pa.Table, duration: float, batch_count: int, **kwargs) -> Dict[str, Any]:
return {
'operation': 'load_table',
'batch_count': batch_count,
'batches_processed': batch_count,
'total_rows': table.num_rows,
'schema_fields': len(table.schema),
'avg_batch_size': round(table.num_rows / batch_count, 2) if batch_count > 0 else 0,
'table_size_mb': round(table.nbytes / 1024 / 1024, 2),
'throughput_rows_per_sec': round(table.num_rows / duration, 2) if duration > 0 else 0,
'host': self.config.host,
'database': self.config.database
}
# Optional: Enhanced functionality
def _create_table_from_schema(self, schema: pa.Schema, table_name: str) -> None:
columns = []
for field in schema:
if pa.types.is_timestamp(field.type):
sql_type = 'TIMESTAMP'
elif pa.types.is_int64(field.type):
sql_type = 'BIGINT'
elif pa.types.is_string(field.type):
sql_type = 'VARCHAR'
else:
sql_type = 'VARCHAR' # Safe fallback
nullable = '' if field.nullable else ' NOT NULL'
columns.append(f'"{field.name}" {sql_type}{nullable}')
sql = f"CREATE TABLE IF NOT EXISTS {table_name} ({', '.join(columns)})"
self._connection.execute(sql)
def _clear_table(self, table_name: str) -> None:
self._connection.execute(f"DELETE FROM {table_name}")
def get_table_info(self, table_name: str) -> Optional[Dict[str, Any]]:
try:
result = self._connection.query(f"SELECT COUNT(*) FROM {table_name}")
return {
'table_name': table_name,
'row_count': result[0]['count'],
'exists': True
}
except Exception:
return None@dataclass
class KeyValueConfig:
host: str
port: int = 6379
database: int = 0
class KeyValueLoader(DataLoader[KeyValueConfig]):
"""Simple key-value store loader"""
SUPPORTED_MODES = {LoadMode.APPEND, LoadMode.OVERWRITE}
def _parse_config(self, config: Dict[str, Any]) -> KeyValueConfig:
return KeyValueConfig(**config)
def _get_required_config_fields(self) -> list[str]:
return ['host']
def connect(self) -> None:
self._client = KeyValueClient(
host=self.config.host,
port=self.config.port,
db=self.config.database
)
self._is_connected = True
def disconnect(self) -> None:
if self._client:
self._client.close()
self._is_connected = False
def _load_batch_impl(self, batch: pa.RecordBatch, table_name: str, **kwargs) -> int:
data_dict = batch.to_pydict()
# Assume first column is key
key_col = batch.schema[0].name
keys = data_dict[key_col]
rows_loaded = 0
for i in range(batch.num_rows):
key = f"{table_name}:{keys[i]}"
value = {col: data_dict[col][i] for col in data_dict.keys()}
self._client.set(key, value)
rows_loaded += 1
return rows_loaded
def _get_batch_metadata(self, batch: pa.RecordBatch, duration: float, **kwargs) -> Dict[str, Any]:
return {
'operation': 'load_batch',
'batch_size': batch.num_rows,
'schema_fields': len(batch.schema),
'throughput_rows_per_sec': round(batch.num_rows / duration, 2) if duration > 0 else 0,
'host': self.config.host,
'database': self.config.database
}
def _get_table_metadata(self, table: pa.Table, duration: float, batch_count: int, **kwargs) -> Dict[str, Any]:
return {
'operation': 'load_table',
'batch_count': batch_count,
'batches_processed': batch_count,
'total_rows': table.num_rows,
'schema_fields': len(table.schema),
'throughput_rows_per_sec': round(table.num_rows / duration, 2) if duration > 0 else 0,
'host': self.config.host,
'database': self.config.database
}