Add data layer (DuckDB storage, fetcher) and spot adapter with tests

- Add DataStorage class for DuckDB-based market data persistence
- Add DataFetcher for historical candle backfill and sync operations
- Add BinanceSpotAdapter for spot wallet balance queries
- Add binance_spot_base_url to Settings for spot testnet support
- Add comprehensive unit tests (50 new tests, 82 total)
- Coverage increased from 62% to 86%
This commit is contained in:
bnair123
2025-12-27 14:38:26 +04:00
parent 17d51c4f78
commit 7d63e43b7b
10 changed files with 1635 additions and 1 deletions

View File

@@ -5,9 +5,10 @@ connecting to cryptocurrency exchanges.
Supported Exchanges: Supported Exchanges:
- Binance USDⓈ-M Perpetual Futures (primary) - Binance USDⓈ-M Perpetual Futures (primary)
- Binance Spot (balance checking only)
Usage: Usage:
from tradefinder.adapters import BinanceUSDMAdapter from tradefinder.adapters import BinanceUSDMAdapter, BinanceSpotAdapter
from tradefinder.adapters.types import OrderRequest, Side, PositionSide, OrderType from tradefinder.adapters.types import OrderRequest, Side, PositionSide, OrderType
adapter = BinanceUSDMAdapter(settings) adapter = BinanceUSDMAdapter(settings)
@@ -23,6 +24,7 @@ from tradefinder.adapters.base import (
OrderValidationError, OrderValidationError,
RateLimitError, RateLimitError,
) )
from tradefinder.adapters.binance_spot import BinanceSpotAdapter, SpotBalance
from tradefinder.adapters.binance_usdm import BinanceUSDMAdapter from tradefinder.adapters.binance_usdm import BinanceUSDMAdapter
from tradefinder.adapters.types import ( from tradefinder.adapters.types import (
AccountBalance, AccountBalance,
@@ -45,6 +47,7 @@ __all__ = [
"ExchangeAdapter", "ExchangeAdapter",
# Implementations # Implementations
"BinanceUSDMAdapter", "BinanceUSDMAdapter",
"BinanceSpotAdapter",
# Types # Types
"AccountBalance", "AccountBalance",
"Candle", "Candle",
@@ -57,6 +60,7 @@ __all__ = [
"Position", "Position",
"PositionSide", "PositionSide",
"Side", "Side",
"SpotBalance",
"SymbolInfo", "SymbolInfo",
"TimeInForce", "TimeInForce",
# Exceptions # Exceptions

View File

@@ -0,0 +1,214 @@
"""Binance Spot API adapter for balance checking.
Provides read-only access to spot wallet balances.
Uses the same API credentials as the futures adapter.
"""
import hashlib
import hmac
import time
from dataclasses import dataclass
from datetime import UTC, datetime
from decimal import Decimal
from typing import Any
from urllib.parse import urlencode
import httpx
import structlog
from tradefinder.adapters.base import AuthenticationError, ExchangeError
from tradefinder.core.config import Settings
logger = structlog.get_logger(__name__)
@dataclass
class SpotBalance:
"""Spot wallet balance for an asset."""
asset: str
free: Decimal
locked: Decimal
updated_at: datetime
@property
def total(self) -> Decimal:
"""Total balance (free + locked)."""
return self.free + self.locked
class BinanceSpotAdapter:
"""Binance Spot API adapter for balance checking.
This adapter connects to Binance Spot (either testnet or production)
and provides read-only access to spot wallet balances.
Usage:
settings = get_settings()
adapter = BinanceSpotAdapter(settings)
await adapter.connect()
balances = await adapter.get_all_balances()
await adapter.disconnect()
"""
def __init__(self, settings: Settings) -> None:
"""Initialize adapter with settings.
Args:
settings: Application settings containing API credentials
"""
self.settings = settings
self.base_url = settings.binance_spot_base_url
self._client: httpx.AsyncClient | None = None
self._recv_window = 5000
@property
def _api_key(self) -> str:
"""Get active API key."""
key = self.settings.get_active_api_key()
if key is None:
raise AuthenticationError("No API key configured for current trading mode")
return key.get_secret_value()
@property
def _secret(self) -> str:
"""Get active secret."""
secret = self.settings.get_active_secret()
if secret is None:
raise AuthenticationError("No secret configured for current trading mode")
return secret.get_secret_value()
def _sign(self, params: dict[str, Any]) -> str:
"""Generate HMAC-SHA256 signature for request."""
query_string = urlencode(params)
signature = hmac.new(
self._secret.encode("utf-8"),
query_string.encode("utf-8"),
hashlib.sha256,
).hexdigest()
return signature
def _get_timestamp(self) -> int:
"""Get current timestamp in milliseconds."""
return int(time.time() * 1000)
async def _request(
self,
method: str,
endpoint: str,
params: dict[str, Any] | None = None,
signed: bool = False,
) -> Any:
"""Make HTTP request to Binance Spot API."""
if self._client is None:
raise ExchangeError("Not connected. Call connect() first.")
params = params or {}
headers = {"X-MBX-APIKEY": self._api_key}
if signed:
params["timestamp"] = self._get_timestamp()
params["recvWindow"] = self._recv_window
params["signature"] = self._sign(params)
url = f"{self.base_url}{endpoint}"
try:
if method == "GET":
response = await self._client.get(url, params=params, headers=headers)
else:
raise ValueError(f"Unsupported HTTP method: {method}")
data = response.json()
if response.status_code >= 400:
code = data.get("code", 0)
msg = data.get("msg", "Unknown error")
logger.error("API error", status_code=response.status_code, code=code, msg=msg)
if code in (-1021, -1022):
raise AuthenticationError(msg)
raise ExchangeError(f"[{code}] {msg}")
return data
except httpx.RequestError as e:
logger.error("Request failed", endpoint=endpoint, error=str(e))
raise ExchangeError(f"Request failed: {e}") from e
async def connect(self) -> None:
"""Establish connection and validate credentials."""
self._client = httpx.AsyncClient(timeout=30.0)
try:
# Validate credentials by fetching account info
await self._request("GET", "/api/v3/account", signed=True)
logger.info(
"Connected to Binance Spot",
mode=self.settings.trading_mode.value,
base_url=self.base_url,
)
except Exception as e:
await self.disconnect()
raise AuthenticationError(f"Failed to authenticate: {e}") from e
async def disconnect(self) -> None:
"""Close HTTP client."""
if self._client:
await self._client.aclose()
self._client = None
logger.info("Disconnected from Binance Spot")
async def get_balance(self, asset: str) -> SpotBalance:
"""Get spot wallet balance for a specific asset.
Args:
asset: Asset symbol (e.g., "BTC", "USDT")
Returns:
SpotBalance with free and locked amounts
Raises:
ExchangeError: If asset not found
"""
data = await self._request("GET", "/api/v3/account", signed=True)
for balance in data.get("balances", []):
if balance["asset"] == asset:
return SpotBalance(
asset=asset,
free=Decimal(balance["free"]),
locked=Decimal(balance["locked"]),
updated_at=datetime.now(UTC),
)
raise ExchangeError(f"Asset {asset} not found in spot wallet")
async def get_all_balances(self, min_balance: Decimal = Decimal("0")) -> list[SpotBalance]:
"""Get all spot wallet balances.
Args:
min_balance: Minimum total balance to include (default: 0, includes all)
Returns:
List of SpotBalance objects with non-zero balances
"""
data = await self._request("GET", "/api/v3/account", signed=True)
now = datetime.now(UTC)
balances = []
for balance in data.get("balances", []):
free = Decimal(balance["free"])
locked = Decimal(balance["locked"])
total = free + locked
if total > min_balance:
balances.append(
SpotBalance(
asset=balance["asset"],
free=free,
locked=locked,
updated_at=now,
)
)
return balances

View File

@@ -271,6 +271,20 @@ class Settings(BaseSettings):
return "wss://stream.binancefuture.com" return "wss://stream.binancefuture.com"
return "wss://fstream.binance.com" return "wss://fstream.binance.com"
@property
def binance_spot_base_url(self) -> str:
"""Get the appropriate Binance Spot API base URL for current mode."""
if self.trading_mode == TradingMode.TESTNET:
return "https://testnet.binance.vision"
return "https://api.binance.com"
@property
def binance_spot_ws_url(self) -> str:
"""Get the appropriate Binance Spot WebSocket URL for current mode."""
if self.trading_mode == TradingMode.TESTNET:
return "wss://testnet.binance.vision"
return "wss://stream.binance.com:9443"
def get_active_api_key(self) -> SecretStr | None: def get_active_api_key(self) -> SecretStr | None:
"""Get the API key for the current trading mode.""" """Get the API key for the current trading mode."""
if self.trading_mode == TradingMode.TESTNET: if self.trading_mode == TradingMode.TESTNET:

View File

@@ -0,0 +1,25 @@
"""Data ingestion and storage module for TradeFinder.
This module provides functionality for fetching, storing, and
retrieving market data.
Components:
- DataStorage: DuckDB storage manager
- DataFetcher: Historical data fetcher
- schemas: Database table definitions
Usage:
from tradefinder.data import DataStorage, DataFetcher
storage = DataStorage(Path("/data/tradefinder.duckdb"))
storage.connect()
storage.initialize_schema()
fetcher = DataFetcher(adapter, storage)
await fetcher.sync_candles("BTCUSDT", "4h")
"""
from tradefinder.data.fetcher import DataFetcher
from tradefinder.data.storage import DataStorage
__all__ = ["DataStorage", "DataFetcher"]

View File

@@ -0,0 +1,207 @@
"""Historical data fetcher for market data.
Fetches historical OHLCV data from exchange and stores in database.
"""
from datetime import datetime, timedelta
import structlog
from tradefinder.adapters.base import ExchangeAdapter
from tradefinder.data.storage import DataStorage
logger = structlog.get_logger(__name__)
# Timeframe to milliseconds mapping
TIMEFRAME_MS = {
"1m": 60 * 1000,
"3m": 3 * 60 * 1000,
"5m": 5 * 60 * 1000,
"15m": 15 * 60 * 1000,
"30m": 30 * 60 * 1000,
"1h": 60 * 60 * 1000,
"2h": 2 * 60 * 60 * 1000,
"4h": 4 * 60 * 60 * 1000,
"6h": 6 * 60 * 60 * 1000,
"8h": 8 * 60 * 60 * 1000,
"12h": 12 * 60 * 60 * 1000,
"1d": 24 * 60 * 60 * 1000,
"3d": 3 * 24 * 60 * 60 * 1000,
"1w": 7 * 24 * 60 * 60 * 1000,
}
class DataFetcher:
"""Fetches historical market data from exchange.
Usage:
fetcher = DataFetcher(adapter, storage)
await fetcher.backfill_candles("BTCUSDT", "4h", start_date, end_date)
await fetcher.sync_candles("BTCUSDT", "4h")
"""
def __init__(self, adapter: ExchangeAdapter, storage: DataStorage) -> None:
"""Initialize fetcher with adapter and storage.
Args:
adapter: Exchange adapter for fetching data
storage: Data storage for persisting data
"""
self.adapter = adapter
self.storage = storage
async def backfill_candles(
self,
symbol: str,
timeframe: str,
start_date: datetime,
end_date: datetime | None = None,
batch_size: int = 1000,
) -> int:
"""Fetch and store historical candle data.
Args:
symbol: Trading symbol (e.g., "BTCUSDT")
timeframe: Candle timeframe (e.g., "1h", "4h")
start_date: Start date for backfill
end_date: End date for backfill (default: now)
batch_size: Number of candles per API request (max 1500)
Returns:
Total number of candles fetched and stored
"""
if end_date is None:
end_date = datetime.now()
tf_ms = TIMEFRAME_MS.get(timeframe)
if tf_ms is None:
raise ValueError(f"Unknown timeframe: {timeframe}")
start_ms = int(start_date.timestamp() * 1000)
end_ms = int(end_date.timestamp() * 1000)
total_fetched = 0
current_start = start_ms
logger.info(
"Starting backfill",
symbol=symbol,
timeframe=timeframe,
start=start_date.isoformat(),
end=end_date.isoformat(),
)
while current_start < end_ms:
# Fetch batch of candles
candles = await self.adapter.get_candles(
symbol=symbol,
timeframe=timeframe,
limit=min(batch_size, 1500),
start_time=current_start,
end_time=end_ms,
)
if not candles:
break
# Store candles
inserted = self.storage.insert_candles(candles, symbol, timeframe)
total_fetched += inserted
# Move to next batch
last_candle_ts = int(candles[-1].timestamp.timestamp() * 1000)
current_start = last_candle_ts + tf_ms
logger.debug(
"Fetched batch",
symbol=symbol,
timeframe=timeframe,
count=len(candles),
total=total_fetched,
)
# Break if we got fewer candles than requested (end of data)
if len(candles) < batch_size:
break
logger.info(
"Backfill complete",
symbol=symbol,
timeframe=timeframe,
total_candles=total_fetched,
)
return total_fetched
async def sync_candles(
self,
symbol: str,
timeframe: str,
lookback_days: int = 30,
) -> int:
"""Sync candles from last stored timestamp to now.
If no candles exist, fetches from lookback_days ago.
Args:
symbol: Trading symbol
timeframe: Candle timeframe
lookback_days: Days to look back if no existing data
Returns:
Number of new candles fetched
"""
# Get latest stored candle timestamp
latest_ts = self.storage.get_latest_candle_timestamp(symbol, timeframe)
if latest_ts:
# Start from next candle after latest
tf_ms = TIMEFRAME_MS.get(timeframe, 3600000)
start_date = latest_ts + timedelta(milliseconds=tf_ms)
logger.info(
"Syncing from last candle",
symbol=symbol,
timeframe=timeframe,
last_candle=latest_ts.isoformat(),
)
else:
# No existing data, use lookback
start_date = datetime.now() - timedelta(days=lookback_days)
logger.info(
"No existing data, fetching from lookback",
symbol=symbol,
timeframe=timeframe,
lookback_days=lookback_days,
)
return await self.backfill_candles(
symbol=symbol,
timeframe=timeframe,
start_date=start_date,
)
async def fetch_latest_candles(
self,
symbol: str,
timeframe: str,
limit: int = 100,
) -> int:
"""Fetch and store the most recent candles.
Args:
symbol: Trading symbol
timeframe: Candle timeframe
limit: Number of recent candles to fetch
Returns:
Number of candles fetched
"""
candles = await self.adapter.get_candles(
symbol=symbol,
timeframe=timeframe,
limit=limit,
)
if candles:
return self.storage.insert_candles(candles, symbol, timeframe)
return 0

View File

@@ -0,0 +1,54 @@
"""Database schemas for market data storage."""
CANDLES_SCHEMA = """
CREATE TABLE IF NOT EXISTS candles (
symbol VARCHAR NOT NULL,
timeframe VARCHAR NOT NULL,
timestamp TIMESTAMP NOT NULL,
open DOUBLE NOT NULL,
high DOUBLE NOT NULL,
low DOUBLE NOT NULL,
close DOUBLE NOT NULL,
volume DOUBLE NOT NULL,
PRIMARY KEY (symbol, timeframe, timestamp)
);
CREATE INDEX IF NOT EXISTS idx_candles_symbol_tf
ON candles (symbol, timeframe, timestamp DESC);
"""
TRADES_SCHEMA = """
CREATE TABLE IF NOT EXISTS trades (
id VARCHAR PRIMARY KEY,
symbol VARCHAR NOT NULL,
side VARCHAR NOT NULL,
position_side VARCHAR NOT NULL,
entry_price DOUBLE NOT NULL,
exit_price DOUBLE,
quantity DOUBLE NOT NULL,
pnl_usdt DOUBLE,
pnl_pct DOUBLE,
strategy VARCHAR NOT NULL,
entry_time TIMESTAMP NOT NULL,
exit_time TIMESTAMP,
status VARCHAR NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_trades_symbol ON trades (symbol);
CREATE INDEX IF NOT EXISTS idx_trades_strategy ON trades (strategy);
CREATE INDEX IF NOT EXISTS idx_trades_entry_time ON trades (entry_time DESC);
"""
FUNDING_RATES_SCHEMA = """
CREATE TABLE IF NOT EXISTS funding_rates (
symbol VARCHAR NOT NULL,
funding_rate DOUBLE NOT NULL,
funding_time TIMESTAMP NOT NULL,
mark_price DOUBLE NOT NULL,
PRIMARY KEY (symbol, funding_time)
);
CREATE INDEX IF NOT EXISTS idx_funding_symbol ON funding_rates (symbol, funding_time DESC);
"""
ALL_SCHEMAS = [CANDLES_SCHEMA, TRADES_SCHEMA, FUNDING_RATES_SCHEMA]

View File

@@ -0,0 +1,242 @@
"""DuckDB storage manager for market data.
Provides async-compatible interface for storing and retrieving
market data using DuckDB.
"""
from datetime import datetime
from pathlib import Path
import duckdb
import structlog
from tradefinder.adapters.types import Candle, FundingRate
from tradefinder.data.schemas import ALL_SCHEMAS
logger = structlog.get_logger(__name__)
class DataStorage:
"""DuckDB storage manager for market data.
Usage:
storage = DataStorage(Path("/data/tradefinder.duckdb"))
storage.connect()
storage.initialize_schema()
storage.insert_candles(candles)
storage.disconnect()
"""
def __init__(self, db_path: Path) -> None:
"""Initialize storage with database path.
Args:
db_path: Path to DuckDB database file
"""
self.db_path = db_path
self._conn: duckdb.DuckDBPyConnection | None = None
def connect(self) -> None:
"""Connect to the database."""
# Ensure parent directory exists
self.db_path.parent.mkdir(parents=True, exist_ok=True)
self._conn = duckdb.connect(str(self.db_path))
logger.info("Connected to DuckDB", path=str(self.db_path))
def disconnect(self) -> None:
"""Close database connection."""
if self._conn:
self._conn.close()
self._conn = None
logger.info("Disconnected from DuckDB")
def __enter__(self) -> "DataStorage":
"""Context manager entry."""
self.connect()
return self
def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None:
"""Context manager exit."""
self.disconnect()
@property
def conn(self) -> duckdb.DuckDBPyConnection:
"""Get database connection."""
if self._conn is None:
raise RuntimeError("Not connected. Call connect() first.")
return self._conn
def initialize_schema(self) -> None:
"""Create all database tables and indexes."""
for schema in ALL_SCHEMAS:
# Execute each statement separately
for statement in schema.strip().split(";"):
statement = statement.strip()
if statement:
self.conn.execute(statement)
logger.info("Database schema initialized")
def insert_candles(self, candles: list[Candle], symbol: str, timeframe: str) -> int:
"""Insert candles into the database.
Args:
candles: List of Candle objects to insert
symbol: Trading symbol (e.g., "BTCUSDT")
timeframe: Candle timeframe (e.g., "1h", "4h")
Returns:
Number of candles inserted
"""
if not candles:
return 0
# Prepare data for insertion
data = [
(
symbol,
timeframe,
c.timestamp,
float(c.open),
float(c.high),
float(c.low),
float(c.close),
float(c.volume),
)
for c in candles
]
# Use INSERT OR REPLACE to handle duplicates
self.conn.executemany(
"""
INSERT OR REPLACE INTO candles
(symbol, timeframe, timestamp, open, high, low, close, volume)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
""",
data,
)
logger.debug(
"Inserted candles",
symbol=symbol,
timeframe=timeframe,
count=len(candles),
)
return len(candles)
def get_candles(
self,
symbol: str,
timeframe: str,
start: datetime | None = None,
end: datetime | None = None,
limit: int | None = None,
) -> list[Candle]:
"""Retrieve candles from the database.
Args:
symbol: Trading symbol
timeframe: Candle timeframe
start: Start timestamp (inclusive)
end: End timestamp (inclusive)
limit: Maximum number of candles to return
Returns:
List of Candle objects, oldest first
"""
query = """
SELECT timestamp, open, high, low, close, volume
FROM candles
WHERE symbol = ? AND timeframe = ?
"""
params: list[str | datetime | int] = [symbol, timeframe]
if start:
query += " AND timestamp >= ?"
params.append(start)
if end:
query += " AND timestamp <= ?"
params.append(end)
query += " ORDER BY timestamp ASC"
if limit:
query += " LIMIT ?"
params.append(limit)
result = self.conn.execute(query, params).fetchall()
from decimal import Decimal
return [
Candle(
timestamp=row[0],
open=Decimal(str(row[1])),
high=Decimal(str(row[2])),
low=Decimal(str(row[3])),
close=Decimal(str(row[4])),
volume=Decimal(str(row[5])),
)
for row in result
]
def get_latest_candle_timestamp(self, symbol: str, timeframe: str) -> datetime | None:
"""Get the timestamp of the most recent candle.
Args:
symbol: Trading symbol
timeframe: Candle timeframe
Returns:
Timestamp of latest candle, or None if no candles exist
"""
result = self.conn.execute(
"""
SELECT MAX(timestamp) FROM candles
WHERE symbol = ? AND timeframe = ?
""",
[symbol, timeframe],
).fetchone()
return result[0] if result and result[0] else None
def insert_funding_rate(self, rate: FundingRate) -> None:
"""Insert a funding rate record.
Args:
rate: FundingRate object to insert
"""
self.conn.execute(
"""
INSERT OR REPLACE INTO funding_rates
(symbol, funding_rate, funding_time, mark_price)
VALUES (?, ?, ?, ?)
""",
[
rate.symbol,
float(rate.funding_rate),
rate.funding_time,
float(rate.mark_price),
],
)
def get_candle_count(self, symbol: str, timeframe: str) -> int:
"""Get the number of candles stored for a symbol/timeframe.
Args:
symbol: Trading symbol
timeframe: Candle timeframe
Returns:
Number of candles
"""
result = self.conn.execute(
"""
SELECT COUNT(*) FROM candles
WHERE symbol = ? AND timeframe = ?
""",
[symbol, timeframe],
).fetchone()
return result[0] if result else 0

286
tests/test_data_fetcher.py Normal file
View File

@@ -0,0 +1,286 @@
"""Unit tests for DataFetcher (backfill and sync logic)."""
import tempfile
from datetime import datetime, timedelta
from decimal import Decimal
from pathlib import Path
from unittest.mock import AsyncMock, MagicMock
import pytest
from tradefinder.adapters.types import Candle
from tradefinder.data.fetcher import TIMEFRAME_MS, DataFetcher
from tradefinder.data.storage import DataStorage
def make_candle(timestamp: datetime) -> Candle:
"""Create a test candle at given timestamp."""
return Candle(
timestamp=timestamp,
open=Decimal("50000.00"),
high=Decimal("51000.00"),
low=Decimal("49000.00"),
close=Decimal("50500.00"),
volume=Decimal("1000.00"),
)
class TestTimeframeMappings:
"""Tests for timeframe constant mappings."""
def test_common_timeframes_are_defined(self) -> None:
"""All expected timeframes have millisecond mappings."""
expected = ["1m", "5m", "15m", "30m", "1h", "4h", "1d", "1w"]
for tf in expected:
assert tf in TIMEFRAME_MS
assert TIMEFRAME_MS[tf] > 0
def test_timeframe_values_are_correct(self) -> None:
"""Timeframe millisecond values are accurate."""
assert TIMEFRAME_MS["1m"] == 60 * 1000
assert TIMEFRAME_MS["1h"] == 60 * 60 * 1000
assert TIMEFRAME_MS["4h"] == 4 * 60 * 60 * 1000
assert TIMEFRAME_MS["1d"] == 24 * 60 * 60 * 1000
class TestDataFetcherBackfill:
"""Tests for backfill_candles functionality."""
async def test_backfill_fetches_and_stores_candles(self) -> None:
"""Candles are fetched from adapter and stored."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
# Mock adapter
adapter = MagicMock()
start = datetime(2024, 1, 1, 0, 0, 0)
end = datetime(2024, 1, 1, 4, 0, 0)
# Return 4 candles then empty (end of data)
candles = [make_candle(start + timedelta(hours=i)) for i in range(4)]
adapter.get_candles = AsyncMock(side_effect=[candles, []])
with DataStorage(db_path) as storage:
storage.initialize_schema()
fetcher = DataFetcher(adapter, storage)
total = await fetcher.backfill_candles("BTCUSDT", "1h", start, end)
assert total == 4
assert storage.get_candle_count("BTCUSDT", "1h") == 4
async def test_backfill_uses_default_end_date(self) -> None:
"""End date defaults to now if not provided."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
adapter = MagicMock()
adapter.get_candles = AsyncMock(return_value=[]) # No data
with DataStorage(db_path) as storage:
storage.initialize_schema()
fetcher = DataFetcher(adapter, storage)
start = datetime.now() - timedelta(hours=1)
await fetcher.backfill_candles("BTCUSDT", "1h", start)
# Should have been called (even if no data returned)
adapter.get_candles.assert_called()
async def test_backfill_raises_on_unknown_timeframe(self) -> None:
"""ValueError raised for unknown timeframe."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
adapter = MagicMock()
with DataStorage(db_path) as storage:
storage.initialize_schema()
fetcher = DataFetcher(adapter, storage)
with pytest.raises(ValueError, match="Unknown timeframe"):
await fetcher.backfill_candles(
"BTCUSDT",
"invalid_tf",
datetime(2024, 1, 1),
)
async def test_backfill_handles_empty_response(self) -> None:
"""Empty response from adapter is handled gracefully."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
adapter = MagicMock()
adapter.get_candles = AsyncMock(return_value=[])
with DataStorage(db_path) as storage:
storage.initialize_schema()
fetcher = DataFetcher(adapter, storage)
total = await fetcher.backfill_candles(
"BTCUSDT",
"1h",
datetime(2024, 1, 1),
datetime(2024, 1, 2),
)
assert total == 0
async def test_backfill_respects_batch_size(self) -> None:
"""Batch size is respected and capped at 1500."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
adapter = MagicMock()
adapter.get_candles = AsyncMock(return_value=[])
with DataStorage(db_path) as storage:
storage.initialize_schema()
fetcher = DataFetcher(adapter, storage)
# Request with batch_size > 1500
await fetcher.backfill_candles(
"BTCUSDT",
"1h",
datetime(2024, 1, 1),
datetime(2024, 1, 2),
batch_size=2000,
)
# Verify limit was capped at 1500
call_kwargs = adapter.get_candles.call_args.kwargs
assert call_kwargs["limit"] == 1500
async def test_backfill_paginates_correctly(self) -> None:
"""Multiple batches are fetched for large date ranges."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
adapter = MagicMock()
start = datetime(2024, 1, 1, 0, 0, 0)
# Return 3 candles per batch, simulate 2 full batches + partial
batch1 = [make_candle(start + timedelta(hours=i)) for i in range(3)]
batch2 = [make_candle(start + timedelta(hours=i + 3)) for i in range(3)]
batch3 = [make_candle(start + timedelta(hours=6))] # Partial batch
adapter.get_candles = AsyncMock(side_effect=[batch1, batch2, batch3])
with DataStorage(db_path) as storage:
storage.initialize_schema()
fetcher = DataFetcher(adapter, storage)
total = await fetcher.backfill_candles(
"BTCUSDT",
"1h",
start,
start + timedelta(hours=10),
batch_size=3,
)
# 3 + 3 + 1 = 7 candles
assert total == 7
assert adapter.get_candles.call_count == 3
class TestDataFetcherSync:
"""Tests for sync_candles functionality."""
async def test_sync_from_latest_timestamp(self) -> None:
"""Sync starts from last stored candle."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
adapter = MagicMock()
adapter.get_candles = AsyncMock(return_value=[])
with DataStorage(db_path) as storage:
storage.initialize_schema()
# Pre-populate with some candles
base_ts = datetime(2024, 1, 1, 0, 0, 0)
existing = [make_candle(base_ts + timedelta(hours=i)) for i in range(5)]
storage.insert_candles(existing, "BTCUSDT", "1h")
fetcher = DataFetcher(adapter, storage)
await fetcher.sync_candles("BTCUSDT", "1h")
# Verify sync started from after the last candle
call_kwargs = adapter.get_candles.call_args.kwargs
start_time = call_kwargs.get("start_time")
# Should be after hour 4 (the last existing candle)
expected_start_ms = int((base_ts + timedelta(hours=5)).timestamp() * 1000)
assert start_time >= expected_start_ms
async def test_sync_with_no_existing_data_uses_lookback(self) -> None:
"""Sync uses lookback when no existing data."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
adapter = MagicMock()
adapter.get_candles = AsyncMock(return_value=[])
with DataStorage(db_path) as storage:
storage.initialize_schema()
fetcher = DataFetcher(adapter, storage)
await fetcher.sync_candles("BTCUSDT", "1h", lookback_days=7)
# Should have called get_candles
adapter.get_candles.assert_called()
class TestDataFetcherLatest:
"""Tests for fetch_latest_candles functionality."""
async def test_fetch_latest_stores_candles(self) -> None:
"""Latest candles are fetched and stored."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
adapter = MagicMock()
now = datetime.now()
candles = [make_candle(now - timedelta(hours=i)) for i in range(5)]
adapter.get_candles = AsyncMock(return_value=candles)
with DataStorage(db_path) as storage:
storage.initialize_schema()
fetcher = DataFetcher(adapter, storage)
count = await fetcher.fetch_latest_candles("BTCUSDT", "1h", limit=5)
assert count == 5
assert storage.get_candle_count("BTCUSDT", "1h") == 5
async def test_fetch_latest_with_empty_response(self) -> None:
"""Empty response returns 0."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
adapter = MagicMock()
adapter.get_candles = AsyncMock(return_value=[])
with DataStorage(db_path) as storage:
storage.initialize_schema()
fetcher = DataFetcher(adapter, storage)
count = await fetcher.fetch_latest_candles("BTCUSDT", "1h")
assert count == 0
async def test_fetch_latest_respects_limit(self) -> None:
"""Limit parameter is passed to adapter."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
adapter = MagicMock()
adapter.get_candles = AsyncMock(return_value=[])
with DataStorage(db_path) as storage:
storage.initialize_schema()
fetcher = DataFetcher(adapter, storage)
await fetcher.fetch_latest_candles("BTCUSDT", "1h", limit=50)
call_kwargs = adapter.get_candles.call_args.kwargs
assert call_kwargs["limit"] == 50

293
tests/test_data_storage.py Normal file
View File

@@ -0,0 +1,293 @@
"""Unit tests for DataStorage (DuckDB operations)."""
import tempfile
from datetime import datetime, timedelta
from decimal import Decimal
from pathlib import Path
import pytest
from tradefinder.adapters.types import Candle, FundingRate
from tradefinder.data.storage import DataStorage
class TestDataStorageConnection:
"""Tests for connection lifecycle."""
def test_connect_creates_database_file(self) -> None:
"""Database file is created on connect."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
storage = DataStorage(db_path)
storage.connect()
assert db_path.exists()
storage.disconnect()
def test_connect_creates_parent_directories(self) -> None:
"""Parent directories are created if they don't exist."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "nested" / "dir" / "test.duckdb"
storage = DataStorage(db_path)
storage.connect()
assert db_path.parent.exists()
storage.disconnect()
def test_disconnect_closes_connection(self) -> None:
"""Disconnect properly closes the connection."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
storage = DataStorage(db_path)
storage.connect()
storage.disconnect()
assert storage._conn is None
def test_context_manager_lifecycle(self) -> None:
"""Context manager properly opens and closes connection."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
assert storage._conn is not None
assert storage._conn is None
def test_conn_property_raises_when_not_connected(self) -> None:
"""Accessing conn property raises when not connected."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
storage = DataStorage(db_path)
with pytest.raises(RuntimeError, match="Not connected"):
_ = storage.conn
class TestDataStorageSchema:
"""Tests for schema initialization."""
def test_initialize_schema_creates_tables(self) -> None:
"""All tables are created by initialize_schema."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
# Verify tables exist
result = storage.conn.execute(
"SELECT table_name FROM information_schema.tables WHERE table_schema = 'main'"
).fetchall()
tables = {row[0] for row in result}
assert "candles" in tables
assert "trades" in tables
assert "funding_rates" in tables
def test_initialize_schema_is_idempotent(self) -> None:
"""Calling initialize_schema multiple times is safe."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
storage.initialize_schema() # Should not raise
class TestDataStorageCandles:
"""Tests for candle CRUD operations."""
def _make_candle(self, timestamp: datetime) -> Candle:
"""Create a test candle."""
return Candle(
timestamp=timestamp,
open=Decimal("50000.00"),
high=Decimal("51000.00"),
low=Decimal("49000.00"),
close=Decimal("50500.00"),
volume=Decimal("1000.50"),
)
def test_insert_candles_stores_data(self) -> None:
"""Candles are stored correctly."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
ts = datetime(2024, 1, 1, 12, 0, 0)
candles = [self._make_candle(ts)]
inserted = storage.insert_candles(candles, "BTCUSDT", "1h")
assert inserted == 1
def test_insert_candles_empty_list_returns_zero(self) -> None:
"""Inserting empty list returns 0."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
inserted = storage.insert_candles([], "BTCUSDT", "1h")
assert inserted == 0
def test_insert_candles_handles_duplicates(self) -> None:
"""Duplicate candles are replaced (upsert behavior)."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
ts = datetime(2024, 1, 1, 12, 0, 0)
candles = [self._make_candle(ts)]
storage.insert_candles(candles, "BTCUSDT", "1h")
storage.insert_candles(candles, "BTCUSDT", "1h") # Duplicate
count = storage.get_candle_count("BTCUSDT", "1h")
assert count == 1 # Not 2
def test_get_candles_retrieves_data(self) -> None:
"""Candles can be retrieved correctly."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
ts = datetime(2024, 1, 1, 12, 0, 0)
candles = [self._make_candle(ts)]
storage.insert_candles(candles, "BTCUSDT", "1h")
result = storage.get_candles("BTCUSDT", "1h")
assert len(result) == 1
assert result[0].open == Decimal("50000.00")
assert result[0].close == Decimal("50500.00")
def test_get_candles_respects_time_range(self) -> None:
"""Candles are filtered by start/end time."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
base_ts = datetime(2024, 1, 1, 0, 0, 0)
candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(10)]
storage.insert_candles(candles, "BTCUSDT", "1h")
# Query middle range
start = base_ts + timedelta(hours=3)
end = base_ts + timedelta(hours=6)
result = storage.get_candles("BTCUSDT", "1h", start=start, end=end)
assert len(result) == 4 # hours 3, 4, 5, 6
def test_get_candles_respects_limit(self) -> None:
"""Candle count is limited correctly."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
base_ts = datetime(2024, 1, 1, 0, 0, 0)
candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(10)]
storage.insert_candles(candles, "BTCUSDT", "1h")
result = storage.get_candles("BTCUSDT", "1h", limit=5)
assert len(result) == 5
def test_get_candles_returns_ascending_order(self) -> None:
"""Candles are returned oldest first."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
base_ts = datetime(2024, 1, 1, 0, 0, 0)
# Insert in reverse order
candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(5)]
storage.insert_candles(candles[::-1], "BTCUSDT", "1h")
result = storage.get_candles("BTCUSDT", "1h")
timestamps = [c.timestamp for c in result]
assert timestamps == sorted(timestamps)
def test_get_candles_filters_by_symbol(self) -> None:
"""Candles are filtered by symbol."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
ts = datetime(2024, 1, 1, 12, 0, 0)
candles = [self._make_candle(ts)]
storage.insert_candles(candles, "BTCUSDT", "1h")
storage.insert_candles(candles, "ETHUSDT", "1h")
btc_result = storage.get_candles("BTCUSDT", "1h")
eth_result = storage.get_candles("ETHUSDT", "1h")
assert len(btc_result) == 1
assert len(eth_result) == 1
def test_get_latest_candle_timestamp(self) -> None:
"""Latest timestamp is returned correctly."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
base_ts = datetime(2024, 1, 1, 0, 0, 0)
candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(5)]
storage.insert_candles(candles, "BTCUSDT", "1h")
latest = storage.get_latest_candle_timestamp("BTCUSDT", "1h")
assert latest == base_ts + timedelta(hours=4)
def test_get_latest_candle_timestamp_returns_none_when_empty(self) -> None:
"""None is returned when no candles exist."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
latest = storage.get_latest_candle_timestamp("BTCUSDT", "1h")
assert latest is None
def test_get_candle_count(self) -> None:
"""Candle count is accurate."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
base_ts = datetime(2024, 1, 1, 0, 0, 0)
candles = [self._make_candle(base_ts + timedelta(hours=i)) for i in range(7)]
storage.insert_candles(candles, "BTCUSDT", "1h")
count = storage.get_candle_count("BTCUSDT", "1h")
assert count == 7
class TestDataStorageFundingRates:
"""Tests for funding rate operations."""
def test_insert_funding_rate(self) -> None:
"""Funding rate is stored correctly."""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test.duckdb"
with DataStorage(db_path) as storage:
storage.initialize_schema()
rate = FundingRate(
symbol="BTCUSDT",
funding_rate=Decimal("0.0001"),
funding_time=datetime(2024, 1, 1, 8, 0, 0),
mark_price=Decimal("50000.00"),
)
storage.insert_funding_rate(rate)
# Verify stored
result = storage.conn.execute(
"SELECT symbol, funding_rate FROM funding_rates"
).fetchone()
assert result is not None
assert result[0] == "BTCUSDT"

295
tests/test_spot_adapter.py Normal file
View File

@@ -0,0 +1,295 @@
"""Unit tests for BinanceSpotAdapter."""
from __future__ import annotations
import hashlib
import hmac
from datetime import UTC
from decimal import Decimal
from typing import Any
from unittest.mock import AsyncMock, Mock
from urllib.parse import urlencode
import pytest
from pydantic import SecretStr
from tradefinder.adapters.base import AuthenticationError, ExchangeError
from tradefinder.adapters.binance_spot import BinanceSpotAdapter, SpotBalance
from tradefinder.core.config import BinanceSettings, Settings, TradingMode
def _make_response(status_code: int, payload: Any) -> Mock:
"""Create a mock HTTP response."""
response = Mock()
response.status_code = status_code
response.json.return_value = payload
return response
@pytest.fixture
def settings() -> Settings:
"""Create test settings without loading .env file."""
return Settings(
trading_mode=TradingMode.TESTNET,
binance=BinanceSettings(
testnet_api_key=SecretStr("test-key"),
testnet_secret=SecretStr("test-secret"),
),
_env_file=None,
)
@pytest.fixture
def adapter(settings: Settings) -> BinanceSpotAdapter:
"""Create adapter with test settings."""
return BinanceSpotAdapter(settings)
@pytest.fixture
def mock_http_client() -> AsyncMock:
"""Create mock HTTP client."""
client = AsyncMock()
client.get = AsyncMock()
client.aclose = AsyncMock()
return client
def _sample_account_response() -> dict[str, Any]:
"""Sample Binance spot account response."""
return {
"makerCommission": 10,
"takerCommission": 10,
"balances": [
{"asset": "BTC", "free": "0.50000000", "locked": "0.10000000"},
{"asset": "ETH", "free": "5.00000000", "locked": "0.00000000"},
{"asset": "USDT", "free": "1000.00000000", "locked": "500.00000000"},
{"asset": "BNB", "free": "0.00000000", "locked": "0.00000000"},
],
}
class TestSpotBalanceDataclass:
"""Tests for SpotBalance dataclass."""
def test_total_property(self) -> None:
"""Total equals free plus locked."""
from datetime import datetime
balance = SpotBalance(
asset="BTC",
free=Decimal("1.5"),
locked=Decimal("0.5"),
updated_at=datetime.now(UTC),
)
assert balance.total == Decimal("2.0")
def test_zero_balance_total(self) -> None:
"""Zero balance returns zero total."""
from datetime import datetime
balance = SpotBalance(
asset="XRP",
free=Decimal("0"),
locked=Decimal("0"),
updated_at=datetime.now(UTC),
)
assert balance.total == Decimal("0")
class TestBinanceSpotAdapterInit:
"""Initialization and credential tests."""
def test_credentials_are_accessible(self, adapter: BinanceSpotAdapter) -> None:
"""API key and secret are exposed through properties."""
assert adapter._api_key == "test-key"
assert adapter._secret == "test-secret"
def test_base_url_uses_testnet(self, settings: Settings) -> None:
"""Testnet mode uses testnet URL."""
adapter = BinanceSpotAdapter(settings)
assert "testnet" in adapter.base_url.lower()
def test_sign_generates_valid_signature(self, adapter: BinanceSpotAdapter) -> None:
"""HMAC-SHA256 signature matches stdlib implementation."""
params = {"symbol": "BTCUSDT", "timestamp": "1234567890"}
expected = hmac.new(
adapter._secret.encode("utf-8"),
urlencode(params).encode("utf-8"),
hashlib.sha256,
).hexdigest()
assert adapter._sign(params) == expected
class TestBinanceSpotAdapterConnection:
"""Connection lifecycle tests."""
async def test_connect_creates_client(self, adapter: BinanceSpotAdapter) -> None:
"""Connect creates HTTP client."""
mock_client = AsyncMock()
mock_client.get = AsyncMock(return_value=_make_response(200, _sample_account_response()))
mock_client.aclose = AsyncMock()
import httpx
original_client = httpx.AsyncClient
try:
# Patch httpx.AsyncClient
httpx.AsyncClient = lambda **kwargs: mock_client
await adapter.connect()
assert adapter._client is mock_client
finally:
httpx.AsyncClient = original_client
await adapter.disconnect()
async def test_disconnect_closes_client(
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
) -> None:
"""Disconnect closes HTTP client."""
adapter._client = mock_http_client
await adapter.disconnect()
mock_http_client.aclose.assert_awaited_once()
assert adapter._client is None
async def test_disconnect_when_not_connected(self, adapter: BinanceSpotAdapter) -> None:
"""Disconnect when not connected is safe."""
await adapter.disconnect() # Should not raise
class TestBinanceSpotAdapterRequest:
"""Request handling tests."""
async def test_request_adds_auth_headers(
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
) -> None:
"""Requests include X-MBX-APIKEY header."""
adapter._client = mock_http_client
response = _make_response(200, {"ok": True})
mock_http_client.get.return_value = response
result = await adapter._request("GET", "/api/v3/test")
assert result == {"ok": True}
mock_http_client.get.assert_awaited_once()
called_headers = mock_http_client.get.call_args[1]["headers"]
assert called_headers["X-MBX-APIKEY"] == "test-key"
async def test_request_adds_signature_when_signed(
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
) -> None:
"""Signed requests include timestamp and signature."""
adapter._client = mock_http_client
response = _make_response(200, {"ok": True})
mock_http_client.get.return_value = response
await adapter._request("GET", "/api/v3/test", params={}, signed=True)
call_kwargs = mock_http_client.get.call_args[1]
params = call_kwargs["params"]
assert "timestamp" in params
assert "signature" in params
assert "recvWindow" in params
async def test_request_raises_when_not_connected(self, adapter: BinanceSpotAdapter) -> None:
"""Request raises when not connected."""
with pytest.raises(ExchangeError, match="Not connected"):
await adapter._request("GET", "/api/v3/test")
async def test_request_handles_auth_error(
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
) -> None:
"""Authentication errors raise AuthenticationError."""
adapter._client = mock_http_client
response = _make_response(401, {"code": -1022, "msg": "Invalid signature"})
mock_http_client.get.return_value = response
with pytest.raises(AuthenticationError):
await adapter._request("GET", "/api/v3/test")
async def test_request_handles_general_error(
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
) -> None:
"""General API errors raise ExchangeError."""
adapter._client = mock_http_client
response = _make_response(400, {"code": -1100, "msg": "Illegal characters"})
mock_http_client.get.return_value = response
with pytest.raises(ExchangeError, match="Illegal characters"):
await adapter._request("GET", "/api/v3/test")
class TestBinanceSpotAdapterBalance:
"""Balance retrieval tests."""
async def test_get_balance_returns_asset(
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
) -> None:
"""get_balance returns balance for specific asset."""
adapter._client = mock_http_client
response = _make_response(200, _sample_account_response())
mock_http_client.get.return_value = response
balance = await adapter.get_balance("BTC")
assert balance.asset == "BTC"
assert balance.free == Decimal("0.50000000")
assert balance.locked == Decimal("0.10000000")
assert balance.total == Decimal("0.60000000")
async def test_get_balance_raises_when_not_found(
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
) -> None:
"""get_balance raises when asset not found."""
adapter._client = mock_http_client
response = _make_response(200, _sample_account_response())
mock_http_client.get.return_value = response
with pytest.raises(ExchangeError, match="not found"):
await adapter.get_balance("UNKNOWN")
async def test_get_all_balances_returns_non_zero(
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
) -> None:
"""get_all_balances returns only non-zero balances."""
adapter._client = mock_http_client
response = _make_response(200, _sample_account_response())
mock_http_client.get.return_value = response
balances = await adapter.get_all_balances()
# BTC, ETH, USDT have non-zero balances; BNB is zero
assets = {b.asset for b in balances}
assert "BTC" in assets
assert "ETH" in assets
assert "USDT" in assets
assert "BNB" not in assets
async def test_get_all_balances_with_min_balance(
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
) -> None:
"""get_all_balances filters by minimum balance."""
adapter._client = mock_http_client
response = _make_response(200, _sample_account_response())
mock_http_client.get.return_value = response
# min_balance of 100 should only include USDT (1500 total)
balances = await adapter.get_all_balances(min_balance=Decimal("100"))
assets = {b.asset for b in balances}
assert "USDT" in assets
assert "BTC" not in assets # Only 0.6 total
assert "ETH" not in assets # Only 5.0 total
async def test_get_all_balances_empty_response(
self, adapter: BinanceSpotAdapter, mock_http_client: AsyncMock
) -> None:
"""get_all_balances handles empty balances."""
adapter._client = mock_http_client
response = _make_response(200, {"balances": []})
mock_http_client.get.return_value = response
balances = await adapter.get_all_balances()
assert balances == []