From 7d63e43b7b495582ef7c6b54307990ffef395184 Mon Sep 17 00:00:00 2001 From: bnair123 Date: Sat, 27 Dec 2025 14:38:26 +0400 Subject: [PATCH] Add data layer (DuckDB storage, fetcher) and spot adapter with tests - Add DataStorage class for DuckDB-based market data persistence - Add DataFetcher for historical candle backfill and sync operations - Add BinanceSpotAdapter for spot wallet balance queries - Add binance_spot_base_url to Settings for spot testnet support - Add comprehensive unit tests (50 new tests, 82 total) - Coverage increased from 62% to 86% --- src/tradefinder/adapters/__init__.py | 6 +- src/tradefinder/adapters/binance_spot.py | 214 ++++++++++++++++ src/tradefinder/core/config.py | 14 ++ src/tradefinder/data/__init__.py | 25 ++ src/tradefinder/data/fetcher.py | 207 ++++++++++++++++ src/tradefinder/data/schemas.py | 54 +++++ src/tradefinder/data/storage.py | 242 +++++++++++++++++++ tests/test_data_fetcher.py | 286 ++++++++++++++++++++++ tests/test_data_storage.py | 293 ++++++++++++++++++++++ tests/test_spot_adapter.py | 295 +++++++++++++++++++++++ 10 files changed, 1635 insertions(+), 1 deletion(-) create mode 100644 src/tradefinder/adapters/binance_spot.py create mode 100644 src/tradefinder/data/fetcher.py create mode 100644 src/tradefinder/data/schemas.py create mode 100644 src/tradefinder/data/storage.py create mode 100644 tests/test_data_fetcher.py create mode 100644 tests/test_data_storage.py create mode 100644 tests/test_spot_adapter.py diff --git a/src/tradefinder/adapters/__init__.py b/src/tradefinder/adapters/__init__.py index 08442f0..1881a1f 100644 --- a/src/tradefinder/adapters/__init__.py +++ b/src/tradefinder/adapters/__init__.py @@ -5,9 +5,10 @@ connecting to cryptocurrency exchanges. Supported Exchanges: - Binance USDⓈ-M Perpetual Futures (primary) + - Binance Spot (balance checking only) Usage: - from tradefinder.adapters import BinanceUSDMAdapter + from tradefinder.adapters import BinanceUSDMAdapter, BinanceSpotAdapter from tradefinder.adapters.types import OrderRequest, Side, PositionSide, OrderType adapter = BinanceUSDMAdapter(settings) @@ -23,6 +24,7 @@ from tradefinder.adapters.base import ( OrderValidationError, RateLimitError, ) +from tradefinder.adapters.binance_spot import BinanceSpotAdapter, SpotBalance from tradefinder.adapters.binance_usdm import BinanceUSDMAdapter from tradefinder.adapters.types import ( AccountBalance, @@ -45,6 +47,7 @@ __all__ = [ "ExchangeAdapter", # Implementations "BinanceUSDMAdapter", + "BinanceSpotAdapter", # Types "AccountBalance", "Candle", @@ -57,6 +60,7 @@ __all__ = [ "Position", "PositionSide", "Side", + "SpotBalance", "SymbolInfo", "TimeInForce", # Exceptions diff --git a/src/tradefinder/adapters/binance_spot.py b/src/tradefinder/adapters/binance_spot.py new file mode 100644 index 0000000..07d2077 --- /dev/null +++ b/src/tradefinder/adapters/binance_spot.py @@ -0,0 +1,214 @@ +"""Binance Spot API adapter for balance checking. + +Provides read-only access to spot wallet balances. +Uses the same API credentials as the futures adapter. +""" + +import hashlib +import hmac +import time +from dataclasses import dataclass +from datetime import UTC, datetime +from decimal import Decimal +from typing import Any +from urllib.parse import urlencode + +import httpx +import structlog + +from tradefinder.adapters.base import AuthenticationError, ExchangeError +from tradefinder.core.config import Settings + +logger = structlog.get_logger(__name__) + + +@dataclass +class SpotBalance: + """Spot wallet balance for an asset.""" + + asset: str + free: Decimal + locked: Decimal + updated_at: datetime + + @property + def total(self) -> Decimal: + """Total balance (free + locked).""" + return self.free + self.locked + + +class BinanceSpotAdapter: + """Binance Spot API adapter for balance checking. + + This adapter connects to Binance Spot (either testnet or production) + and provides read-only access to spot wallet balances. + + Usage: + settings = get_settings() + adapter = BinanceSpotAdapter(settings) + await adapter.connect() + balances = await adapter.get_all_balances() + await adapter.disconnect() + """ + + def __init__(self, settings: Settings) -> None: + """Initialize adapter with settings. + + Args: + settings: Application settings containing API credentials + """ + self.settings = settings + self.base_url = settings.binance_spot_base_url + self._client: httpx.AsyncClient | None = None + self._recv_window = 5000 + + @property + def _api_key(self) -> str: + """Get active API key.""" + key = self.settings.get_active_api_key() + if key is None: + raise AuthenticationError("No API key configured for current trading mode") + return key.get_secret_value() + + @property + def _secret(self) -> str: + """Get active secret.""" + secret = self.settings.get_active_secret() + if secret is None: + raise AuthenticationError("No secret configured for current trading mode") + return secret.get_secret_value() + + def _sign(self, params: dict[str, Any]) -> str: + """Generate HMAC-SHA256 signature for request.""" + query_string = urlencode(params) + signature = hmac.new( + self._secret.encode("utf-8"), + query_string.encode("utf-8"), + hashlib.sha256, + ).hexdigest() + return signature + + def _get_timestamp(self) -> int: + """Get current timestamp in milliseconds.""" + return int(time.time() * 1000) + + async def _request( + self, + method: str, + endpoint: str, + params: dict[str, Any] | None = None, + signed: bool = False, + ) -> Any: + """Make HTTP request to Binance Spot API.""" + if self._client is None: + raise ExchangeError("Not connected. Call connect() first.") + + params = params or {} + headers = {"X-MBX-APIKEY": self._api_key} + + if signed: + params["timestamp"] = self._get_timestamp() + params["recvWindow"] = self._recv_window + params["signature"] = self._sign(params) + + url = f"{self.base_url}{endpoint}" + + try: + if method == "GET": + response = await self._client.get(url, params=params, headers=headers) + else: + raise ValueError(f"Unsupported HTTP method: {method}") + + data = response.json() + + if response.status_code >= 400: + code = data.get("code", 0) + msg = data.get("msg", "Unknown error") + logger.error("API error", status_code=response.status_code, code=code, msg=msg) + if code in (-1021, -1022): + raise AuthenticationError(msg) + raise ExchangeError(f"[{code}] {msg}") + + return data + + except httpx.RequestError as e: + logger.error("Request failed", endpoint=endpoint, error=str(e)) + raise ExchangeError(f"Request failed: {e}") from e + + async def connect(self) -> None: + """Establish connection and validate credentials.""" + self._client = httpx.AsyncClient(timeout=30.0) + + try: + # Validate credentials by fetching account info + await self._request("GET", "/api/v3/account", signed=True) + logger.info( + "Connected to Binance Spot", + mode=self.settings.trading_mode.value, + base_url=self.base_url, + ) + except Exception as e: + await self.disconnect() + raise AuthenticationError(f"Failed to authenticate: {e}") from e + + async def disconnect(self) -> None: + """Close HTTP client.""" + if self._client: + await self._client.aclose() + self._client = None + logger.info("Disconnected from Binance Spot") + + async def get_balance(self, asset: str) -> SpotBalance: + """Get spot wallet balance for a specific asset. + + Args: + asset: Asset symbol (e.g., "BTC", "USDT") + + Returns: + SpotBalance with free and locked amounts + + Raises: + ExchangeError: If asset not found + """ + data = await self._request("GET", "/api/v3/account", signed=True) + + for balance in data.get("balances", []): + if balance["asset"] == asset: + return SpotBalance( + asset=asset, + free=Decimal(balance["free"]), + locked=Decimal(balance["locked"]), + updated_at=datetime.now(UTC), + ) + + raise ExchangeError(f"Asset {asset} not found in spot wallet") + + async def get_all_balances(self, min_balance: Decimal = Decimal("0")) -> list[SpotBalance]: + """Get all spot wallet balances. + + Args: + min_balance: Minimum total balance to include (default: 0, includes all) + + Returns: + List of SpotBalance objects with non-zero balances + """ + data = await self._request("GET", "/api/v3/account", signed=True) + now = datetime.now(UTC) + + balances = [] + for balance in data.get("balances", []): + free = Decimal(balance["free"]) + locked = Decimal(balance["locked"]) + total = free + locked + + if total > min_balance: + balances.append( + SpotBalance( + asset=balance["asset"], + free=free, + locked=locked, + updated_at=now, + ) + ) + + return balances diff --git a/src/tradefinder/core/config.py b/src/tradefinder/core/config.py index 85bd16f..253eb56 100644 --- a/src/tradefinder/core/config.py +++ b/src/tradefinder/core/config.py @@ -271,6 +271,20 @@ class Settings(BaseSettings): return "wss://stream.binancefuture.com" return "wss://fstream.binance.com" + @property + def binance_spot_base_url(self) -> str: + """Get the appropriate Binance Spot API base URL for current mode.""" + if self.trading_mode == TradingMode.TESTNET: + return "https://testnet.binance.vision" + return "https://api.binance.com" + + @property + def binance_spot_ws_url(self) -> str: + """Get the appropriate Binance Spot WebSocket URL for current mode.""" + if self.trading_mode == TradingMode.TESTNET: + return "wss://testnet.binance.vision" + return "wss://stream.binance.com:9443" + def get_active_api_key(self) -> SecretStr | None: """Get the API key for the current trading mode.""" if self.trading_mode == TradingMode.TESTNET: diff --git a/src/tradefinder/data/__init__.py b/src/tradefinder/data/__init__.py index e69de29..1cfe880 100644 --- a/src/tradefinder/data/__init__.py +++ b/src/tradefinder/data/__init__.py @@ -0,0 +1,25 @@ +"""Data ingestion and storage module for TradeFinder. + +This module provides functionality for fetching, storing, and +retrieving market data. + +Components: + - DataStorage: DuckDB storage manager + - DataFetcher: Historical data fetcher + - schemas: Database table definitions + +Usage: + from tradefinder.data import DataStorage, DataFetcher + + storage = DataStorage(Path("/data/tradefinder.duckdb")) + storage.connect() + storage.initialize_schema() + + fetcher = DataFetcher(adapter, storage) + await fetcher.sync_candles("BTCUSDT", "4h") +""" + +from tradefinder.data.fetcher import DataFetcher +from tradefinder.data.storage import DataStorage + +__all__ = ["DataStorage", "DataFetcher"] diff --git a/src/tradefinder/data/fetcher.py b/src/tradefinder/data/fetcher.py new file mode 100644 index 0000000..f5653df --- /dev/null +++ b/src/tradefinder/data/fetcher.py @@ -0,0 +1,207 @@ +"""Historical data fetcher for market data. + +Fetches historical OHLCV data from exchange and stores in database. +""" + +from datetime import datetime, timedelta + +import structlog + +from tradefinder.adapters.base import ExchangeAdapter +from tradefinder.data.storage import DataStorage + +logger = structlog.get_logger(__name__) + +# Timeframe to milliseconds mapping +TIMEFRAME_MS = { + "1m": 60 * 1000, + "3m": 3 * 60 * 1000, + "5m": 5 * 60 * 1000, + "15m": 15 * 60 * 1000, + "30m": 30 * 60 * 1000, + "1h": 60 * 60 * 1000, + "2h": 2 * 60 * 60 * 1000, + "4h": 4 * 60 * 60 * 1000, + "6h": 6 * 60 * 60 * 1000, + "8h": 8 * 60 * 60 * 1000, + "12h": 12 * 60 * 60 * 1000, + "1d": 24 * 60 * 60 * 1000, + "3d": 3 * 24 * 60 * 60 * 1000, + "1w": 7 * 24 * 60 * 60 * 1000, +} + + +class DataFetcher: + """Fetches historical market data from exchange. + + Usage: + fetcher = DataFetcher(adapter, storage) + await fetcher.backfill_candles("BTCUSDT", "4h", start_date, end_date) + await fetcher.sync_candles("BTCUSDT", "4h") + """ + + def __init__(self, adapter: ExchangeAdapter, storage: DataStorage) -> None: + """Initialize fetcher with adapter and storage. + + Args: + adapter: Exchange adapter for fetching data + storage: Data storage for persisting data + """ + self.adapter = adapter + self.storage = storage + + async def backfill_candles( + self, + symbol: str, + timeframe: str, + start_date: datetime, + end_date: datetime | None = None, + batch_size: int = 1000, + ) -> int: + """Fetch and store historical candle data. + + Args: + symbol: Trading symbol (e.g., "BTCUSDT") + timeframe: Candle timeframe (e.g., "1h", "4h") + start_date: Start date for backfill + end_date: End date for backfill (default: now) + batch_size: Number of candles per API request (max 1500) + + Returns: + Total number of candles fetched and stored + """ + if end_date is None: + end_date = datetime.now() + + tf_ms = TIMEFRAME_MS.get(timeframe) + if tf_ms is None: + raise ValueError(f"Unknown timeframe: {timeframe}") + + start_ms = int(start_date.timestamp() * 1000) + end_ms = int(end_date.timestamp() * 1000) + + total_fetched = 0 + current_start = start_ms + + logger.info( + "Starting backfill", + symbol=symbol, + timeframe=timeframe, + start=start_date.isoformat(), + end=end_date.isoformat(), + ) + + while current_start < end_ms: + # Fetch batch of candles + candles = await self.adapter.get_candles( + symbol=symbol, + timeframe=timeframe, + limit=min(batch_size, 1500), + start_time=current_start, + end_time=end_ms, + ) + + if not candles: + break + + # Store candles + inserted = self.storage.insert_candles(candles, symbol, timeframe) + total_fetched += inserted + + # Move to next batch + last_candle_ts = int(candles[-1].timestamp.timestamp() * 1000) + current_start = last_candle_ts + tf_ms + + logger.debug( + "Fetched batch", + symbol=symbol, + timeframe=timeframe, + count=len(candles), + total=total_fetched, + ) + + # Break if we got fewer candles than requested (end of data) + if len(candles) < batch_size: + break + + logger.info( + "Backfill complete", + symbol=symbol, + timeframe=timeframe, + total_candles=total_fetched, + ) + + return total_fetched + + async def sync_candles( + self, + symbol: str, + timeframe: str, + lookback_days: int = 30, + ) -> int: + """Sync candles from last stored timestamp to now. + + If no candles exist, fetches from lookback_days ago. + + Args: + symbol: Trading symbol + timeframe: Candle timeframe + lookback_days: Days to look back if no existing data + + Returns: + Number of new candles fetched + """ + # Get latest stored candle timestamp + latest_ts = self.storage.get_latest_candle_timestamp(symbol, timeframe) + + if latest_ts: + # Start from next candle after latest + tf_ms = TIMEFRAME_MS.get(timeframe, 3600000) + start_date = latest_ts + timedelta(milliseconds=tf_ms) + logger.info( + "Syncing from last candle", + symbol=symbol, + timeframe=timeframe, + last_candle=latest_ts.isoformat(), + ) + else: + # No existing data, use lookback + start_date = datetime.now() - timedelta(days=lookback_days) + logger.info( + "No existing data, fetching from lookback", + symbol=symbol, + timeframe=timeframe, + lookback_days=lookback_days, + ) + + return await self.backfill_candles( + symbol=symbol, + timeframe=timeframe, + start_date=start_date, + ) + + async def fetch_latest_candles( + self, + symbol: str, + timeframe: str, + limit: int = 100, + ) -> int: + """Fetch and store the most recent candles. + + Args: + symbol: Trading symbol + timeframe: Candle timeframe + limit: Number of recent candles to fetch + + Returns: + Number of candles fetched + """ + candles = await self.adapter.get_candles( + symbol=symbol, + timeframe=timeframe, + limit=limit, + ) + + if candles: + return self.storage.insert_candles(candles, symbol, timeframe) + return 0 diff --git a/src/tradefinder/data/schemas.py b/src/tradefinder/data/schemas.py new file mode 100644 index 0000000..a1e5855 --- /dev/null +++ b/src/tradefinder/data/schemas.py @@ -0,0 +1,54 @@ +"""Database schemas for market data storage.""" + +CANDLES_SCHEMA = """ +CREATE TABLE IF NOT EXISTS candles ( + symbol VARCHAR NOT NULL, + timeframe VARCHAR NOT NULL, + timestamp TIMESTAMP NOT NULL, + open DOUBLE NOT NULL, + high DOUBLE NOT NULL, + low DOUBLE NOT NULL, + close DOUBLE NOT NULL, + volume DOUBLE NOT NULL, + PRIMARY KEY (symbol, timeframe, timestamp) +); + +CREATE INDEX IF NOT EXISTS idx_candles_symbol_tf +ON candles (symbol, timeframe, timestamp DESC); +""" + +TRADES_SCHEMA = """ +CREATE TABLE IF NOT EXISTS trades ( + id VARCHAR PRIMARY KEY, + symbol VARCHAR NOT NULL, + side VARCHAR NOT NULL, + position_side VARCHAR NOT NULL, + entry_price DOUBLE NOT NULL, + exit_price DOUBLE, + quantity DOUBLE NOT NULL, + pnl_usdt DOUBLE, + pnl_pct DOUBLE, + strategy VARCHAR NOT NULL, + entry_time TIMESTAMP NOT NULL, + exit_time TIMESTAMP, + status VARCHAR NOT NULL +); + +CREATE INDEX IF NOT EXISTS idx_trades_symbol ON trades (symbol); +CREATE INDEX IF NOT EXISTS idx_trades_strategy ON trades (strategy); +CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades (entry_time DESC); +""" + +FUNDING_RATES_SCHEMA = """ +CREATE TABLE IF NOT EXISTS funding_rates ( + symbol VARCHAR NOT NULL, + funding_rate DOUBLE NOT NULL, + funding_time TIMESTAMP NOT NULL, + mark_price DOUBLE NOT NULL, + PRIMARY KEY (symbol, funding_time) +); + +CREATE INDEX IF NOT EXISTS idx_funding_symbol ON funding_rates (symbol, funding_time DESC); +""" + +ALL_SCHEMAS = [CANDLES_SCHEMA, TRADES_SCHEMA, FUNDING_RATES_SCHEMA] diff --git a/src/tradefinder/data/storage.py b/src/tradefinder/data/storage.py new file mode 100644 index 0000000..397d4af --- /dev/null +++ b/src/tradefinder/data/storage.py @@ -0,0 +1,242 @@ +"""DuckDB storage manager for market data. + +Provides async-compatible interface for storing and retrieving +market data using DuckDB. +""" + +from datetime import datetime +from pathlib import Path + +import duckdb +import structlog + +from tradefinder.adapters.types import Candle, FundingRate +from tradefinder.data.schemas import ALL_SCHEMAS + +logger = structlog.get_logger(__name__) + + +class DataStorage: + """DuckDB storage manager for market data. + + Usage: + storage = DataStorage(Path("/data/tradefinder.duckdb")) + storage.connect() + storage.initialize_schema() + storage.insert_candles(candles) + storage.disconnect() + """ + + def __init__(self, db_path: Path) -> None: + """Initialize storage with database path. + + Args: + db_path: Path to DuckDB database file + """ + self.db_path = db_path + self._conn: duckdb.DuckDBPyConnection | None = None + + def connect(self) -> None: + """Connect to the database.""" + # Ensure parent directory exists + self.db_path.parent.mkdir(parents=True, exist_ok=True) + + self._conn = duckdb.connect(str(self.db_path)) + logger.info("Connected to DuckDB", path=str(self.db_path)) + + def disconnect(self) -> None: + """Close database connection.""" + if self._conn: + self._conn.close() + self._conn = None + logger.info("Disconnected from DuckDB") + + def __enter__(self) -> "DataStorage": + """Context manager entry.""" + self.connect() + return self + + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: + """Context manager exit.""" + self.disconnect() + + @property + def conn(self) -> duckdb.DuckDBPyConnection: + """Get database connection.""" + if self._conn is None: + raise RuntimeError("Not connected. Call connect() first.") + return self._conn + + def initialize_schema(self) -> None: + """Create all database tables and indexes.""" + for schema in ALL_SCHEMAS: + # Execute each statement separately + for statement in schema.strip().split(";"): + statement = statement.strip() + if statement: + self.conn.execute(statement) + + logger.info("Database schema initialized") + + def insert_candles(self, candles: list[Candle], symbol: str, timeframe: str) -> int: + """Insert candles into the database. + + Args: + candles: List of Candle objects to insert + symbol: Trading symbol (e.g., "BTCUSDT") + timeframe: Candle timeframe (e.g., "1h", "4h") + + Returns: + Number of candles inserted + """ + if not candles: + return 0 + + # Prepare data for insertion + data = [ + ( + symbol, + timeframe, + c.timestamp, + float(c.open), + float(c.high), + float(c.low), + float(c.close), + float(c.volume), + ) + for c in candles + ] + + # Use INSERT OR REPLACE to handle duplicates + self.conn.executemany( + """ + INSERT OR REPLACE INTO candles + (symbol, timeframe, timestamp, open, high, low, close, volume) + VALUES (?, ?, ?, ?, ?, ?, ?, ?) + """, + data, + ) + + logger.debug( + "Inserted candles", + symbol=symbol, + timeframe=timeframe, + count=len(candles), + ) + return len(candles) + + def get_candles( + self, + symbol: str, + timeframe: str, + start: datetime | None = None, + end: datetime | None = None, + limit: int | None = None, + ) -> list[Candle]: + """Retrieve candles from the database. + + Args: + symbol: Trading symbol + timeframe: Candle timeframe + start: Start timestamp (inclusive) + end: End timestamp (inclusive) + limit: Maximum number of candles to return + + Returns: + List of Candle objects, oldest first + """ + query = """ + SELECT timestamp, open, high, low, close, volume + FROM candles + WHERE symbol = ? AND timeframe = ? + """ + params: list[str | datetime | int] = [symbol, timeframe] + + if start: + query += " AND timestamp >= ?" + params.append(start) + if end: + query += " AND timestamp <= ?" + params.append(end) + + query += " ORDER BY timestamp ASC" + + if limit: + query += " LIMIT ?" + params.append(limit) + + result = self.conn.execute(query, params).fetchall() + + from decimal import Decimal + + return [ + Candle( + timestamp=row[0], + open=Decimal(str(row[1])), + high=Decimal(str(row[2])), + low=Decimal(str(row[3])), + close=Decimal(str(row[4])), + volume=Decimal(str(row[5])), + ) + for row in result + ] + + def get_latest_candle_timestamp(self, symbol: str, timeframe: str) -> datetime | None: + """Get the timestamp of the most recent candle. + + Args: + symbol: Trading symbol + timeframe: Candle timeframe + + Returns: + Timestamp of latest candle, or None if no candles exist + """ + result = self.conn.execute( + """ + SELECT MAX(timestamp) FROM candles + WHERE symbol = ? AND timeframe = ? + """, + [symbol, timeframe], + ).fetchone() + + return result[0] if result and result[0] else None + + def insert_funding_rate(self, rate: FundingRate) -> None: + """Insert a funding rate record. + + Args: + rate: FundingRate object to insert + """ + self.conn.execute( + """ + INSERT OR REPLACE INTO funding_rates + (symbol, funding_rate, funding_time, mark_price) + VALUES (?, ?, ?, ?) + """, + [ + rate.symbol, + float(rate.funding_rate), + rate.funding_time, + float(rate.mark_price), + ], + ) + + def get_candle_count(self, symbol: str, timeframe: str) -> int: + """Get the number of candles stored for a symbol/timeframe. + + Args: + symbol: Trading symbol + timeframe: Candle timeframe + + Returns: + Number of candles + """ + result = self.conn.execute( + """ + SELECT COUNT(*) FROM candles + WHERE symbol = ? AND timeframe = ? + """, + [symbol, timeframe], + ).fetchone() + + return result[0] if result else 0 diff --git a/tests/test_data_fetcher.py b/tests/test_data_fetcher.py new file mode 100644 index 0000000..60f48bd --- /dev/null +++ b/tests/test_data_fetcher.py @@ -0,0 +1,286 @@ +"""Unit tests for DataFetcher (backfill and sync logic).""" + +import tempfile +from datetime import datetime, timedelta +from decimal import Decimal +from pathlib import Path +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from tradefinder.adapters.types import Candle +from tradefinder.data.fetcher import TIMEFRAME_MS, DataFetcher +from tradefinder.data.storage import DataStorage + + +def make_candle(timestamp: datetime) -> Candle: + """Create a test candle at given timestamp.""" + return Candle( + timestamp=timestamp, + open=Decimal("50000.00"), + high=Decimal("51000.00"), + low=Decimal("49000.00"), + close=Decimal("50500.00"), + volume=Decimal("1000.00"), + ) + + +class TestTimeframeMappings: + """Tests for timeframe constant mappings.""" + + def test_common_timeframes_are_defined(self) -> None: + """All expected timeframes have millisecond mappings.""" + expected = ["1m", "5m", "15m", "30m", "1h", "4h", "1d", "1w"] + for tf in expected: + assert tf in TIMEFRAME_MS + assert TIMEFRAME_MS[tf] > 0 + + def test_timeframe_values_are_correct(self) -> None: + """Timeframe millisecond values are accurate.""" + assert TIMEFRAME_MS["1m"] == 60 * 1000 + assert TIMEFRAME_MS["1h"] == 60 * 60 * 1000 + assert TIMEFRAME_MS["4h"] == 4 * 60 * 60 * 1000 + assert TIMEFRAME_MS["1d"] == 24 * 60 * 60 * 1000 + + +class TestDataFetcherBackfill: + """Tests for backfill_candles functionality.""" + + async def test_backfill_fetches_and_stores_candles(self) -> None: + """Candles are fetched from adapter and stored.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + + # Mock adapter + adapter = MagicMock() + start = datetime(2024, 1, 1, 0, 0, 0) + end = datetime(2024, 1, 1, 4, 0, 0) + + # Return 4 candles then empty (end of data) + candles = [make_candle(start + timedelta(hours=i)) for i in range(4)] + adapter.get_candles = AsyncMock(side_effect=[candles, []]) + + with DataStorage(db_path) as storage: + storage.initialize_schema() + fetcher = DataFetcher(adapter, storage) + + total = await fetcher.backfill_candles("BTCUSDT", "1h", start, end) + + assert total == 4 + assert storage.get_candle_count("BTCUSDT", "1h") == 4 + + async def test_backfill_uses_default_end_date(self) -> None: + """End date defaults to now if not provided.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + + adapter = MagicMock() + adapter.get_candles = AsyncMock(return_value=[]) # No data + + with DataStorage(db_path) as storage: + storage.initialize_schema() + fetcher = DataFetcher(adapter, storage) + + start = datetime.now() - timedelta(hours=1) + await fetcher.backfill_candles("BTCUSDT", "1h", start) + + # Should have been called (even if no data returned) + adapter.get_candles.assert_called() + + async def test_backfill_raises_on_unknown_timeframe(self) -> None: + """ValueError raised for unknown timeframe.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + + adapter = MagicMock() + + with DataStorage(db_path) as storage: + storage.initialize_schema() + fetcher = DataFetcher(adapter, storage) + + with pytest.raises(ValueError, match="Unknown timeframe"): + await fetcher.backfill_candles( + "BTCUSDT", + "invalid_tf", + datetime(2024, 1, 1), + ) + + async def test_backfill_handles_empty_response(self) -> None: + """Empty response from adapter is handled gracefully.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + + adapter = MagicMock() + adapter.get_candles = AsyncMock(return_value=[]) + + with DataStorage(db_path) as storage: + storage.initialize_schema() + fetcher = DataFetcher(adapter, storage) + + total = await fetcher.backfill_candles( + "BTCUSDT", + "1h", + datetime(2024, 1, 1), + datetime(2024, 1, 2), + ) + + assert total == 0 + + async def test_backfill_respects_batch_size(self) -> None: + """Batch size is respected and capped at 1500.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + + adapter = MagicMock() + adapter.get_candles = AsyncMock(return_value=[]) + + with DataStorage(db_path) as storage: + storage.initialize_schema() + fetcher = DataFetcher(adapter, storage) + + # Request with batch_size > 1500 + await fetcher.backfill_candles( + "BTCUSDT", + "1h", + datetime(2024, 1, 1), + datetime(2024, 1, 2), + batch_size=2000, + ) + + # Verify limit was capped at 1500 + call_kwargs = adapter.get_candles.call_args.kwargs + assert call_kwargs["limit"] == 1500 + + async def test_backfill_paginates_correctly(self) -> None: + """Multiple batches are fetched for large date ranges.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + + adapter = MagicMock() + start = datetime(2024, 1, 1, 0, 0, 0) + + # Return 3 candles per batch, simulate 2 full batches + partial + batch1 = [make_candle(start + timedelta(hours=i)) for i in range(3)] + batch2 = [make_candle(start + timedelta(hours=i + 3)) for i in range(3)] + batch3 = [make_candle(start + timedelta(hours=6))] # Partial batch + + adapter.get_candles = AsyncMock(side_effect=[batch1, batch2, batch3]) + + with DataStorage(db_path) as storage: + storage.initialize_schema() + fetcher = DataFetcher(adapter, storage) + + total = await fetcher.backfill_candles( + "BTCUSDT", + "1h", + start, + start + timedelta(hours=10), + batch_size=3, + ) + + # 3 + 3 + 1 = 7 candles + assert total == 7 + assert adapter.get_candles.call_count == 3 + + +class TestDataFetcherSync: + """Tests for sync_candles functionality.""" + + async def test_sync_from_latest_timestamp(self) -> None: + """Sync starts from last stored candle.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + + adapter = MagicMock() + adapter.get_candles = AsyncMock(return_value=[]) + + with DataStorage(db_path) as storage: + storage.initialize_schema() + + # Pre-populate with some candles + base_ts = datetime(2024, 1, 1, 0, 0, 0) + existing = [make_candle(base_ts + timedelta(hours=i)) for i in range(5)] + storage.insert_candles(existing, "BTCUSDT", "1h") + + fetcher = DataFetcher(adapter, storage) + await fetcher.sync_candles("BTCUSDT", "1h") + + # Verify sync started from after the last candle + call_kwargs = adapter.get_candles.call_args.kwargs + start_time = call_kwargs.get("start_time") + # Should be after hour 4 (the last existing candle) + expected_start_ms = int((base_ts + timedelta(hours=5)).timestamp() * 1000) + assert start_time >= expected_start_ms + + async def test_sync_with_no_existing_data_uses_lookback(self) -> None: + """Sync uses lookback when no existing data.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + + adapter = MagicMock() + adapter.get_candles = AsyncMock(return_value=[]) + + with DataStorage(db_path) as storage: + storage.initialize_schema() + + fetcher = DataFetcher(adapter, storage) + await fetcher.sync_candles("BTCUSDT", "1h", lookback_days=7) + + # Should have called get_candles + adapter.get_candles.assert_called() + + +class TestDataFetcherLatest: + """Tests for fetch_latest_candles functionality.""" + + async def test_fetch_latest_stores_candles(self) -> None: + """Latest candles are fetched and stored.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + + adapter = MagicMock() + now = datetime.now() + candles = [make_candle(now - timedelta(hours=i)) for i in range(5)] + adapter.get_candles = AsyncMock(return_value=candles) + + with DataStorage(db_path) as storage: + storage.initialize_schema() + + fetcher = DataFetcher(adapter, storage) + count = await fetcher.fetch_latest_candles("BTCUSDT", "1h", limit=5) + + assert count == 5 + assert storage.get_candle_count("BTCUSDT", "1h") == 5 + + async def test_fetch_latest_with_empty_response(self) -> None: + """Empty response returns 0.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + + adapter = MagicMock() + adapter.get_candles = AsyncMock(return_value=[]) + + with DataStorage(db_path) as storage: + storage.initialize_schema() + + fetcher = DataFetcher(adapter, storage) + count = await fetcher.fetch_latest_candles("BTCUSDT", "1h") + + assert count == 0 + + async def test_fetch_latest_respects_limit(self) -> None: + """Limit parameter is passed to adapter.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + + adapter = MagicMock() + adapter.get_candles = AsyncMock(return_value=[]) + + with DataStorage(db_path) as storage: + storage.initialize_schema() + + fetcher = DataFetcher(adapter, storage) + await fetcher.fetch_latest_candles("BTCUSDT", "1h", limit=50) + + call_kwargs = adapter.get_candles.call_args.kwargs + assert call_kwargs["limit"] == 50 diff --git a/tests/test_data_storage.py b/tests/test_data_storage.py new file mode 100644 index 0000000..9a08de7 --- /dev/null +++ b/tests/test_data_storage.py @@ -0,0 +1,293 @@ +"""Unit tests for DataStorage (DuckDB operations).""" + +import tempfile +from datetime import datetime, timedelta +from decimal import Decimal +from pathlib import Path + +import pytest + +from tradefinder.adapters.types import Candle, FundingRate +from tradefinder.data.storage import DataStorage + + +class TestDataStorageConnection: + """Tests for connection lifecycle.""" + + def test_connect_creates_database_file(self) -> None: + """Database file is created on connect.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + storage = DataStorage(db_path) + + storage.connect() + assert db_path.exists() + storage.disconnect() + + def test_connect_creates_parent_directories(self) -> None: + """Parent directories are created if they don't exist.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "nested" / "dir" / "test.duckdb" + storage = DataStorage(db_path) + + storage.connect() + assert db_path.parent.exists() + storage.disconnect() + + def test_disconnect_closes_connection(self) -> None: + """Disconnect properly closes the connection.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + storage = DataStorage(db_path) + + storage.connect() + storage.disconnect() + assert storage._conn is None + + def test_context_manager_lifecycle(self) -> None: + """Context manager properly opens and closes connection.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + assert storage._conn is not None + assert storage._conn is None + + def test_conn_property_raises_when_not_connected(self) -> None: + """Accessing conn property raises when not connected.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + storage = DataStorage(db_path) + + with pytest.raises(RuntimeError, match="Not connected"): + _ = storage.conn + + +class TestDataStorageSchema: + """Tests for schema initialization.""" + + def test_initialize_schema_creates_tables(self) -> None: + """All tables are created by initialize_schema.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + # Verify tables exist + result = storage.conn.execute( + "SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'" + ).fetchall() + tables = {row[0] for row in result} + + assert "candles" in tables + assert "trades" in tables + assert "funding_rates" in tables + + def test_initialize_schema_is_idempotent(self) -> None: + """Calling initialize_schema multiple times is safe.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + storage.initialize_schema() # Should not raise + + +class TestDataStorageCandles: + """Tests for candle CRUD operations.""" + + def _make_candle(self, timestamp: datetime) -> Candle: + """Create a test candle.""" + return Candle( + timestamp=timestamp, + open=Decimal("50000.00"), + high=Decimal("51000.00"), + low=Decimal("49000.00"), + close=Decimal("50500.00"), + volume=Decimal("1000.50"), + ) + + def test_insert_candles_stores_data(self) -> None: + """Candles are stored correctly.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + ts = datetime(2024, 1, 1, 12, 0, 0) + candles = [self._make_candle(ts)] + + inserted = storage.insert_candles(candles, "BTCUSDT", "1h") + assert inserted == 1 + + def test_insert_candles_empty_list_returns_zero(self) -> None: + """Inserting empty list returns 0.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + inserted = storage.insert_candles([], "BTCUSDT", "1h") + assert inserted == 0 + + def test_insert_candles_handles_duplicates(self) -> None: + """Duplicate candles are replaced (upsert behavior).""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + ts = datetime(2024, 1, 1, 12, 0, 0) + candles = [self._make_candle(ts)] + + storage.insert_candles(candles, "BTCUSDT", "1h") + storage.insert_candles(candles, "BTCUSDT", "1h") # Duplicate + + count = storage.get_candle_count("BTCUSDT", "1h") + assert count == 1 # Not 2 + + def test_get_candles_retrieves_data(self) -> None: + """Candles can be retrieved correctly.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + ts = datetime(2024, 1, 1, 12, 0, 0) + candles = [self._make_candle(ts)] + storage.insert_candles(candles, "BTCUSDT", "1h") + + result = storage.get_candles("BTCUSDT", "1h") + assert len(result) == 1 + assert result[0].open == Decimal("50000.00") + assert result[0].close == Decimal("50500.00") + + def test_get_candles_respects_time_range(self) -> None: + """Candles are filtered by start/end time.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + base_ts = datetime(2024, 1, 1, 0, 0, 0) + candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(10)] + storage.insert_candles(candles, "BTCUSDT", "1h") + + # Query middle range + start = base_ts + timedelta(hours=3) + end = base_ts + timedelta(hours=6) + result = storage.get_candles("BTCUSDT", "1h", start=start, end=end) + + assert len(result) == 4 # hours 3, 4, 5, 6 + + def test_get_candles_respects_limit(self) -> None: + """Candle count is limited correctly.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + base_ts = datetime(2024, 1, 1, 0, 0, 0) + candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(10)] + storage.insert_candles(candles, "BTCUSDT", "1h") + + result = storage.get_candles("BTCUSDT", "1h", limit=5) + assert len(result) == 5 + + def test_get_candles_returns_ascending_order(self) -> None: + """Candles are returned oldest first.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + base_ts = datetime(2024, 1, 1, 0, 0, 0) + # Insert in reverse order + candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(5)] + storage.insert_candles(candles[::-1], "BTCUSDT", "1h") + + result = storage.get_candles("BTCUSDT", "1h") + timestamps = [c.timestamp for c in result] + assert timestamps == sorted(timestamps) + + def test_get_candles_filters_by_symbol(self) -> None: + """Candles are filtered by symbol.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + ts = datetime(2024, 1, 1, 12, 0, 0) + candles = [self._make_candle(ts)] + + storage.insert_candles(candles, "BTCUSDT", "1h") + storage.insert_candles(candles, "ETHUSDT", "1h") + + btc_result = storage.get_candles("BTCUSDT", "1h") + eth_result = storage.get_candles("ETHUSDT", "1h") + + assert len(btc_result) == 1 + assert len(eth_result) == 1 + + def test_get_latest_candle_timestamp(self) -> None: + """Latest timestamp is returned correctly.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + base_ts = datetime(2024, 1, 1, 0, 0, 0) + candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(5)] + storage.insert_candles(candles, "BTCUSDT", "1h") + + latest = storage.get_latest_candle_timestamp("BTCUSDT", "1h") + assert latest == base_ts + timedelta(hours=4) + + def test_get_latest_candle_timestamp_returns_none_when_empty(self) -> None: + """None is returned when no candles exist.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + latest = storage.get_latest_candle_timestamp("BTCUSDT", "1h") + assert latest is None + + def test_get_candle_count(self) -> None: + """Candle count is accurate.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + base_ts = datetime(2024, 1, 1, 0, 0, 0) + candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(7)] + storage.insert_candles(candles, "BTCUSDT", "1h") + + count = storage.get_candle_count("BTCUSDT", "1h") + assert count == 7 + + +class TestDataStorageFundingRates: + """Tests for funding rate operations.""" + + def test_insert_funding_rate(self) -> None: + """Funding rate is stored correctly.""" + with tempfile.TemporaryDirectory() as tmpdir: + db_path = Path(tmpdir) / "test.duckdb" + with DataStorage(db_path) as storage: + storage.initialize_schema() + + rate = FundingRate( + symbol="BTCUSDT", + funding_rate=Decimal("0.0001"), + funding_time=datetime(2024, 1, 1, 8, 0, 0), + mark_price=Decimal("50000.00"), + ) + + storage.insert_funding_rate(rate) + + # Verify stored + result = storage.conn.execute( + "SELECT symbol, funding_rate FROM funding_rates" + ).fetchone() + assert result is not None + assert result[0] == "BTCUSDT" diff --git a/tests/test_spot_adapter.py b/tests/test_spot_adapter.py new file mode 100644 index 0000000..f9076c2 --- /dev/null +++ b/tests/test_spot_adapter.py @@ -0,0 +1,295 @@ +"""Unit tests for BinanceSpotAdapter.""" + +from __future__ import annotations + +import hashlib +import hmac +from datetime import UTC +from decimal import Decimal +from typing import Any +from unittest.mock import AsyncMock, Mock +from urllib.parse import urlencode + +import pytest +from pydantic import SecretStr + +from tradefinder.adapters.base import AuthenticationError, ExchangeError +from tradefinder.adapters.binance_spot import BinanceSpotAdapter, SpotBalance +from tradefinder.core.config import BinanceSettings, Settings, TradingMode + + +def _make_response(status_code: int, payload: Any) -> Mock: + """Create a mock HTTP response.""" + response = Mock() + response.status_code = status_code + response.json.return_value = payload + return response + + +@pytest.fixture +def settings() -> Settings: + """Create test settings without loading .env file.""" + return Settings( + trading_mode=TradingMode.TESTNET, + binance=BinanceSettings( + testnet_api_key=SecretStr("test-key"), + testnet_secret=SecretStr("test-secret"), + ), + _env_file=None, + ) + + +@pytest.fixture +def adapter(settings: Settings) -> BinanceSpotAdapter: + """Create adapter with test settings.""" + return BinanceSpotAdapter(settings) + + +@pytest.fixture +def mock_http_client() -> AsyncMock: + """Create mock HTTP client.""" + client = AsyncMock() + client.get = AsyncMock() + client.aclose = AsyncMock() + return client + + +def _sample_account_response() -> dict[str, Any]: + """Sample Binance spot account response.""" + return { + "makerCommission": 10, + "takerCommission": 10, + "balances": [ + {"asset": "BTC", "free": "0.50000000", "locked": "0.10000000"}, + {"asset": "ETH", "free": "5.00000000", "locked": "0.00000000"}, + {"asset": "USDT", "free": "1000.00000000", "locked": "500.00000000"}, + {"asset": "BNB", "free": "0.00000000", "locked": "0.00000000"}, + ], + } + + +class TestSpotBalanceDataclass: + """Tests for SpotBalance dataclass.""" + + def test_total_property(self) -> None: + """Total equals free plus locked.""" + from datetime import datetime + + balance = SpotBalance( + asset="BTC", + free=Decimal("1.5"), + locked=Decimal("0.5"), + updated_at=datetime.now(UTC), + ) + assert balance.total == Decimal("2.0") + + def test_zero_balance_total(self) -> None: + """Zero balance returns zero total.""" + from datetime import datetime + + balance = SpotBalance( + asset="XRP", + free=Decimal("0"), + locked=Decimal("0"), + updated_at=datetime.now(UTC), + ) + assert balance.total == Decimal("0") + + +class TestBinanceSpotAdapterInit: + """Initialization and credential tests.""" + + def test_credentials_are_accessible(self, adapter: BinanceSpotAdapter) -> None: + """API key and secret are exposed through properties.""" + assert adapter._api_key == "test-key" + assert adapter._secret == "test-secret" + + def test_base_url_uses_testnet(self, settings: Settings) -> None: + """Testnet mode uses testnet URL.""" + adapter = BinanceSpotAdapter(settings) + assert "testnet" in adapter.base_url.lower() + + def test_sign_generates_valid_signature(self, adapter: BinanceSpotAdapter) -> None: + """HMAC-SHA256 signature matches stdlib implementation.""" + params = {"symbol": "BTCUSDT", "timestamp": "1234567890"} + expected = hmac.new( + adapter._secret.encode("utf-8"), + urlencode(params).encode("utf-8"), + hashlib.sha256, + ).hexdigest() + assert adapter._sign(params) == expected + + +class TestBinanceSpotAdapterConnection: + """Connection lifecycle tests.""" + + async def test_connect_creates_client(self, adapter: BinanceSpotAdapter) -> None: + """Connect creates HTTP client.""" + mock_client = AsyncMock() + mock_client.get = AsyncMock(return_value=_make_response(200, _sample_account_response())) + mock_client.aclose = AsyncMock() + + import httpx + + original_client = httpx.AsyncClient + + try: + # Patch httpx.AsyncClient + httpx.AsyncClient = lambda **kwargs: mock_client + await adapter.connect() + assert adapter._client is mock_client + finally: + httpx.AsyncClient = original_client + await adapter.disconnect() + + async def test_disconnect_closes_client( + self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock + ) -> None: + """Disconnect closes HTTP client.""" + adapter._client = mock_http_client + + await adapter.disconnect() + + mock_http_client.aclose.assert_awaited_once() + assert adapter._client is None + + async def test_disconnect_when_not_connected(self, adapter: BinanceSpotAdapter) -> None: + """Disconnect when not connected is safe.""" + await adapter.disconnect() # Should not raise + + +class TestBinanceSpotAdapterRequest: + """Request handling tests.""" + + async def test_request_adds_auth_headers( + self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock + ) -> None: + """Requests include X-MBX-APIKEY header.""" + adapter._client = mock_http_client + response = _make_response(200, {"ok": True}) + mock_http_client.get.return_value = response + + result = await adapter._request("GET", "/api/v3/test") + + assert result == {"ok": True} + mock_http_client.get.assert_awaited_once() + called_headers = mock_http_client.get.call_args[1]["headers"] + assert called_headers["X-MBX-APIKEY"] == "test-key" + + async def test_request_adds_signature_when_signed( + self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock + ) -> None: + """Signed requests include timestamp and signature.""" + adapter._client = mock_http_client + response = _make_response(200, {"ok": True}) + mock_http_client.get.return_value = response + + await adapter._request("GET", "/api/v3/test", params={}, signed=True) + + call_kwargs = mock_http_client.get.call_args[1] + params = call_kwargs["params"] + assert "timestamp" in params + assert "signature" in params + assert "recvWindow" in params + + async def test_request_raises_when_not_connected(self, adapter: BinanceSpotAdapter) -> None: + """Request raises when not connected.""" + with pytest.raises(ExchangeError, match="Not connected"): + await adapter._request("GET", "/api/v3/test") + + async def test_request_handles_auth_error( + self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock + ) -> None: + """Authentication errors raise AuthenticationError.""" + adapter._client = mock_http_client + response = _make_response(401, {"code": -1022, "msg": "Invalid signature"}) + mock_http_client.get.return_value = response + + with pytest.raises(AuthenticationError): + await adapter._request("GET", "/api/v3/test") + + async def test_request_handles_general_error( + self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock + ) -> None: + """General API errors raise ExchangeError.""" + adapter._client = mock_http_client + response = _make_response(400, {"code": -1100, "msg": "Illegal characters"}) + mock_http_client.get.return_value = response + + with pytest.raises(ExchangeError, match="Illegal characters"): + await adapter._request("GET", "/api/v3/test") + + +class TestBinanceSpotAdapterBalance: + """Balance retrieval tests.""" + + async def test_get_balance_returns_asset( + self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock + ) -> None: + """get_balance returns balance for specific asset.""" + adapter._client = mock_http_client + response = _make_response(200, _sample_account_response()) + mock_http_client.get.return_value = response + + balance = await adapter.get_balance("BTC") + + assert balance.asset == "BTC" + assert balance.free == Decimal("0.50000000") + assert balance.locked == Decimal("0.10000000") + assert balance.total == Decimal("0.60000000") + + async def test_get_balance_raises_when_not_found( + self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock + ) -> None: + """get_balance raises when asset not found.""" + adapter._client = mock_http_client + response = _make_response(200, _sample_account_response()) + mock_http_client.get.return_value = response + + with pytest.raises(ExchangeError, match="not found"): + await adapter.get_balance("UNKNOWN") + + async def test_get_all_balances_returns_non_zero( + self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock + ) -> None: + """get_all_balances returns only non-zero balances.""" + adapter._client = mock_http_client + response = _make_response(200, _sample_account_response()) + mock_http_client.get.return_value = response + + balances = await adapter.get_all_balances() + + # BTC, ETH, USDT have non-zero balances; BNB is zero + assets = {b.asset for b in balances} + assert "BTC" in assets + assert "ETH" in assets + assert "USDT" in assets + assert "BNB" not in assets + + async def test_get_all_balances_with_min_balance( + self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock + ) -> None: + """get_all_balances filters by minimum balance.""" + adapter._client = mock_http_client + response = _make_response(200, _sample_account_response()) + mock_http_client.get.return_value = response + + # min_balance of 100 should only include USDT (1500 total) + balances = await adapter.get_all_balances(min_balance=Decimal("100")) + + assets = {b.asset for b in balances} + assert "USDT" in assets + assert "BTC" not in assets # Only 0.6 total + assert "ETH" not in assets # Only 5.0 total + + async def test_get_all_balances_empty_response( + self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock + ) -> None: + """get_all_balances handles empty balances.""" + adapter._client = mock_http_client + response = _make_response(200, {"balances": []}) + mock_http_client.get.return_value = response + + balances = await adapter.get_all_balances() + + assert balances == []