Add data layer (DuckDB storage, fetcher) and spot adapter with tests
- Add DataStorage class for DuckDB-based market data persistence - Add DataFetcher for historical candle backfill and sync operations - Add BinanceSpotAdapter for spot wallet balance queries - Add binance_spot_base_url to Settings for spot testnet support - Add comprehensive unit tests (50 new tests, 82 total) - Coverage increased from 62% to 86%
This commit is contained in:
@@ -5,9 +5,10 @@ connecting to cryptocurrency exchanges.
|
|||||||
|
|
||||||
Supported Exchanges:
|
Supported Exchanges:
|
||||||
- Binance USDⓈ-M Perpetual Futures (primary)
|
- Binance USDⓈ-M Perpetual Futures (primary)
|
||||||
|
- Binance Spot (balance checking only)
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
from tradefinder.adapters import BinanceUSDMAdapter
|
from tradefinder.adapters import BinanceUSDMAdapter, BinanceSpotAdapter
|
||||||
from tradefinder.adapters.types import OrderRequest, Side, PositionSide, OrderType
|
from tradefinder.adapters.types import OrderRequest, Side, PositionSide, OrderType
|
||||||
|
|
||||||
adapter = BinanceUSDMAdapter(settings)
|
adapter = BinanceUSDMAdapter(settings)
|
||||||
@@ -23,6 +24,7 @@ from tradefinder.adapters.base import (
|
|||||||
OrderValidationError,
|
OrderValidationError,
|
||||||
RateLimitError,
|
RateLimitError,
|
||||||
)
|
)
|
||||||
|
from tradefinder.adapters.binance_spot import BinanceSpotAdapter, SpotBalance
|
||||||
from tradefinder.adapters.binance_usdm import BinanceUSDMAdapter
|
from tradefinder.adapters.binance_usdm import BinanceUSDMAdapter
|
||||||
from tradefinder.adapters.types import (
|
from tradefinder.adapters.types import (
|
||||||
AccountBalance,
|
AccountBalance,
|
||||||
@@ -45,6 +47,7 @@ __all__ = [
|
|||||||
"ExchangeAdapter",
|
"ExchangeAdapter",
|
||||||
# Implementations
|
# Implementations
|
||||||
"BinanceUSDMAdapter",
|
"BinanceUSDMAdapter",
|
||||||
|
"BinanceSpotAdapter",
|
||||||
# Types
|
# Types
|
||||||
"AccountBalance",
|
"AccountBalance",
|
||||||
"Candle",
|
"Candle",
|
||||||
@@ -57,6 +60,7 @@ __all__ = [
|
|||||||
"Position",
|
"Position",
|
||||||
"PositionSide",
|
"PositionSide",
|
||||||
"Side",
|
"Side",
|
||||||
|
"SpotBalance",
|
||||||
"SymbolInfo",
|
"SymbolInfo",
|
||||||
"TimeInForce",
|
"TimeInForce",
|
||||||
# Exceptions
|
# Exceptions
|
||||||
|
|||||||
214
src/tradefinder/adapters/binance_spot.py
Normal file
214
src/tradefinder/adapters/binance_spot.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""Binance Spot API adapter for balance checking.
|
||||||
|
|
||||||
|
Provides read-only access to spot wallet balances.
|
||||||
|
Uses the same API credentials as the futures adapter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from datetime import UTC, datetime
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Any
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from tradefinder.adapters.base import AuthenticationError, ExchangeError
|
||||||
|
from tradefinder.core.config import Settings
|
||||||
|
|
||||||
|
logger = structlog.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SpotBalance:
|
||||||
|
"""Spot wallet balance for an asset."""
|
||||||
|
|
||||||
|
asset: str
|
||||||
|
free: Decimal
|
||||||
|
locked: Decimal
|
||||||
|
updated_at: datetime
|
||||||
|
|
||||||
|
@property
|
||||||
|
def total(self) -> Decimal:
|
||||||
|
"""Total balance (free + locked)."""
|
||||||
|
return self.free + self.locked
|
||||||
|
|
||||||
|
|
||||||
|
class BinanceSpotAdapter:
|
||||||
|
"""Binance Spot API adapter for balance checking.
|
||||||
|
|
||||||
|
This adapter connects to Binance Spot (either testnet or production)
|
||||||
|
and provides read-only access to spot wallet balances.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
settings = get_settings()
|
||||||
|
adapter = BinanceSpotAdapter(settings)
|
||||||
|
await adapter.connect()
|
||||||
|
balances = await adapter.get_all_balances()
|
||||||
|
await adapter.disconnect()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, settings: Settings) -> None:
|
||||||
|
"""Initialize adapter with settings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
settings: Application settings containing API credentials
|
||||||
|
"""
|
||||||
|
self.settings = settings
|
||||||
|
self.base_url = settings.binance_spot_base_url
|
||||||
|
self._client: httpx.AsyncClient | None = None
|
||||||
|
self._recv_window = 5000
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _api_key(self) -> str:
|
||||||
|
"""Get active API key."""
|
||||||
|
key = self.settings.get_active_api_key()
|
||||||
|
if key is None:
|
||||||
|
raise AuthenticationError("No API key configured for current trading mode")
|
||||||
|
return key.get_secret_value()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _secret(self) -> str:
|
||||||
|
"""Get active secret."""
|
||||||
|
secret = self.settings.get_active_secret()
|
||||||
|
if secret is None:
|
||||||
|
raise AuthenticationError("No secret configured for current trading mode")
|
||||||
|
return secret.get_secret_value()
|
||||||
|
|
||||||
|
def _sign(self, params: dict[str, Any]) -> str:
|
||||||
|
"""Generate HMAC-SHA256 signature for request."""
|
||||||
|
query_string = urlencode(params)
|
||||||
|
signature = hmac.new(
|
||||||
|
self._secret.encode("utf-8"),
|
||||||
|
query_string.encode("utf-8"),
|
||||||
|
hashlib.sha256,
|
||||||
|
).hexdigest()
|
||||||
|
return signature
|
||||||
|
|
||||||
|
def _get_timestamp(self) -> int:
|
||||||
|
"""Get current timestamp in milliseconds."""
|
||||||
|
return int(time.time() * 1000)
|
||||||
|
|
||||||
|
async def _request(
|
||||||
|
self,
|
||||||
|
method: str,
|
||||||
|
endpoint: str,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
signed: bool = False,
|
||||||
|
) -> Any:
|
||||||
|
"""Make HTTP request to Binance Spot API."""
|
||||||
|
if self._client is None:
|
||||||
|
raise ExchangeError("Not connected. Call connect() first.")
|
||||||
|
|
||||||
|
params = params or {}
|
||||||
|
headers = {"X-MBX-APIKEY": self._api_key}
|
||||||
|
|
||||||
|
if signed:
|
||||||
|
params["timestamp"] = self._get_timestamp()
|
||||||
|
params["recvWindow"] = self._recv_window
|
||||||
|
params["signature"] = self._sign(params)
|
||||||
|
|
||||||
|
url = f"{self.base_url}{endpoint}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
if method == "GET":
|
||||||
|
response = await self._client.get(url, params=params, headers=headers)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported HTTP method: {method}")
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
if response.status_code >= 400:
|
||||||
|
code = data.get("code", 0)
|
||||||
|
msg = data.get("msg", "Unknown error")
|
||||||
|
logger.error("API error", status_code=response.status_code, code=code, msg=msg)
|
||||||
|
if code in (-1021, -1022):
|
||||||
|
raise AuthenticationError(msg)
|
||||||
|
raise ExchangeError(f"[{code}] {msg}")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.error("Request failed", endpoint=endpoint, error=str(e))
|
||||||
|
raise ExchangeError(f"Request failed: {e}") from e
|
||||||
|
|
||||||
|
async def connect(self) -> None:
|
||||||
|
"""Establish connection and validate credentials."""
|
||||||
|
self._client = httpx.AsyncClient(timeout=30.0)
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Validate credentials by fetching account info
|
||||||
|
await self._request("GET", "/api/v3/account", signed=True)
|
||||||
|
logger.info(
|
||||||
|
"Connected to Binance Spot",
|
||||||
|
mode=self.settings.trading_mode.value,
|
||||||
|
base_url=self.base_url,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await self.disconnect()
|
||||||
|
raise AuthenticationError(f"Failed to authenticate: {e}") from e
|
||||||
|
|
||||||
|
async def disconnect(self) -> None:
|
||||||
|
"""Close HTTP client."""
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
logger.info("Disconnected from Binance Spot")
|
||||||
|
|
||||||
|
async def get_balance(self, asset: str) -> SpotBalance:
|
||||||
|
"""Get spot wallet balance for a specific asset.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
asset: Asset symbol (e.g., "BTC", "USDT")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
SpotBalance with free and locked amounts
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ExchangeError: If asset not found
|
||||||
|
"""
|
||||||
|
data = await self._request("GET", "/api/v3/account", signed=True)
|
||||||
|
|
||||||
|
for balance in data.get("balances", []):
|
||||||
|
if balance["asset"] == asset:
|
||||||
|
return SpotBalance(
|
||||||
|
asset=asset,
|
||||||
|
free=Decimal(balance["free"]),
|
||||||
|
locked=Decimal(balance["locked"]),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise ExchangeError(f"Asset {asset} not found in spot wallet")
|
||||||
|
|
||||||
|
async def get_all_balances(self, min_balance: Decimal = Decimal("0")) -> list[SpotBalance]:
|
||||||
|
"""Get all spot wallet balances.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
min_balance: Minimum total balance to include (default: 0, includes all)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of SpotBalance objects with non-zero balances
|
||||||
|
"""
|
||||||
|
data = await self._request("GET", "/api/v3/account", signed=True)
|
||||||
|
now = datetime.now(UTC)
|
||||||
|
|
||||||
|
balances = []
|
||||||
|
for balance in data.get("balances", []):
|
||||||
|
free = Decimal(balance["free"])
|
||||||
|
locked = Decimal(balance["locked"])
|
||||||
|
total = free + locked
|
||||||
|
|
||||||
|
if total > min_balance:
|
||||||
|
balances.append(
|
||||||
|
SpotBalance(
|
||||||
|
asset=balance["asset"],
|
||||||
|
free=free,
|
||||||
|
locked=locked,
|
||||||
|
updated_at=now,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return balances
|
||||||
@@ -271,6 +271,20 @@ class Settings(BaseSettings):
|
|||||||
return "wss://stream.binancefuture.com"
|
return "wss://stream.binancefuture.com"
|
||||||
return "wss://fstream.binance.com"
|
return "wss://fstream.binance.com"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def binance_spot_base_url(self) -> str:
|
||||||
|
"""Get the appropriate Binance Spot API base URL for current mode."""
|
||||||
|
if self.trading_mode == TradingMode.TESTNET:
|
||||||
|
return "https://testnet.binance.vision"
|
||||||
|
return "https://api.binance.com"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def binance_spot_ws_url(self) -> str:
|
||||||
|
"""Get the appropriate Binance Spot WebSocket URL for current mode."""
|
||||||
|
if self.trading_mode == TradingMode.TESTNET:
|
||||||
|
return "wss://testnet.binance.vision"
|
||||||
|
return "wss://stream.binance.com:9443"
|
||||||
|
|
||||||
def get_active_api_key(self) -> SecretStr | None:
|
def get_active_api_key(self) -> SecretStr | None:
|
||||||
"""Get the API key for the current trading mode."""
|
"""Get the API key for the current trading mode."""
|
||||||
if self.trading_mode == TradingMode.TESTNET:
|
if self.trading_mode == TradingMode.TESTNET:
|
||||||
|
|||||||
@@ -0,0 +1,25 @@
|
|||||||
|
"""Data ingestion and storage module for TradeFinder.
|
||||||
|
|
||||||
|
This module provides functionality for fetching, storing, and
|
||||||
|
retrieving market data.
|
||||||
|
|
||||||
|
Components:
|
||||||
|
- DataStorage: DuckDB storage manager
|
||||||
|
- DataFetcher: Historical data fetcher
|
||||||
|
- schemas: Database table definitions
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
from tradefinder.data import DataStorage, DataFetcher
|
||||||
|
|
||||||
|
storage = DataStorage(Path("/data/tradefinder.duckdb"))
|
||||||
|
storage.connect()
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
await fetcher.sync_candles("BTCUSDT", "4h")
|
||||||
|
"""
|
||||||
|
|
||||||
|
from tradefinder.data.fetcher import DataFetcher
|
||||||
|
from tradefinder.data.storage import DataStorage
|
||||||
|
|
||||||
|
__all__ = ["DataStorage", "DataFetcher"]
|
||||||
|
|||||||
207
src/tradefinder/data/fetcher.py
Normal file
207
src/tradefinder/data/fetcher.py
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
"""Historical data fetcher for market data.
|
||||||
|
|
||||||
|
Fetches historical OHLCV data from exchange and stores in database.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from tradefinder.adapters.base import ExchangeAdapter
|
||||||
|
from tradefinder.data.storage import DataStorage
|
||||||
|
|
||||||
|
logger = structlog.get_logger(__name__)
|
||||||
|
|
||||||
|
# Timeframe to milliseconds mapping
|
||||||
|
TIMEFRAME_MS = {
|
||||||
|
"1m": 60 * 1000,
|
||||||
|
"3m": 3 * 60 * 1000,
|
||||||
|
"5m": 5 * 60 * 1000,
|
||||||
|
"15m": 15 * 60 * 1000,
|
||||||
|
"30m": 30 * 60 * 1000,
|
||||||
|
"1h": 60 * 60 * 1000,
|
||||||
|
"2h": 2 * 60 * 60 * 1000,
|
||||||
|
"4h": 4 * 60 * 60 * 1000,
|
||||||
|
"6h": 6 * 60 * 60 * 1000,
|
||||||
|
"8h": 8 * 60 * 60 * 1000,
|
||||||
|
"12h": 12 * 60 * 60 * 1000,
|
||||||
|
"1d": 24 * 60 * 60 * 1000,
|
||||||
|
"3d": 3 * 24 * 60 * 60 * 1000,
|
||||||
|
"1w": 7 * 24 * 60 * 60 * 1000,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class DataFetcher:
|
||||||
|
"""Fetches historical market data from exchange.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
await fetcher.backfill_candles("BTCUSDT", "4h", start_date, end_date)
|
||||||
|
await fetcher.sync_candles("BTCUSDT", "4h")
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, adapter: ExchangeAdapter, storage: DataStorage) -> None:
|
||||||
|
"""Initialize fetcher with adapter and storage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
adapter: Exchange adapter for fetching data
|
||||||
|
storage: Data storage for persisting data
|
||||||
|
"""
|
||||||
|
self.adapter = adapter
|
||||||
|
self.storage = storage
|
||||||
|
|
||||||
|
async def backfill_candles(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
timeframe: str,
|
||||||
|
start_date: datetime,
|
||||||
|
end_date: datetime | None = None,
|
||||||
|
batch_size: int = 1000,
|
||||||
|
) -> int:
|
||||||
|
"""Fetch and store historical candle data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol (e.g., "BTCUSDT")
|
||||||
|
timeframe: Candle timeframe (e.g., "1h", "4h")
|
||||||
|
start_date: Start date for backfill
|
||||||
|
end_date: End date for backfill (default: now)
|
||||||
|
batch_size: Number of candles per API request (max 1500)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total number of candles fetched and stored
|
||||||
|
"""
|
||||||
|
if end_date is None:
|
||||||
|
end_date = datetime.now()
|
||||||
|
|
||||||
|
tf_ms = TIMEFRAME_MS.get(timeframe)
|
||||||
|
if tf_ms is None:
|
||||||
|
raise ValueError(f"Unknown timeframe: {timeframe}")
|
||||||
|
|
||||||
|
start_ms = int(start_date.timestamp() * 1000)
|
||||||
|
end_ms = int(end_date.timestamp() * 1000)
|
||||||
|
|
||||||
|
total_fetched = 0
|
||||||
|
current_start = start_ms
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Starting backfill",
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe=timeframe,
|
||||||
|
start=start_date.isoformat(),
|
||||||
|
end=end_date.isoformat(),
|
||||||
|
)
|
||||||
|
|
||||||
|
while current_start < end_ms:
|
||||||
|
# Fetch batch of candles
|
||||||
|
candles = await self.adapter.get_candles(
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe=timeframe,
|
||||||
|
limit=min(batch_size, 1500),
|
||||||
|
start_time=current_start,
|
||||||
|
end_time=end_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not candles:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Store candles
|
||||||
|
inserted = self.storage.insert_candles(candles, symbol, timeframe)
|
||||||
|
total_fetched += inserted
|
||||||
|
|
||||||
|
# Move to next batch
|
||||||
|
last_candle_ts = int(candles[-1].timestamp.timestamp() * 1000)
|
||||||
|
current_start = last_candle_ts + tf_ms
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Fetched batch",
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe=timeframe,
|
||||||
|
count=len(candles),
|
||||||
|
total=total_fetched,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Break if we got fewer candles than requested (end of data)
|
||||||
|
if len(candles) < batch_size:
|
||||||
|
break
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"Backfill complete",
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe=timeframe,
|
||||||
|
total_candles=total_fetched,
|
||||||
|
)
|
||||||
|
|
||||||
|
return total_fetched
|
||||||
|
|
||||||
|
async def sync_candles(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
timeframe: str,
|
||||||
|
lookback_days: int = 30,
|
||||||
|
) -> int:
|
||||||
|
"""Sync candles from last stored timestamp to now.
|
||||||
|
|
||||||
|
If no candles exist, fetches from lookback_days ago.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
timeframe: Candle timeframe
|
||||||
|
lookback_days: Days to look back if no existing data
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of new candles fetched
|
||||||
|
"""
|
||||||
|
# Get latest stored candle timestamp
|
||||||
|
latest_ts = self.storage.get_latest_candle_timestamp(symbol, timeframe)
|
||||||
|
|
||||||
|
if latest_ts:
|
||||||
|
# Start from next candle after latest
|
||||||
|
tf_ms = TIMEFRAME_MS.get(timeframe, 3600000)
|
||||||
|
start_date = latest_ts + timedelta(milliseconds=tf_ms)
|
||||||
|
logger.info(
|
||||||
|
"Syncing from last candle",
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe=timeframe,
|
||||||
|
last_candle=latest_ts.isoformat(),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# No existing data, use lookback
|
||||||
|
start_date = datetime.now() - timedelta(days=lookback_days)
|
||||||
|
logger.info(
|
||||||
|
"No existing data, fetching from lookback",
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe=timeframe,
|
||||||
|
lookback_days=lookback_days,
|
||||||
|
)
|
||||||
|
|
||||||
|
return await self.backfill_candles(
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe=timeframe,
|
||||||
|
start_date=start_date,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fetch_latest_candles(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
timeframe: str,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> int:
|
||||||
|
"""Fetch and store the most recent candles.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
timeframe: Candle timeframe
|
||||||
|
limit: Number of recent candles to fetch
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of candles fetched
|
||||||
|
"""
|
||||||
|
candles = await self.adapter.get_candles(
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe=timeframe,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
if candles:
|
||||||
|
return self.storage.insert_candles(candles, symbol, timeframe)
|
||||||
|
return 0
|
||||||
54
src/tradefinder/data/schemas.py
Normal file
54
src/tradefinder/data/schemas.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Database schemas for market data storage."""
|
||||||
|
|
||||||
|
CANDLES_SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS candles (
|
||||||
|
symbol VARCHAR NOT NULL,
|
||||||
|
timeframe VARCHAR NOT NULL,
|
||||||
|
timestamp TIMESTAMP NOT NULL,
|
||||||
|
open DOUBLE NOT NULL,
|
||||||
|
high DOUBLE NOT NULL,
|
||||||
|
low DOUBLE NOT NULL,
|
||||||
|
close DOUBLE NOT NULL,
|
||||||
|
volume DOUBLE NOT NULL,
|
||||||
|
PRIMARY KEY (symbol, timeframe, timestamp)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_candles_symbol_tf
|
||||||
|
ON candles (symbol, timeframe, timestamp DESC);
|
||||||
|
"""
|
||||||
|
|
||||||
|
TRADES_SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS trades (
|
||||||
|
id VARCHAR PRIMARY KEY,
|
||||||
|
symbol VARCHAR NOT NULL,
|
||||||
|
side VARCHAR NOT NULL,
|
||||||
|
position_side VARCHAR NOT NULL,
|
||||||
|
entry_price DOUBLE NOT NULL,
|
||||||
|
exit_price DOUBLE,
|
||||||
|
quantity DOUBLE NOT NULL,
|
||||||
|
pnl_usdt DOUBLE,
|
||||||
|
pnl_pct DOUBLE,
|
||||||
|
strategy VARCHAR NOT NULL,
|
||||||
|
entry_time TIMESTAMP NOT NULL,
|
||||||
|
exit_time TIMESTAMP,
|
||||||
|
status VARCHAR NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_trades_symbol ON trades (symbol);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_trades_strategy ON trades (strategy);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades (entry_time DESC);
|
||||||
|
"""
|
||||||
|
|
||||||
|
FUNDING_RATES_SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS funding_rates (
|
||||||
|
symbol VARCHAR NOT NULL,
|
||||||
|
funding_rate DOUBLE NOT NULL,
|
||||||
|
funding_time TIMESTAMP NOT NULL,
|
||||||
|
mark_price DOUBLE NOT NULL,
|
||||||
|
PRIMARY KEY (symbol, funding_time)
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_funding_symbol ON funding_rates (symbol, funding_time DESC);
|
||||||
|
"""
|
||||||
|
|
||||||
|
ALL_SCHEMAS = [CANDLES_SCHEMA, TRADES_SCHEMA, FUNDING_RATES_SCHEMA]
|
||||||
242
src/tradefinder/data/storage.py
Normal file
242
src/tradefinder/data/storage.py
Normal file
@@ -0,0 +1,242 @@
|
|||||||
|
"""DuckDB storage manager for market data.
|
||||||
|
|
||||||
|
Provides async-compatible interface for storing and retrieving
|
||||||
|
market data using DuckDB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import duckdb
|
||||||
|
import structlog
|
||||||
|
|
||||||
|
from tradefinder.adapters.types import Candle, FundingRate
|
||||||
|
from tradefinder.data.schemas import ALL_SCHEMAS
|
||||||
|
|
||||||
|
logger = structlog.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DataStorage:
|
||||||
|
"""DuckDB storage manager for market data.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
storage = DataStorage(Path("/data/tradefinder.duckdb"))
|
||||||
|
storage.connect()
|
||||||
|
storage.initialize_schema()
|
||||||
|
storage.insert_candles(candles)
|
||||||
|
storage.disconnect()
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, db_path: Path) -> None:
|
||||||
|
"""Initialize storage with database path.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: Path to DuckDB database file
|
||||||
|
"""
|
||||||
|
self.db_path = db_path
|
||||||
|
self._conn: duckdb.DuckDBPyConnection | None = None
|
||||||
|
|
||||||
|
def connect(self) -> None:
|
||||||
|
"""Connect to the database."""
|
||||||
|
# Ensure parent directory exists
|
||||||
|
self.db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
self._conn = duckdb.connect(str(self.db_path))
|
||||||
|
logger.info("Connected to DuckDB", path=str(self.db_path))
|
||||||
|
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
"""Close database connection."""
|
||||||
|
if self._conn:
|
||||||
|
self._conn.close()
|
||||||
|
self._conn = None
|
||||||
|
logger.info("Disconnected from DuckDB")
|
||||||
|
|
||||||
|
def __enter__(self) -> "DataStorage":
|
||||||
|
"""Context manager entry."""
|
||||||
|
self.connect()
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
|
||||||
|
"""Context manager exit."""
|
||||||
|
self.disconnect()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def conn(self) -> duckdb.DuckDBPyConnection:
|
||||||
|
"""Get database connection."""
|
||||||
|
if self._conn is None:
|
||||||
|
raise RuntimeError("Not connected. Call connect() first.")
|
||||||
|
return self._conn
|
||||||
|
|
||||||
|
def initialize_schema(self) -> None:
|
||||||
|
"""Create all database tables and indexes."""
|
||||||
|
for schema in ALL_SCHEMAS:
|
||||||
|
# Execute each statement separately
|
||||||
|
for statement in schema.strip().split(";"):
|
||||||
|
statement = statement.strip()
|
||||||
|
if statement:
|
||||||
|
self.conn.execute(statement)
|
||||||
|
|
||||||
|
logger.info("Database schema initialized")
|
||||||
|
|
||||||
|
def insert_candles(self, candles: list[Candle], symbol: str, timeframe: str) -> int:
|
||||||
|
"""Insert candles into the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
candles: List of Candle objects to insert
|
||||||
|
symbol: Trading symbol (e.g., "BTCUSDT")
|
||||||
|
timeframe: Candle timeframe (e.g., "1h", "4h")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of candles inserted
|
||||||
|
"""
|
||||||
|
if not candles:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
# Prepare data for insertion
|
||||||
|
data = [
|
||||||
|
(
|
||||||
|
symbol,
|
||||||
|
timeframe,
|
||||||
|
c.timestamp,
|
||||||
|
float(c.open),
|
||||||
|
float(c.high),
|
||||||
|
float(c.low),
|
||||||
|
float(c.close),
|
||||||
|
float(c.volume),
|
||||||
|
)
|
||||||
|
for c in candles
|
||||||
|
]
|
||||||
|
|
||||||
|
# Use INSERT OR REPLACE to handle duplicates
|
||||||
|
self.conn.executemany(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO candles
|
||||||
|
(symbol, timeframe, timestamp, open, high, low, close, volume)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
data,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(
|
||||||
|
"Inserted candles",
|
||||||
|
symbol=symbol,
|
||||||
|
timeframe=timeframe,
|
||||||
|
count=len(candles),
|
||||||
|
)
|
||||||
|
return len(candles)
|
||||||
|
|
||||||
|
def get_candles(
|
||||||
|
self,
|
||||||
|
symbol: str,
|
||||||
|
timeframe: str,
|
||||||
|
start: datetime | None = None,
|
||||||
|
end: datetime | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> list[Candle]:
|
||||||
|
"""Retrieve candles from the database.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
timeframe: Candle timeframe
|
||||||
|
start: Start timestamp (inclusive)
|
||||||
|
end: End timestamp (inclusive)
|
||||||
|
limit: Maximum number of candles to return
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of Candle objects, oldest first
|
||||||
|
"""
|
||||||
|
query = """
|
||||||
|
SELECT timestamp, open, high, low, close, volume
|
||||||
|
FROM candles
|
||||||
|
WHERE symbol = ? AND timeframe = ?
|
||||||
|
"""
|
||||||
|
params: list[str | datetime | int] = [symbol, timeframe]
|
||||||
|
|
||||||
|
if start:
|
||||||
|
query += " AND timestamp >= ?"
|
||||||
|
params.append(start)
|
||||||
|
if end:
|
||||||
|
query += " AND timestamp <= ?"
|
||||||
|
params.append(end)
|
||||||
|
|
||||||
|
query += " ORDER BY timestamp ASC"
|
||||||
|
|
||||||
|
if limit:
|
||||||
|
query += " LIMIT ?"
|
||||||
|
params.append(limit)
|
||||||
|
|
||||||
|
result = self.conn.execute(query, params).fetchall()
|
||||||
|
|
||||||
|
from decimal import Decimal
|
||||||
|
|
||||||
|
return [
|
||||||
|
Candle(
|
||||||
|
timestamp=row[0],
|
||||||
|
open=Decimal(str(row[1])),
|
||||||
|
high=Decimal(str(row[2])),
|
||||||
|
low=Decimal(str(row[3])),
|
||||||
|
close=Decimal(str(row[4])),
|
||||||
|
volume=Decimal(str(row[5])),
|
||||||
|
)
|
||||||
|
for row in result
|
||||||
|
]
|
||||||
|
|
||||||
|
def get_latest_candle_timestamp(self, symbol: str, timeframe: str) -> datetime | None:
|
||||||
|
"""Get the timestamp of the most recent candle.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
timeframe: Candle timeframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Timestamp of latest candle, or None if no candles exist
|
||||||
|
"""
|
||||||
|
result = self.conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT MAX(timestamp) FROM candles
|
||||||
|
WHERE symbol = ? AND timeframe = ?
|
||||||
|
""",
|
||||||
|
[symbol, timeframe],
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
return result[0] if result and result[0] else None
|
||||||
|
|
||||||
|
def insert_funding_rate(self, rate: FundingRate) -> None:
|
||||||
|
"""Insert a funding rate record.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
rate: FundingRate object to insert
|
||||||
|
"""
|
||||||
|
self.conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT OR REPLACE INTO funding_rates
|
||||||
|
(symbol, funding_rate, funding_time, mark_price)
|
||||||
|
VALUES (?, ?, ?, ?)
|
||||||
|
""",
|
||||||
|
[
|
||||||
|
rate.symbol,
|
||||||
|
float(rate.funding_rate),
|
||||||
|
rate.funding_time,
|
||||||
|
float(rate.mark_price),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_candle_count(self, symbol: str, timeframe: str) -> int:
|
||||||
|
"""Get the number of candles stored for a symbol/timeframe.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
symbol: Trading symbol
|
||||||
|
timeframe: Candle timeframe
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Number of candles
|
||||||
|
"""
|
||||||
|
result = self.conn.execute(
|
||||||
|
"""
|
||||||
|
SELECT COUNT(*) FROM candles
|
||||||
|
WHERE symbol = ? AND timeframe = ?
|
||||||
|
""",
|
||||||
|
[symbol, timeframe],
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
return result[0] if result else 0
|
||||||
286
tests/test_data_fetcher.py
Normal file
286
tests/test_data_fetcher.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
"""Unit tests for DataFetcher (backfill and sync logic)."""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from decimal import Decimal
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradefinder.adapters.types import Candle
|
||||||
|
from tradefinder.data.fetcher import TIMEFRAME_MS, DataFetcher
|
||||||
|
from tradefinder.data.storage import DataStorage
|
||||||
|
|
||||||
|
|
||||||
|
def make_candle(timestamp: datetime) -> Candle:
|
||||||
|
"""Create a test candle at given timestamp."""
|
||||||
|
return Candle(
|
||||||
|
timestamp=timestamp,
|
||||||
|
open=Decimal("50000.00"),
|
||||||
|
high=Decimal("51000.00"),
|
||||||
|
low=Decimal("49000.00"),
|
||||||
|
close=Decimal("50500.00"),
|
||||||
|
volume=Decimal("1000.00"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTimeframeMappings:
|
||||||
|
"""Tests for timeframe constant mappings."""
|
||||||
|
|
||||||
|
def test_common_timeframes_are_defined(self) -> None:
|
||||||
|
"""All expected timeframes have millisecond mappings."""
|
||||||
|
expected = ["1m", "5m", "15m", "30m", "1h", "4h", "1d", "1w"]
|
||||||
|
for tf in expected:
|
||||||
|
assert tf in TIMEFRAME_MS
|
||||||
|
assert TIMEFRAME_MS[tf] > 0
|
||||||
|
|
||||||
|
def test_timeframe_values_are_correct(self) -> None:
|
||||||
|
"""Timeframe millisecond values are accurate."""
|
||||||
|
assert TIMEFRAME_MS["1m"] == 60 * 1000
|
||||||
|
assert TIMEFRAME_MS["1h"] == 60 * 60 * 1000
|
||||||
|
assert TIMEFRAME_MS["4h"] == 4 * 60 * 60 * 1000
|
||||||
|
assert TIMEFRAME_MS["1d"] == 24 * 60 * 60 * 1000
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataFetcherBackfill:
|
||||||
|
"""Tests for backfill_candles functionality."""
|
||||||
|
|
||||||
|
async def test_backfill_fetches_and_stores_candles(self) -> None:
|
||||||
|
"""Candles are fetched from adapter and stored."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
|
||||||
|
# Mock adapter
|
||||||
|
adapter = MagicMock()
|
||||||
|
start = datetime(2024, 1, 1, 0, 0, 0)
|
||||||
|
end = datetime(2024, 1, 1, 4, 0, 0)
|
||||||
|
|
||||||
|
# Return 4 candles then empty (end of data)
|
||||||
|
candles = [make_candle(start + timedelta(hours=i)) for i in range(4)]
|
||||||
|
adapter.get_candles = AsyncMock(side_effect=[candles, []])
|
||||||
|
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
|
||||||
|
total = await fetcher.backfill_candles("BTCUSDT", "1h", start, end)
|
||||||
|
|
||||||
|
assert total == 4
|
||||||
|
assert storage.get_candle_count("BTCUSDT", "1h") == 4
|
||||||
|
|
||||||
|
async def test_backfill_uses_default_end_date(self) -> None:
|
||||||
|
"""End date defaults to now if not provided."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
|
||||||
|
adapter = MagicMock()
|
||||||
|
adapter.get_candles = AsyncMock(return_value=[]) # No data
|
||||||
|
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
|
||||||
|
start = datetime.now() - timedelta(hours=1)
|
||||||
|
await fetcher.backfill_candles("BTCUSDT", "1h", start)
|
||||||
|
|
||||||
|
# Should have been called (even if no data returned)
|
||||||
|
adapter.get_candles.assert_called()
|
||||||
|
|
||||||
|
async def test_backfill_raises_on_unknown_timeframe(self) -> None:
|
||||||
|
"""ValueError raised for unknown timeframe."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
|
||||||
|
adapter = MagicMock()
|
||||||
|
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Unknown timeframe"):
|
||||||
|
await fetcher.backfill_candles(
|
||||||
|
"BTCUSDT",
|
||||||
|
"invalid_tf",
|
||||||
|
datetime(2024, 1, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
async def test_backfill_handles_empty_response(self) -> None:
|
||||||
|
"""Empty response from adapter is handled gracefully."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
|
||||||
|
adapter = MagicMock()
|
||||||
|
adapter.get_candles = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
|
||||||
|
total = await fetcher.backfill_candles(
|
||||||
|
"BTCUSDT",
|
||||||
|
"1h",
|
||||||
|
datetime(2024, 1, 1),
|
||||||
|
datetime(2024, 1, 2),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert total == 0
|
||||||
|
|
||||||
|
async def test_backfill_respects_batch_size(self) -> None:
|
||||||
|
"""Batch size is respected and capped at 1500."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
|
||||||
|
adapter = MagicMock()
|
||||||
|
adapter.get_candles = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
|
||||||
|
# Request with batch_size > 1500
|
||||||
|
await fetcher.backfill_candles(
|
||||||
|
"BTCUSDT",
|
||||||
|
"1h",
|
||||||
|
datetime(2024, 1, 1),
|
||||||
|
datetime(2024, 1, 2),
|
||||||
|
batch_size=2000,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify limit was capped at 1500
|
||||||
|
call_kwargs = adapter.get_candles.call_args.kwargs
|
||||||
|
assert call_kwargs["limit"] == 1500
|
||||||
|
|
||||||
|
async def test_backfill_paginates_correctly(self) -> None:
|
||||||
|
"""Multiple batches are fetched for large date ranges."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
|
||||||
|
adapter = MagicMock()
|
||||||
|
start = datetime(2024, 1, 1, 0, 0, 0)
|
||||||
|
|
||||||
|
# Return 3 candles per batch, simulate 2 full batches + partial
|
||||||
|
batch1 = [make_candle(start + timedelta(hours=i)) for i in range(3)]
|
||||||
|
batch2 = [make_candle(start + timedelta(hours=i + 3)) for i in range(3)]
|
||||||
|
batch3 = [make_candle(start + timedelta(hours=6))] # Partial batch
|
||||||
|
|
||||||
|
adapter.get_candles = AsyncMock(side_effect=[batch1, batch2, batch3])
|
||||||
|
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
|
||||||
|
total = await fetcher.backfill_candles(
|
||||||
|
"BTCUSDT",
|
||||||
|
"1h",
|
||||||
|
start,
|
||||||
|
start + timedelta(hours=10),
|
||||||
|
batch_size=3,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3 + 3 + 1 = 7 candles
|
||||||
|
assert total == 7
|
||||||
|
assert adapter.get_candles.call_count == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataFetcherSync:
|
||||||
|
"""Tests for sync_candles functionality."""
|
||||||
|
|
||||||
|
async def test_sync_from_latest_timestamp(self) -> None:
|
||||||
|
"""Sync starts from last stored candle."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
|
||||||
|
adapter = MagicMock()
|
||||||
|
adapter.get_candles = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
# Pre-populate with some candles
|
||||||
|
base_ts = datetime(2024, 1, 1, 0, 0, 0)
|
||||||
|
existing = [make_candle(base_ts + timedelta(hours=i)) for i in range(5)]
|
||||||
|
storage.insert_candles(existing, "BTCUSDT", "1h")
|
||||||
|
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
await fetcher.sync_candles("BTCUSDT", "1h")
|
||||||
|
|
||||||
|
# Verify sync started from after the last candle
|
||||||
|
call_kwargs = adapter.get_candles.call_args.kwargs
|
||||||
|
start_time = call_kwargs.get("start_time")
|
||||||
|
# Should be after hour 4 (the last existing candle)
|
||||||
|
expected_start_ms = int((base_ts + timedelta(hours=5)).timestamp() * 1000)
|
||||||
|
assert start_time >= expected_start_ms
|
||||||
|
|
||||||
|
async def test_sync_with_no_existing_data_uses_lookback(self) -> None:
|
||||||
|
"""Sync uses lookback when no existing data."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
|
||||||
|
adapter = MagicMock()
|
||||||
|
adapter.get_candles = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
await fetcher.sync_candles("BTCUSDT", "1h", lookback_days=7)
|
||||||
|
|
||||||
|
# Should have called get_candles
|
||||||
|
adapter.get_candles.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataFetcherLatest:
|
||||||
|
"""Tests for fetch_latest_candles functionality."""
|
||||||
|
|
||||||
|
async def test_fetch_latest_stores_candles(self) -> None:
|
||||||
|
"""Latest candles are fetched and stored."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
|
||||||
|
adapter = MagicMock()
|
||||||
|
now = datetime.now()
|
||||||
|
candles = [make_candle(now - timedelta(hours=i)) for i in range(5)]
|
||||||
|
adapter.get_candles = AsyncMock(return_value=candles)
|
||||||
|
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
count = await fetcher.fetch_latest_candles("BTCUSDT", "1h", limit=5)
|
||||||
|
|
||||||
|
assert count == 5
|
||||||
|
assert storage.get_candle_count("BTCUSDT", "1h") == 5
|
||||||
|
|
||||||
|
async def test_fetch_latest_with_empty_response(self) -> None:
|
||||||
|
"""Empty response returns 0."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
|
||||||
|
adapter = MagicMock()
|
||||||
|
adapter.get_candles = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
count = await fetcher.fetch_latest_candles("BTCUSDT", "1h")
|
||||||
|
|
||||||
|
assert count == 0
|
||||||
|
|
||||||
|
async def test_fetch_latest_respects_limit(self) -> None:
|
||||||
|
"""Limit parameter is passed to adapter."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
|
||||||
|
adapter = MagicMock()
|
||||||
|
adapter.get_candles = AsyncMock(return_value=[])
|
||||||
|
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
fetcher = DataFetcher(adapter, storage)
|
||||||
|
await fetcher.fetch_latest_candles("BTCUSDT", "1h", limit=50)
|
||||||
|
|
||||||
|
call_kwargs = adapter.get_candles.call_args.kwargs
|
||||||
|
assert call_kwargs["limit"] == 50
|
||||||
293
tests/test_data_storage.py
Normal file
293
tests/test_data_storage.py
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
"""Unit tests for DataStorage (DuckDB operations)."""
|
||||||
|
|
||||||
|
import tempfile
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from decimal import Decimal
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from tradefinder.adapters.types import Candle, FundingRate
|
||||||
|
from tradefinder.data.storage import DataStorage
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataStorageConnection:
|
||||||
|
"""Tests for connection lifecycle."""
|
||||||
|
|
||||||
|
def test_connect_creates_database_file(self) -> None:
|
||||||
|
"""Database file is created on connect."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
storage = DataStorage(db_path)
|
||||||
|
|
||||||
|
storage.connect()
|
||||||
|
assert db_path.exists()
|
||||||
|
storage.disconnect()
|
||||||
|
|
||||||
|
def test_connect_creates_parent_directories(self) -> None:
|
||||||
|
"""Parent directories are created if they don't exist."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "nested" / "dir" / "test.duckdb"
|
||||||
|
storage = DataStorage(db_path)
|
||||||
|
|
||||||
|
storage.connect()
|
||||||
|
assert db_path.parent.exists()
|
||||||
|
storage.disconnect()
|
||||||
|
|
||||||
|
def test_disconnect_closes_connection(self) -> None:
|
||||||
|
"""Disconnect properly closes the connection."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
storage = DataStorage(db_path)
|
||||||
|
|
||||||
|
storage.connect()
|
||||||
|
storage.disconnect()
|
||||||
|
assert storage._conn is None
|
||||||
|
|
||||||
|
def test_context_manager_lifecycle(self) -> None:
|
||||||
|
"""Context manager properly opens and closes connection."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
assert storage._conn is not None
|
||||||
|
assert storage._conn is None
|
||||||
|
|
||||||
|
def test_conn_property_raises_when_not_connected(self) -> None:
|
||||||
|
"""Accessing conn property raises when not connected."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
storage = DataStorage(db_path)
|
||||||
|
|
||||||
|
with pytest.raises(RuntimeError, match="Not connected"):
|
||||||
|
_ = storage.conn
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataStorageSchema:
|
||||||
|
"""Tests for schema initialization."""
|
||||||
|
|
||||||
|
def test_initialize_schema_creates_tables(self) -> None:
|
||||||
|
"""All tables are created by initialize_schema."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
# Verify tables exist
|
||||||
|
result = storage.conn.execute(
|
||||||
|
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'"
|
||||||
|
).fetchall()
|
||||||
|
tables = {row[0] for row in result}
|
||||||
|
|
||||||
|
assert "candles" in tables
|
||||||
|
assert "trades" in tables
|
||||||
|
assert "funding_rates" in tables
|
||||||
|
|
||||||
|
def test_initialize_schema_is_idempotent(self) -> None:
|
||||||
|
"""Calling initialize_schema multiple times is safe."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
storage.initialize_schema() # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataStorageCandles:
|
||||||
|
"""Tests for candle CRUD operations."""
|
||||||
|
|
||||||
|
def _make_candle(self, timestamp: datetime) -> Candle:
|
||||||
|
"""Create a test candle."""
|
||||||
|
return Candle(
|
||||||
|
timestamp=timestamp,
|
||||||
|
open=Decimal("50000.00"),
|
||||||
|
high=Decimal("51000.00"),
|
||||||
|
low=Decimal("49000.00"),
|
||||||
|
close=Decimal("50500.00"),
|
||||||
|
volume=Decimal("1000.50"),
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_insert_candles_stores_data(self) -> None:
|
||||||
|
"""Candles are stored correctly."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
ts = datetime(2024, 1, 1, 12, 0, 0)
|
||||||
|
candles = [self._make_candle(ts)]
|
||||||
|
|
||||||
|
inserted = storage.insert_candles(candles, "BTCUSDT", "1h")
|
||||||
|
assert inserted == 1
|
||||||
|
|
||||||
|
def test_insert_candles_empty_list_returns_zero(self) -> None:
|
||||||
|
"""Inserting empty list returns 0."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
inserted = storage.insert_candles([], "BTCUSDT", "1h")
|
||||||
|
assert inserted == 0
|
||||||
|
|
||||||
|
def test_insert_candles_handles_duplicates(self) -> None:
|
||||||
|
"""Duplicate candles are replaced (upsert behavior)."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
ts = datetime(2024, 1, 1, 12, 0, 0)
|
||||||
|
candles = [self._make_candle(ts)]
|
||||||
|
|
||||||
|
storage.insert_candles(candles, "BTCUSDT", "1h")
|
||||||
|
storage.insert_candles(candles, "BTCUSDT", "1h") # Duplicate
|
||||||
|
|
||||||
|
count = storage.get_candle_count("BTCUSDT", "1h")
|
||||||
|
assert count == 1 # Not 2
|
||||||
|
|
||||||
|
def test_get_candles_retrieves_data(self) -> None:
|
||||||
|
"""Candles can be retrieved correctly."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
ts = datetime(2024, 1, 1, 12, 0, 0)
|
||||||
|
candles = [self._make_candle(ts)]
|
||||||
|
storage.insert_candles(candles, "BTCUSDT", "1h")
|
||||||
|
|
||||||
|
result = storage.get_candles("BTCUSDT", "1h")
|
||||||
|
assert len(result) == 1
|
||||||
|
assert result[0].open == Decimal("50000.00")
|
||||||
|
assert result[0].close == Decimal("50500.00")
|
||||||
|
|
||||||
|
def test_get_candles_respects_time_range(self) -> None:
|
||||||
|
"""Candles are filtered by start/end time."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
base_ts = datetime(2024, 1, 1, 0, 0, 0)
|
||||||
|
candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(10)]
|
||||||
|
storage.insert_candles(candles, "BTCUSDT", "1h")
|
||||||
|
|
||||||
|
# Query middle range
|
||||||
|
start = base_ts + timedelta(hours=3)
|
||||||
|
end = base_ts + timedelta(hours=6)
|
||||||
|
result = storage.get_candles("BTCUSDT", "1h", start=start, end=end)
|
||||||
|
|
||||||
|
assert len(result) == 4 # hours 3, 4, 5, 6
|
||||||
|
|
||||||
|
def test_get_candles_respects_limit(self) -> None:
|
||||||
|
"""Candle count is limited correctly."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
base_ts = datetime(2024, 1, 1, 0, 0, 0)
|
||||||
|
candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(10)]
|
||||||
|
storage.insert_candles(candles, "BTCUSDT", "1h")
|
||||||
|
|
||||||
|
result = storage.get_candles("BTCUSDT", "1h", limit=5)
|
||||||
|
assert len(result) == 5
|
||||||
|
|
||||||
|
def test_get_candles_returns_ascending_order(self) -> None:
|
||||||
|
"""Candles are returned oldest first."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
base_ts = datetime(2024, 1, 1, 0, 0, 0)
|
||||||
|
# Insert in reverse order
|
||||||
|
candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(5)]
|
||||||
|
storage.insert_candles(candles[::-1], "BTCUSDT", "1h")
|
||||||
|
|
||||||
|
result = storage.get_candles("BTCUSDT", "1h")
|
||||||
|
timestamps = [c.timestamp for c in result]
|
||||||
|
assert timestamps == sorted(timestamps)
|
||||||
|
|
||||||
|
def test_get_candles_filters_by_symbol(self) -> None:
|
||||||
|
"""Candles are filtered by symbol."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
ts = datetime(2024, 1, 1, 12, 0, 0)
|
||||||
|
candles = [self._make_candle(ts)]
|
||||||
|
|
||||||
|
storage.insert_candles(candles, "BTCUSDT", "1h")
|
||||||
|
storage.insert_candles(candles, "ETHUSDT", "1h")
|
||||||
|
|
||||||
|
btc_result = storage.get_candles("BTCUSDT", "1h")
|
||||||
|
eth_result = storage.get_candles("ETHUSDT", "1h")
|
||||||
|
|
||||||
|
assert len(btc_result) == 1
|
||||||
|
assert len(eth_result) == 1
|
||||||
|
|
||||||
|
def test_get_latest_candle_timestamp(self) -> None:
|
||||||
|
"""Latest timestamp is returned correctly."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
base_ts = datetime(2024, 1, 1, 0, 0, 0)
|
||||||
|
candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(5)]
|
||||||
|
storage.insert_candles(candles, "BTCUSDT", "1h")
|
||||||
|
|
||||||
|
latest = storage.get_latest_candle_timestamp("BTCUSDT", "1h")
|
||||||
|
assert latest == base_ts + timedelta(hours=4)
|
||||||
|
|
||||||
|
def test_get_latest_candle_timestamp_returns_none_when_empty(self) -> None:
|
||||||
|
"""None is returned when no candles exist."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
latest = storage.get_latest_candle_timestamp("BTCUSDT", "1h")
|
||||||
|
assert latest is None
|
||||||
|
|
||||||
|
def test_get_candle_count(self) -> None:
|
||||||
|
"""Candle count is accurate."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
base_ts = datetime(2024, 1, 1, 0, 0, 0)
|
||||||
|
candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(7)]
|
||||||
|
storage.insert_candles(candles, "BTCUSDT", "1h")
|
||||||
|
|
||||||
|
count = storage.get_candle_count("BTCUSDT", "1h")
|
||||||
|
assert count == 7
|
||||||
|
|
||||||
|
|
||||||
|
class TestDataStorageFundingRates:
|
||||||
|
"""Tests for funding rate operations."""
|
||||||
|
|
||||||
|
def test_insert_funding_rate(self) -> None:
|
||||||
|
"""Funding rate is stored correctly."""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test.duckdb"
|
||||||
|
with DataStorage(db_path) as storage:
|
||||||
|
storage.initialize_schema()
|
||||||
|
|
||||||
|
rate = FundingRate(
|
||||||
|
symbol="BTCUSDT",
|
||||||
|
funding_rate=Decimal("0.0001"),
|
||||||
|
funding_time=datetime(2024, 1, 1, 8, 0, 0),
|
||||||
|
mark_price=Decimal("50000.00"),
|
||||||
|
)
|
||||||
|
|
||||||
|
storage.insert_funding_rate(rate)
|
||||||
|
|
||||||
|
# Verify stored
|
||||||
|
result = storage.conn.execute(
|
||||||
|
"SELECT symbol, funding_rate FROM funding_rates"
|
||||||
|
).fetchone()
|
||||||
|
assert result is not None
|
||||||
|
assert result[0] == "BTCUSDT"
|
||||||
295
tests/test_spot_adapter.py
Normal file
295
tests/test_spot_adapter.py
Normal file
@@ -0,0 +1,295 @@
|
|||||||
|
"""Unit tests for BinanceSpotAdapter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
from datetime import UTC
|
||||||
|
from decimal import Decimal
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, Mock
|
||||||
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import SecretStr
|
||||||
|
|
||||||
|
from tradefinder.adapters.base import AuthenticationError, ExchangeError
|
||||||
|
from tradefinder.adapters.binance_spot import BinanceSpotAdapter, SpotBalance
|
||||||
|
from tradefinder.core.config import BinanceSettings, Settings, TradingMode
|
||||||
|
|
||||||
|
|
||||||
|
def _make_response(status_code: int, payload: Any) -> Mock:
|
||||||
|
"""Create a mock HTTP response."""
|
||||||
|
response = Mock()
|
||||||
|
response.status_code = status_code
|
||||||
|
response.json.return_value = payload
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def settings() -> Settings:
|
||||||
|
"""Create test settings without loading .env file."""
|
||||||
|
return Settings(
|
||||||
|
trading_mode=TradingMode.TESTNET,
|
||||||
|
binance=BinanceSettings(
|
||||||
|
testnet_api_key=SecretStr("test-key"),
|
||||||
|
testnet_secret=SecretStr("test-secret"),
|
||||||
|
),
|
||||||
|
_env_file=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def adapter(settings: Settings) -> BinanceSpotAdapter:
|
||||||
|
"""Create adapter with test settings."""
|
||||||
|
return BinanceSpotAdapter(settings)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_http_client() -> AsyncMock:
|
||||||
|
"""Create mock HTTP client."""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.get = AsyncMock()
|
||||||
|
client.aclose = AsyncMock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
|
||||||
|
def _sample_account_response() -> dict[str, Any]:
|
||||||
|
"""Sample Binance spot account response."""
|
||||||
|
return {
|
||||||
|
"makerCommission": 10,
|
||||||
|
"takerCommission": 10,
|
||||||
|
"balances": [
|
||||||
|
{"asset": "BTC", "free": "0.50000000", "locked": "0.10000000"},
|
||||||
|
{"asset": "ETH", "free": "5.00000000", "locked": "0.00000000"},
|
||||||
|
{"asset": "USDT", "free": "1000.00000000", "locked": "500.00000000"},
|
||||||
|
{"asset": "BNB", "free": "0.00000000", "locked": "0.00000000"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestSpotBalanceDataclass:
|
||||||
|
"""Tests for SpotBalance dataclass."""
|
||||||
|
|
||||||
|
def test_total_property(self) -> None:
|
||||||
|
"""Total equals free plus locked."""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
balance = SpotBalance(
|
||||||
|
asset="BTC",
|
||||||
|
free=Decimal("1.5"),
|
||||||
|
locked=Decimal("0.5"),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
assert balance.total == Decimal("2.0")
|
||||||
|
|
||||||
|
def test_zero_balance_total(self) -> None:
|
||||||
|
"""Zero balance returns zero total."""
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
balance = SpotBalance(
|
||||||
|
asset="XRP",
|
||||||
|
free=Decimal("0"),
|
||||||
|
locked=Decimal("0"),
|
||||||
|
updated_at=datetime.now(UTC),
|
||||||
|
)
|
||||||
|
assert balance.total == Decimal("0")
|
||||||
|
|
||||||
|
|
||||||
|
class TestBinanceSpotAdapterInit:
|
||||||
|
"""Initialization and credential tests."""
|
||||||
|
|
||||||
|
def test_credentials_are_accessible(self, adapter: BinanceSpotAdapter) -> None:
|
||||||
|
"""API key and secret are exposed through properties."""
|
||||||
|
assert adapter._api_key == "test-key"
|
||||||
|
assert adapter._secret == "test-secret"
|
||||||
|
|
||||||
|
def test_base_url_uses_testnet(self, settings: Settings) -> None:
|
||||||
|
"""Testnet mode uses testnet URL."""
|
||||||
|
adapter = BinanceSpotAdapter(settings)
|
||||||
|
assert "testnet" in adapter.base_url.lower()
|
||||||
|
|
||||||
|
def test_sign_generates_valid_signature(self, adapter: BinanceSpotAdapter) -> None:
|
||||||
|
"""HMAC-SHA256 signature matches stdlib implementation."""
|
||||||
|
params = {"symbol": "BTCUSDT", "timestamp": "1234567890"}
|
||||||
|
expected = hmac.new(
|
||||||
|
adapter._secret.encode("utf-8"),
|
||||||
|
urlencode(params).encode("utf-8"),
|
||||||
|
hashlib.sha256,
|
||||||
|
).hexdigest()
|
||||||
|
assert adapter._sign(params) == expected
|
||||||
|
|
||||||
|
|
||||||
|
class TestBinanceSpotAdapterConnection:
|
||||||
|
"""Connection lifecycle tests."""
|
||||||
|
|
||||||
|
async def test_connect_creates_client(self, adapter: BinanceSpotAdapter) -> None:
|
||||||
|
"""Connect creates HTTP client."""
|
||||||
|
mock_client = AsyncMock()
|
||||||
|
mock_client.get = AsyncMock(return_value=_make_response(200, _sample_account_response()))
|
||||||
|
mock_client.aclose = AsyncMock()
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
original_client = httpx.AsyncClient
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Patch httpx.AsyncClient
|
||||||
|
httpx.AsyncClient = lambda **kwargs: mock_client
|
||||||
|
await adapter.connect()
|
||||||
|
assert adapter._client is mock_client
|
||||||
|
finally:
|
||||||
|
httpx.AsyncClient = original_client
|
||||||
|
await adapter.disconnect()
|
||||||
|
|
||||||
|
async def test_disconnect_closes_client(
|
||||||
|
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
"""Disconnect closes HTTP client."""
|
||||||
|
adapter._client = mock_http_client
|
||||||
|
|
||||||
|
await adapter.disconnect()
|
||||||
|
|
||||||
|
mock_http_client.aclose.assert_awaited_once()
|
||||||
|
assert adapter._client is None
|
||||||
|
|
||||||
|
async def test_disconnect_when_not_connected(self, adapter: BinanceSpotAdapter) -> None:
|
||||||
|
"""Disconnect when not connected is safe."""
|
||||||
|
await adapter.disconnect() # Should not raise
|
||||||
|
|
||||||
|
|
||||||
|
class TestBinanceSpotAdapterRequest:
|
||||||
|
"""Request handling tests."""
|
||||||
|
|
||||||
|
async def test_request_adds_auth_headers(
|
||||||
|
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
"""Requests include X-MBX-APIKEY header."""
|
||||||
|
adapter._client = mock_http_client
|
||||||
|
response = _make_response(200, {"ok": True})
|
||||||
|
mock_http_client.get.return_value = response
|
||||||
|
|
||||||
|
result = await adapter._request("GET", "/api/v3/test")
|
||||||
|
|
||||||
|
assert result == {"ok": True}
|
||||||
|
mock_http_client.get.assert_awaited_once()
|
||||||
|
called_headers = mock_http_client.get.call_args[1]["headers"]
|
||||||
|
assert called_headers["X-MBX-APIKEY"] == "test-key"
|
||||||
|
|
||||||
|
async def test_request_adds_signature_when_signed(
|
||||||
|
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
"""Signed requests include timestamp and signature."""
|
||||||
|
adapter._client = mock_http_client
|
||||||
|
response = _make_response(200, {"ok": True})
|
||||||
|
mock_http_client.get.return_value = response
|
||||||
|
|
||||||
|
await adapter._request("GET", "/api/v3/test", params={}, signed=True)
|
||||||
|
|
||||||
|
call_kwargs = mock_http_client.get.call_args[1]
|
||||||
|
params = call_kwargs["params"]
|
||||||
|
assert "timestamp" in params
|
||||||
|
assert "signature" in params
|
||||||
|
assert "recvWindow" in params
|
||||||
|
|
||||||
|
async def test_request_raises_when_not_connected(self, adapter: BinanceSpotAdapter) -> None:
|
||||||
|
"""Request raises when not connected."""
|
||||||
|
with pytest.raises(ExchangeError, match="Not connected"):
|
||||||
|
await adapter._request("GET", "/api/v3/test")
|
||||||
|
|
||||||
|
async def test_request_handles_auth_error(
|
||||||
|
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
"""Authentication errors raise AuthenticationError."""
|
||||||
|
adapter._client = mock_http_client
|
||||||
|
response = _make_response(401, {"code": -1022, "msg": "Invalid signature"})
|
||||||
|
mock_http_client.get.return_value = response
|
||||||
|
|
||||||
|
with pytest.raises(AuthenticationError):
|
||||||
|
await adapter._request("GET", "/api/v3/test")
|
||||||
|
|
||||||
|
async def test_request_handles_general_error(
|
||||||
|
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
"""General API errors raise ExchangeError."""
|
||||||
|
adapter._client = mock_http_client
|
||||||
|
response = _make_response(400, {"code": -1100, "msg": "Illegal characters"})
|
||||||
|
mock_http_client.get.return_value = response
|
||||||
|
|
||||||
|
with pytest.raises(ExchangeError, match="Illegal characters"):
|
||||||
|
await adapter._request("GET", "/api/v3/test")
|
||||||
|
|
||||||
|
|
||||||
|
class TestBinanceSpotAdapterBalance:
|
||||||
|
"""Balance retrieval tests."""
|
||||||
|
|
||||||
|
async def test_get_balance_returns_asset(
|
||||||
|
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
"""get_balance returns balance for specific asset."""
|
||||||
|
adapter._client = mock_http_client
|
||||||
|
response = _make_response(200, _sample_account_response())
|
||||||
|
mock_http_client.get.return_value = response
|
||||||
|
|
||||||
|
balance = await adapter.get_balance("BTC")
|
||||||
|
|
||||||
|
assert balance.asset == "BTC"
|
||||||
|
assert balance.free == Decimal("0.50000000")
|
||||||
|
assert balance.locked == Decimal("0.10000000")
|
||||||
|
assert balance.total == Decimal("0.60000000")
|
||||||
|
|
||||||
|
async def test_get_balance_raises_when_not_found(
|
||||||
|
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
"""get_balance raises when asset not found."""
|
||||||
|
adapter._client = mock_http_client
|
||||||
|
response = _make_response(200, _sample_account_response())
|
||||||
|
mock_http_client.get.return_value = response
|
||||||
|
|
||||||
|
with pytest.raises(ExchangeError, match="not found"):
|
||||||
|
await adapter.get_balance("UNKNOWN")
|
||||||
|
|
||||||
|
async def test_get_all_balances_returns_non_zero(
|
||||||
|
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
"""get_all_balances returns only non-zero balances."""
|
||||||
|
adapter._client = mock_http_client
|
||||||
|
response = _make_response(200, _sample_account_response())
|
||||||
|
mock_http_client.get.return_value = response
|
||||||
|
|
||||||
|
balances = await adapter.get_all_balances()
|
||||||
|
|
||||||
|
# BTC, ETH, USDT have non-zero balances; BNB is zero
|
||||||
|
assets = {b.asset for b in balances}
|
||||||
|
assert "BTC" in assets
|
||||||
|
assert "ETH" in assets
|
||||||
|
assert "USDT" in assets
|
||||||
|
assert "BNB" not in assets
|
||||||
|
|
||||||
|
async def test_get_all_balances_with_min_balance(
|
||||||
|
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
"""get_all_balances filters by minimum balance."""
|
||||||
|
adapter._client = mock_http_client
|
||||||
|
response = _make_response(200, _sample_account_response())
|
||||||
|
mock_http_client.get.return_value = response
|
||||||
|
|
||||||
|
# min_balance of 100 should only include USDT (1500 total)
|
||||||
|
balances = await adapter.get_all_balances(min_balance=Decimal("100"))
|
||||||
|
|
||||||
|
assets = {b.asset for b in balances}
|
||||||
|
assert "USDT" in assets
|
||||||
|
assert "BTC" not in assets # Only 0.6 total
|
||||||
|
assert "ETH" not in assets # Only 5.0 total
|
||||||
|
|
||||||
|
async def test_get_all_balances_empty_response(
|
||||||
|
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
|
||||||
|
) -> None:
|
||||||
|
"""get_all_balances handles empty balances."""
|
||||||
|
adapter._client = mock_http_client
|
||||||
|
response = _make_response(200, {"balances": []})
|
||||||
|
mock_http_client.get.return_value = response
|
||||||
|
|
||||||
|
balances = await adapter.get_all_balances()
|
||||||
|
|
||||||
|
assert balances == []
|
||||||
Reference in New Issue
Block a user