Phase 1 completion: - Add DataStreamer for real-time Binance Futures WebSocket data (klines, mark price) - Add DataValidator for candle validation and gap detection - Add timeframes module with interval mappings Phase 2 foundation: - Add RegimeClassifier with ADX/ATR/Bollinger Band analysis - Add Regime enum (TRENDING_UP/DOWN, RANGING, HIGH_VOLATILITY, UNCERTAIN) - Add Strategy ABC defining generate_signal, get_stop_loss, parameters, suitable_regimes - Add Signal dataclass and SignalType enum for strategy outputs Testing: - Add comprehensive test suites for all new modules - 159 tests passing, 24 skipped (async WebSocket timing) - 82% code coverage Dependencies: - Add pandas-stubs to dev dependencies for mypy compatibility
340 lines
12 KiB
Python
340 lines
12 KiB
Python
"""Unit tests for RegimeClassifier (market regime detection)."""
|
|
|
|
from decimal import Decimal
|
|
from unittest.mock import patch
|
|
|
|
import pandas as pd
|
|
|
|
from tradefinder.adapters.types import Candle
|
|
from tradefinder.core.regime import Regime, RegimeClassifier
|
|
|
|
|
|
class TestRegimeClassifierInit:
|
|
"""Tests for RegimeClassifier initialization."""
|
|
|
|
def test_init_default_parameters(self) -> None:
|
|
"""Default parameters are set correctly."""
|
|
classifier = RegimeClassifier()
|
|
assert classifier._adx_threshold == Decimal("25")
|
|
assert classifier._atr_lookback == 14
|
|
assert classifier._bb_period == 20
|
|
|
|
def test_init_custom_parameters(self) -> None:
|
|
"""Custom parameters are accepted."""
|
|
classifier = RegimeClassifier(adx_threshold=30, atr_lookback=20, bb_period=25)
|
|
assert classifier._adx_threshold == Decimal("30")
|
|
assert classifier._atr_lookback == 20
|
|
assert classifier._bb_period == 25
|
|
|
|
def test_init_validates_parameters(self) -> None:
|
|
"""Parameters are validated on init."""
|
|
classifier = RegimeClassifier(atr_lookback=0, bb_period=0)
|
|
assert classifier._atr_lookback == 1 # Min value
|
|
assert classifier._bb_period == 1 # Min value
|
|
|
|
|
|
class TestRegimeClassifierClassify:
|
|
"""Tests for regime classification logic."""
|
|
|
|
def test_classify_insufficient_data(self) -> None:
|
|
"""UNCERTAIN returned when insufficient data."""
|
|
classifier = RegimeClassifier()
|
|
candles = [
|
|
Candle(
|
|
timestamp=pd.Timestamp("2024-01-01"),
|
|
open=Decimal("50000"),
|
|
high=Decimal("51000"),
|
|
low=Decimal("49000"),
|
|
close=Decimal("50500"),
|
|
volume=Decimal("100"),
|
|
)
|
|
]
|
|
regime = classifier.classify(candles)
|
|
assert regime == Regime.UNCERTAIN
|
|
|
|
def test_classify_trending_up(self) -> None:
|
|
"""TRENDING_UP detected with high ADX and +DI > -DI."""
|
|
classifier = RegimeClassifier(adx_threshold=20) # Lower threshold for test
|
|
|
|
# Create mock indicators
|
|
classifier.get_indicators = lambda c: {
|
|
"adx": Decimal("25"),
|
|
"plus_di": Decimal("30"),
|
|
"minus_di": Decimal("15"),
|
|
"bb_width": Decimal("0.02"),
|
|
"atr_pct": Decimal("1.0"),
|
|
"atr_pct_avg": Decimal("0.5"),
|
|
}
|
|
|
|
candles = [
|
|
Candle(
|
|
timestamp=pd.Timestamp("2024-01-01"),
|
|
open=Decimal("50000"),
|
|
high=Decimal("51000"),
|
|
low=Decimal("49000"),
|
|
close=Decimal("50500"),
|
|
volume=Decimal("100"),
|
|
)
|
|
]
|
|
regime = classifier.classify(candles)
|
|
assert regime == Regime.TRENDING_UP
|
|
|
|
def test_classify_trending_down(self) -> None:
|
|
"""TRENDING_DOWN detected with high ADX and -DI > +DI."""
|
|
classifier = RegimeClassifier(adx_threshold=20)
|
|
|
|
classifier.get_indicators = lambda c: {
|
|
"adx": Decimal("25"),
|
|
"plus_di": Decimal("15"),
|
|
"minus_di": Decimal("30"),
|
|
"bb_width": Decimal("0.02"),
|
|
"atr_pct": Decimal("1.0"),
|
|
"atr_pct_avg": Decimal("0.5"),
|
|
}
|
|
|
|
candles = [
|
|
Candle(
|
|
timestamp=pd.Timestamp("2024-01-01"),
|
|
open=Decimal("50000"),
|
|
high=Decimal("51000"),
|
|
low=Decimal("49000"),
|
|
close=Decimal("50500"),
|
|
volume=Decimal("100"),
|
|
)
|
|
]
|
|
regime = classifier.classify(candles)
|
|
assert regime == Regime.TRENDING_DOWN
|
|
|
|
def test_classify_ranging(self) -> None:
|
|
"""RANGING detected with low ADX and narrow BB."""
|
|
classifier = RegimeClassifier()
|
|
|
|
classifier.get_indicators = lambda c: {
|
|
"adx": Decimal("15"), # Below ceiling
|
|
"plus_di": Decimal("20"),
|
|
"minus_di": Decimal("18"),
|
|
"bb_width": Decimal("0.02"), # Below threshold
|
|
"atr_pct": Decimal("0.5"),
|
|
"atr_pct_avg": Decimal("1.0"),
|
|
}
|
|
|
|
candles = [
|
|
Candle(
|
|
timestamp=pd.Timestamp("2024-01-01"),
|
|
open=Decimal("50000"),
|
|
high=Decimal("51000"),
|
|
low=Decimal("49000"),
|
|
close=Decimal("50500"),
|
|
volume=Decimal("100"),
|
|
)
|
|
]
|
|
regime = classifier.classify(candles)
|
|
assert regime == Regime.RANGING
|
|
|
|
def test_classify_high_volatility(self) -> None:
|
|
"""HIGH_VOLATILITY detected with high ATR% vs average."""
|
|
classifier = RegimeClassifier()
|
|
|
|
classifier.get_indicators = lambda c: {
|
|
"adx": Decimal("15"),
|
|
"plus_di": Decimal("20"),
|
|
"minus_di": Decimal("18"),
|
|
"bb_width": Decimal("0.10"), # Above threshold
|
|
"atr_pct": Decimal("2.0"), # Above 1.5x average
|
|
"atr_pct_avg": Decimal("1.0"),
|
|
}
|
|
|
|
candles = [
|
|
Candle(
|
|
timestamp=pd.Timestamp("2024-01-01"),
|
|
open=Decimal("50000"),
|
|
high=Decimal("51000"),
|
|
low=Decimal("49000"),
|
|
close=Decimal("50500"),
|
|
volume=Decimal("100"),
|
|
)
|
|
]
|
|
regime = classifier.classify(candles)
|
|
assert regime == Regime.HIGH_VOLATILITY
|
|
|
|
def test_classify_uncertain_fallback(self) -> None:
|
|
"""UNCERTAIN returned when no conditions met."""
|
|
classifier = RegimeClassifier()
|
|
|
|
classifier.get_indicators = lambda c: {
|
|
"adx": Decimal("15"),
|
|
"plus_di": Decimal("20"),
|
|
"minus_di": Decimal("18"),
|
|
"bb_width": Decimal("0.10"),
|
|
"atr_pct": Decimal("1.2"), # Below 1.5x threshold
|
|
"atr_pct_avg": Decimal("1.0"),
|
|
}
|
|
|
|
candles = [
|
|
Candle(
|
|
timestamp=pd.Timestamp("2024-01-01"),
|
|
open=Decimal("50000"),
|
|
high=Decimal("51000"),
|
|
low=Decimal("49000"),
|
|
close=Decimal("50500"),
|
|
volume=Decimal("100"),
|
|
)
|
|
]
|
|
regime = classifier.classify(candles)
|
|
assert regime == Regime.UNCERTAIN
|
|
|
|
|
|
class TestRegimeClassifierGetIndicators:
|
|
"""Tests for indicator calculation."""
|
|
|
|
def test_get_indicators_insufficient_data(self) -> None:
|
|
"""Empty dict returned when insufficient data."""
|
|
classifier = RegimeClassifier(atr_lookback=50, bb_period=50)
|
|
candles = [
|
|
Candle(
|
|
timestamp=pd.Timestamp("2024-01-01"),
|
|
open=Decimal("50000"),
|
|
high=Decimal("51000"),
|
|
low=Decimal("49000"),
|
|
close=Decimal("50500"),
|
|
volume=Decimal("100"),
|
|
)
|
|
]
|
|
indicators = classifier.get_indicators(candles)
|
|
assert indicators == {}
|
|
|
|
@patch("tradefinder.core.regime.ta.adx")
|
|
@patch("tradefinder.core.regime.ta.atr")
|
|
@patch("tradefinder.core.regime.ta.bbands")
|
|
def test_get_indicators_calculates_correctly(
|
|
self, mock_bbands: patch, mock_atr: patch, mock_adx: patch
|
|
) -> None:
|
|
"""Indicators are calculated and returned correctly."""
|
|
# Mock pandas TA functions
|
|
mock_adx.return_value = pd.DataFrame(
|
|
{
|
|
"ADX_14": [20.5, 21.0, 22.0],
|
|
"DMP_14": [25.0, 26.0, 27.0],
|
|
"DMN_14": [15.0, 16.0, 17.0],
|
|
}
|
|
)
|
|
|
|
mock_atr.return_value = pd.DataFrame(
|
|
{
|
|
"ATR_14": [1000.0, 1100.0, 1200.0],
|
|
}
|
|
)
|
|
|
|
mock_bbands.return_value = pd.DataFrame(
|
|
{
|
|
"BBL_20_2.0": [48000.0, 48500.0, 49000.0],
|
|
"BBM_20_2.0": [50000.0, 50500.0, 51000.0],
|
|
"BBU_20_2.0": [52000.0, 52500.0, 53000.0],
|
|
}
|
|
)
|
|
|
|
classifier = RegimeClassifier()
|
|
candles = [
|
|
Candle(
|
|
timestamp=pd.Timestamp("2024-01-01") + pd.Timedelta(days=i),
|
|
open=Decimal("50000"),
|
|
high=Decimal("51000"),
|
|
low=Decimal("49000"),
|
|
close=Decimal("50500"),
|
|
volume=Decimal("100"),
|
|
)
|
|
for i in range(25) # Enough data
|
|
]
|
|
|
|
indicators = classifier.get_indicators(candles)
|
|
|
|
assert "adx" in indicators
|
|
assert "plus_di" in indicators
|
|
assert "minus_di" in indicators
|
|
assert "atr" in indicators
|
|
assert "atr_pct" in indicators
|
|
assert "atr_pct_avg" in indicators
|
|
assert "bb_width" in indicators
|
|
|
|
assert indicators["adx"] == Decimal("22.0")
|
|
assert indicators["plus_di"] == Decimal("27.0")
|
|
assert indicators["minus_di"] == Decimal("17.0")
|
|
|
|
|
|
class TestRegimeClassifierStaticMethods:
|
|
"""Tests for static helper methods."""
|
|
|
|
def test_to_decimal_from_float(self) -> None:
|
|
"""Float conversion works."""
|
|
result = RegimeClassifier._to_decimal(123.45)
|
|
assert result == Decimal("123.45")
|
|
|
|
def test_to_decimal_from_int(self) -> None:
|
|
"""Int conversion works."""
|
|
result = RegimeClassifier._to_decimal(123)
|
|
assert result == Decimal("123")
|
|
|
|
def test_to_decimal_from_decimal_passthrough(self) -> None:
|
|
"""Decimal passthrough works."""
|
|
dec = Decimal("123.45")
|
|
result = RegimeClassifier._to_decimal(dec)
|
|
assert result == dec
|
|
|
|
def test_to_decimal_from_none(self) -> None:
|
|
"""None returns None."""
|
|
result = RegimeClassifier._to_decimal(None)
|
|
assert result is None
|
|
|
|
@patch("pandas.isna")
|
|
def test_to_decimal_from_nan(self, mock_isna: patch) -> None:
|
|
"""NaN values return None."""
|
|
mock_isna.return_value = True
|
|
result = RegimeClassifier._to_decimal(float("nan"))
|
|
assert result is None
|
|
|
|
def test_decimal_from_series_tail(self) -> None:
|
|
"""Last value from series is extracted."""
|
|
series = pd.Series([1.0, 2.0, 3.0])
|
|
result = RegimeClassifier._decimal_from_series_tail(series)
|
|
assert result == Decimal("3.0")
|
|
|
|
def test_decimal_from_series_tail_empty(self) -> None:
|
|
"""Empty series returns None."""
|
|
series = pd.Series([])
|
|
result = RegimeClassifier._decimal_from_series_tail(series)
|
|
assert result is None
|
|
|
|
def test_decimal_from_series_avg(self) -> None:
|
|
"""Average of series is calculated."""
|
|
series = pd.Series([1.0, 2.0, 3.0, 4.0])
|
|
result = RegimeClassifier._decimal_from_series_avg(series)
|
|
assert result == Decimal("2.5")
|
|
|
|
def test_decimal_from_series_avg_empty(self) -> None:
|
|
"""Empty series average returns None."""
|
|
series = pd.Series([])
|
|
result = RegimeClassifier._decimal_from_series_avg(series)
|
|
assert result is None
|
|
|
|
def test_candles_to_frame(self) -> None:
|
|
"""Candles are converted to DataFrame correctly."""
|
|
candles = [
|
|
Candle(
|
|
timestamp=pd.Timestamp("2024-01-01"),
|
|
open=Decimal("50000"),
|
|
high=Decimal("51000"),
|
|
low=Decimal("49000"),
|
|
close=Decimal("50500"),
|
|
volume=Decimal("100"),
|
|
)
|
|
]
|
|
df = RegimeClassifier._candles_to_frame(candles)
|
|
assert len(df) == 1
|
|
assert df.iloc[0]["open"] == 50000.0
|
|
assert df.iloc[0]["close"] == 50500.0
|
|
|
|
def test_candles_to_frame_empty(self) -> None:
|
|
"""Empty candles list returns empty DataFrame."""
|
|
df = RegimeClassifier._candles_to_frame([])
|
|
assert df.empty
|