Add Phase 2 foundation: regime classifier, strategy framework, WebSocket streamer
Phase 1 completion: - Add DataStreamer for real-time Binance Futures WebSocket data (klines, mark price) - Add DataValidator for candle validation and gap detection - Add timeframes module with interval mappings Phase 2 foundation: - Add RegimeClassifier with ADX/ATR/Bollinger Band analysis - Add Regime enum (TRENDING_UP/DOWN, RANGING, HIGH_VOLATILITY, UNCERTAIN) - Add Strategy ABC defining generate_signal, get_stop_loss, parameters, suitable_regimes - Add Signal dataclass and SignalType enum for strategy outputs Testing: - Add comprehensive test suites for all new modules - 159 tests passing, 24 skipped (async WebSocket timing) - 82% code coverage Dependencies: - Add pandas-stubs to dev dependencies for mypy compatibility
This commit is contained in:
339
tests/test_regime.py
Normal file
339
tests/test_regime.py
Normal file
@@ -0,0 +1,339 @@
|
||||
"""Unit tests for RegimeClassifier (market regime detection)."""
|
||||
|
||||
from decimal import Decimal
|
||||
from unittest.mock import patch
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from tradefinder.adapters.types import Candle
|
||||
from tradefinder.core.regime import Regime, RegimeClassifier
|
||||
|
||||
|
||||
class TestRegimeClassifierInit:
|
||||
"""Tests for RegimeClassifier initialization."""
|
||||
|
||||
def test_init_default_parameters(self) -> None:
|
||||
"""Default parameters are set correctly."""
|
||||
classifier = RegimeClassifier()
|
||||
assert classifier._adx_threshold == Decimal("25")
|
||||
assert classifier._atr_lookback == 14
|
||||
assert classifier._bb_period == 20
|
||||
|
||||
def test_init_custom_parameters(self) -> None:
|
||||
"""Custom parameters are accepted."""
|
||||
classifier = RegimeClassifier(adx_threshold=30, atr_lookback=20, bb_period=25)
|
||||
assert classifier._adx_threshold == Decimal("30")
|
||||
assert classifier._atr_lookback == 20
|
||||
assert classifier._bb_period == 25
|
||||
|
||||
def test_init_validates_parameters(self) -> None:
|
||||
"""Parameters are validated on init."""
|
||||
classifier = RegimeClassifier(atr_lookback=0, bb_period=0)
|
||||
assert classifier._atr_lookback == 1 # Min value
|
||||
assert classifier._bb_period == 1 # Min value
|
||||
|
||||
|
||||
class TestRegimeClassifierClassify:
|
||||
"""Tests for regime classification logic."""
|
||||
|
||||
def test_classify_insufficient_data(self) -> None:
|
||||
"""UNCERTAIN returned when insufficient data."""
|
||||
classifier = RegimeClassifier()
|
||||
candles = [
|
||||
Candle(
|
||||
timestamp=pd.Timestamp("2024-01-01"),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("51000"),
|
||||
low=Decimal("49000"),
|
||||
close=Decimal("50500"),
|
||||
volume=Decimal("100"),
|
||||
)
|
||||
]
|
||||
regime = classifier.classify(candles)
|
||||
assert regime == Regime.UNCERTAIN
|
||||
|
||||
def test_classify_trending_up(self) -> None:
|
||||
"""TRENDING_UP detected with high ADX and +DI > -DI."""
|
||||
classifier = RegimeClassifier(adx_threshold=20) # Lower threshold for test
|
||||
|
||||
# Create mock indicators
|
||||
classifier.get_indicators = lambda c: {
|
||||
"adx": Decimal("25"),
|
||||
"plus_di": Decimal("30"),
|
||||
"minus_di": Decimal("15"),
|
||||
"bb_width": Decimal("0.02"),
|
||||
"atr_pct": Decimal("1.0"),
|
||||
"atr_pct_avg": Decimal("0.5"),
|
||||
}
|
||||
|
||||
candles = [
|
||||
Candle(
|
||||
timestamp=pd.Timestamp("2024-01-01"),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("51000"),
|
||||
low=Decimal("49000"),
|
||||
close=Decimal("50500"),
|
||||
volume=Decimal("100"),
|
||||
)
|
||||
]
|
||||
regime = classifier.classify(candles)
|
||||
assert regime == Regime.TRENDING_UP
|
||||
|
||||
def test_classify_trending_down(self) -> None:
|
||||
"""TRENDING_DOWN detected with high ADX and -DI > +DI."""
|
||||
classifier = RegimeClassifier(adx_threshold=20)
|
||||
|
||||
classifier.get_indicators = lambda c: {
|
||||
"adx": Decimal("25"),
|
||||
"plus_di": Decimal("15"),
|
||||
"minus_di": Decimal("30"),
|
||||
"bb_width": Decimal("0.02"),
|
||||
"atr_pct": Decimal("1.0"),
|
||||
"atr_pct_avg": Decimal("0.5"),
|
||||
}
|
||||
|
||||
candles = [
|
||||
Candle(
|
||||
timestamp=pd.Timestamp("2024-01-01"),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("51000"),
|
||||
low=Decimal("49000"),
|
||||
close=Decimal("50500"),
|
||||
volume=Decimal("100"),
|
||||
)
|
||||
]
|
||||
regime = classifier.classify(candles)
|
||||
assert regime == Regime.TRENDING_DOWN
|
||||
|
||||
def test_classify_ranging(self) -> None:
|
||||
"""RANGING detected with low ADX and narrow BB."""
|
||||
classifier = RegimeClassifier()
|
||||
|
||||
classifier.get_indicators = lambda c: {
|
||||
"adx": Decimal("15"), # Below ceiling
|
||||
"plus_di": Decimal("20"),
|
||||
"minus_di": Decimal("18"),
|
||||
"bb_width": Decimal("0.02"), # Below threshold
|
||||
"atr_pct": Decimal("0.5"),
|
||||
"atr_pct_avg": Decimal("1.0"),
|
||||
}
|
||||
|
||||
candles = [
|
||||
Candle(
|
||||
timestamp=pd.Timestamp("2024-01-01"),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("51000"),
|
||||
low=Decimal("49000"),
|
||||
close=Decimal("50500"),
|
||||
volume=Decimal("100"),
|
||||
)
|
||||
]
|
||||
regime = classifier.classify(candles)
|
||||
assert regime == Regime.RANGING
|
||||
|
||||
def test_classify_high_volatility(self) -> None:
|
||||
"""HIGH_VOLATILITY detected with high ATR% vs average."""
|
||||
classifier = RegimeClassifier()
|
||||
|
||||
classifier.get_indicators = lambda c: {
|
||||
"adx": Decimal("15"),
|
||||
"plus_di": Decimal("20"),
|
||||
"minus_di": Decimal("18"),
|
||||
"bb_width": Decimal("0.10"), # Above threshold
|
||||
"atr_pct": Decimal("2.0"), # Above 1.5x average
|
||||
"atr_pct_avg": Decimal("1.0"),
|
||||
}
|
||||
|
||||
candles = [
|
||||
Candle(
|
||||
timestamp=pd.Timestamp("2024-01-01"),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("51000"),
|
||||
low=Decimal("49000"),
|
||||
close=Decimal("50500"),
|
||||
volume=Decimal("100"),
|
||||
)
|
||||
]
|
||||
regime = classifier.classify(candles)
|
||||
assert regime == Regime.HIGH_VOLATILITY
|
||||
|
||||
def test_classify_uncertain_fallback(self) -> None:
|
||||
"""UNCERTAIN returned when no conditions met."""
|
||||
classifier = RegimeClassifier()
|
||||
|
||||
classifier.get_indicators = lambda c: {
|
||||
"adx": Decimal("15"),
|
||||
"plus_di": Decimal("20"),
|
||||
"minus_di": Decimal("18"),
|
||||
"bb_width": Decimal("0.10"),
|
||||
"atr_pct": Decimal("1.2"), # Below 1.5x threshold
|
||||
"atr_pct_avg": Decimal("1.0"),
|
||||
}
|
||||
|
||||
candles = [
|
||||
Candle(
|
||||
timestamp=pd.Timestamp("2024-01-01"),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("51000"),
|
||||
low=Decimal("49000"),
|
||||
close=Decimal("50500"),
|
||||
volume=Decimal("100"),
|
||||
)
|
||||
]
|
||||
regime = classifier.classify(candles)
|
||||
assert regime == Regime.UNCERTAIN
|
||||
|
||||
|
||||
class TestRegimeClassifierGetIndicators:
|
||||
"""Tests for indicator calculation."""
|
||||
|
||||
def test_get_indicators_insufficient_data(self) -> None:
|
||||
"""Empty dict returned when insufficient data."""
|
||||
classifier = RegimeClassifier(atr_lookback=50, bb_period=50)
|
||||
candles = [
|
||||
Candle(
|
||||
timestamp=pd.Timestamp("2024-01-01"),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("51000"),
|
||||
low=Decimal("49000"),
|
||||
close=Decimal("50500"),
|
||||
volume=Decimal("100"),
|
||||
)
|
||||
]
|
||||
indicators = classifier.get_indicators(candles)
|
||||
assert indicators == {}
|
||||
|
||||
@patch("tradefinder.core.regime.ta.adx")
|
||||
@patch("tradefinder.core.regime.ta.atr")
|
||||
@patch("tradefinder.core.regime.ta.bbands")
|
||||
def test_get_indicators_calculates_correctly(
|
||||
self, mock_bbands: patch, mock_atr: patch, mock_adx: patch
|
||||
) -> None:
|
||||
"""Indicators are calculated and returned correctly."""
|
||||
# Mock pandas TA functions
|
||||
mock_adx.return_value = pd.DataFrame(
|
||||
{
|
||||
"ADX_14": [20.5, 21.0, 22.0],
|
||||
"DMP_14": [25.0, 26.0, 27.0],
|
||||
"DMN_14": [15.0, 16.0, 17.0],
|
||||
}
|
||||
)
|
||||
|
||||
mock_atr.return_value = pd.DataFrame(
|
||||
{
|
||||
"ATR_14": [1000.0, 1100.0, 1200.0],
|
||||
}
|
||||
)
|
||||
|
||||
mock_bbands.return_value = pd.DataFrame(
|
||||
{
|
||||
"BBL_20_2.0": [48000.0, 48500.0, 49000.0],
|
||||
"BBM_20_2.0": [50000.0, 50500.0, 51000.0],
|
||||
"BBU_20_2.0": [52000.0, 52500.0, 53000.0],
|
||||
}
|
||||
)
|
||||
|
||||
classifier = RegimeClassifier()
|
||||
candles = [
|
||||
Candle(
|
||||
timestamp=pd.Timestamp("2024-01-01") + pd.Timedelta(days=i),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("51000"),
|
||||
low=Decimal("49000"),
|
||||
close=Decimal("50500"),
|
||||
volume=Decimal("100"),
|
||||
)
|
||||
for i in range(25) # Enough data
|
||||
]
|
||||
|
||||
indicators = classifier.get_indicators(candles)
|
||||
|
||||
assert "adx" in indicators
|
||||
assert "plus_di" in indicators
|
||||
assert "minus_di" in indicators
|
||||
assert "atr" in indicators
|
||||
assert "atr_pct" in indicators
|
||||
assert "atr_pct_avg" in indicators
|
||||
assert "bb_width" in indicators
|
||||
|
||||
assert indicators["adx"] == Decimal("22.0")
|
||||
assert indicators["plus_di"] == Decimal("27.0")
|
||||
assert indicators["minus_di"] == Decimal("17.0")
|
||||
|
||||
|
||||
class TestRegimeClassifierStaticMethods:
|
||||
"""Tests for static helper methods."""
|
||||
|
||||
def test_to_decimal_from_float(self) -> None:
|
||||
"""Float conversion works."""
|
||||
result = RegimeClassifier._to_decimal(123.45)
|
||||
assert result == Decimal("123.45")
|
||||
|
||||
def test_to_decimal_from_int(self) -> None:
|
||||
"""Int conversion works."""
|
||||
result = RegimeClassifier._to_decimal(123)
|
||||
assert result == Decimal("123")
|
||||
|
||||
def test_to_decimal_from_decimal_passthrough(self) -> None:
|
||||
"""Decimal passthrough works."""
|
||||
dec = Decimal("123.45")
|
||||
result = RegimeClassifier._to_decimal(dec)
|
||||
assert result == dec
|
||||
|
||||
def test_to_decimal_from_none(self) -> None:
|
||||
"""None returns None."""
|
||||
result = RegimeClassifier._to_decimal(None)
|
||||
assert result is None
|
||||
|
||||
@patch("pandas.isna")
|
||||
def test_to_decimal_from_nan(self, mock_isna: patch) -> None:
|
||||
"""NaN values return None."""
|
||||
mock_isna.return_value = True
|
||||
result = RegimeClassifier._to_decimal(float("nan"))
|
||||
assert result is None
|
||||
|
||||
def test_decimal_from_series_tail(self) -> None:
|
||||
"""Last value from series is extracted."""
|
||||
series = pd.Series([1.0, 2.0, 3.0])
|
||||
result = RegimeClassifier._decimal_from_series_tail(series)
|
||||
assert result == Decimal("3.0")
|
||||
|
||||
def test_decimal_from_series_tail_empty(self) -> None:
|
||||
"""Empty series returns None."""
|
||||
series = pd.Series([])
|
||||
result = RegimeClassifier._decimal_from_series_tail(series)
|
||||
assert result is None
|
||||
|
||||
def test_decimal_from_series_avg(self) -> None:
|
||||
"""Average of series is calculated."""
|
||||
series = pd.Series([1.0, 2.0, 3.0, 4.0])
|
||||
result = RegimeClassifier._decimal_from_series_avg(series)
|
||||
assert result == Decimal("2.5")
|
||||
|
||||
def test_decimal_from_series_avg_empty(self) -> None:
|
||||
"""Empty series average returns None."""
|
||||
series = pd.Series([])
|
||||
result = RegimeClassifier._decimal_from_series_avg(series)
|
||||
assert result is None
|
||||
|
||||
def test_candles_to_frame(self) -> None:
|
||||
"""Candles are converted to DataFrame correctly."""
|
||||
candles = [
|
||||
Candle(
|
||||
timestamp=pd.Timestamp("2024-01-01"),
|
||||
open=Decimal("50000"),
|
||||
high=Decimal("51000"),
|
||||
low=Decimal("49000"),
|
||||
close=Decimal("50500"),
|
||||
volume=Decimal("100"),
|
||||
)
|
||||
]
|
||||
df = RegimeClassifier._candles_to_frame(candles)
|
||||
assert len(df) == 1
|
||||
assert df.iloc[0]["open"] == 50000.0
|
||||
assert df.iloc[0]["close"] == 50500.0
|
||||
|
||||
def test_candles_to_frame_empty(self) -> None:
|
||||
"""Empty candles list returns empty DataFrame."""
|
||||
df = RegimeClassifier._candles_to_frame([])
|
||||
assert df.empty
|
||||
Reference in New Issue
Block a user