Files
CryptoTrading/tests/test_regime.py
bnair123 eca17b42fe Add Phase 2 foundation: regime classifier, strategy framework, WebSocket streamer
Phase 1 completion:
- Add DataStreamer for real-time Binance Futures WebSocket data (klines, mark price)
- Add DataValidator for candle validation and gap detection
- Add timeframes module with interval mappings

Phase 2 foundation:
- Add RegimeClassifier with ADX/ATR/Bollinger Band analysis
- Add Regime enum (TRENDING_UP/DOWN, RANGING, HIGH_VOLATILITY, UNCERTAIN)
- Add Strategy ABC defining generate_signal, get_stop_loss, parameters, suitable_regimes
- Add Signal dataclass and SignalType enum for strategy outputs

Testing:
- Add comprehensive test suites for all new modules
- 159 tests passing, 24 skipped (async WebSocket timing)
- 82% code coverage

Dependencies:
- Add pandas-stubs to dev dependencies for mypy compatibility
2025-12-27 15:28:28 +04:00

340 lines
12 KiB
Python

"""Unit tests for RegimeClassifier (market regime detection)."""
from decimal import Decimal
from unittest.mock import patch
import pandas as pd
from tradefinder.adapters.types import Candle
from tradefinder.core.regime import Regime, RegimeClassifier
class TestRegimeClassifierInit:
"""Tests for RegimeClassifier initialization."""
def test_init_default_parameters(self) -> None:
"""Default parameters are set correctly."""
classifier = RegimeClassifier()
assert classifier._adx_threshold == Decimal("25")
assert classifier._atr_lookback == 14
assert classifier._bb_period == 20
def test_init_custom_parameters(self) -> None:
"""Custom parameters are accepted."""
classifier = RegimeClassifier(adx_threshold=30, atr_lookback=20, bb_period=25)
assert classifier._adx_threshold == Decimal("30")
assert classifier._atr_lookback == 20
assert classifier._bb_period == 25
def test_init_validates_parameters(self) -> None:
"""Parameters are validated on init."""
classifier = RegimeClassifier(atr_lookback=0, bb_period=0)
assert classifier._atr_lookback == 1 # Min value
assert classifier._bb_period == 1 # Min value
class TestRegimeClassifierClassify:
"""Tests for regime classification logic."""
def test_classify_insufficient_data(self) -> None:
"""UNCERTAIN returned when insufficient data."""
classifier = RegimeClassifier()
candles = [
Candle(
timestamp=pd.Timestamp("2024-01-01"),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
close=Decimal("50500"),
volume=Decimal("100"),
)
]
regime = classifier.classify(candles)
assert regime == Regime.UNCERTAIN
def test_classify_trending_up(self) -> None:
"""TRENDING_UP detected with high ADX and +DI > -DI."""
classifier = RegimeClassifier(adx_threshold=20) # Lower threshold for test
# Create mock indicators
classifier.get_indicators = lambda c: {
"adx": Decimal("25"),
"plus_di": Decimal("30"),
"minus_di": Decimal("15"),
"bb_width": Decimal("0.02"),
"atr_pct": Decimal("1.0"),
"atr_pct_avg": Decimal("0.5"),
}
candles = [
Candle(
timestamp=pd.Timestamp("2024-01-01"),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
close=Decimal("50500"),
volume=Decimal("100"),
)
]
regime = classifier.classify(candles)
assert regime == Regime.TRENDING_UP
def test_classify_trending_down(self) -> None:
"""TRENDING_DOWN detected with high ADX and -DI > +DI."""
classifier = RegimeClassifier(adx_threshold=20)
classifier.get_indicators = lambda c: {
"adx": Decimal("25"),
"plus_di": Decimal("15"),
"minus_di": Decimal("30"),
"bb_width": Decimal("0.02"),
"atr_pct": Decimal("1.0"),
"atr_pct_avg": Decimal("0.5"),
}
candles = [
Candle(
timestamp=pd.Timestamp("2024-01-01"),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
close=Decimal("50500"),
volume=Decimal("100"),
)
]
regime = classifier.classify(candles)
assert regime == Regime.TRENDING_DOWN
def test_classify_ranging(self) -> None:
"""RANGING detected with low ADX and narrow BB."""
classifier = RegimeClassifier()
classifier.get_indicators = lambda c: {
"adx": Decimal("15"), # Below ceiling
"plus_di": Decimal("20"),
"minus_di": Decimal("18"),
"bb_width": Decimal("0.02"), # Below threshold
"atr_pct": Decimal("0.5"),
"atr_pct_avg": Decimal("1.0"),
}
candles = [
Candle(
timestamp=pd.Timestamp("2024-01-01"),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
close=Decimal("50500"),
volume=Decimal("100"),
)
]
regime = classifier.classify(candles)
assert regime == Regime.RANGING
def test_classify_high_volatility(self) -> None:
"""HIGH_VOLATILITY detected with high ATR% vs average."""
classifier = RegimeClassifier()
classifier.get_indicators = lambda c: {
"adx": Decimal("15"),
"plus_di": Decimal("20"),
"minus_di": Decimal("18"),
"bb_width": Decimal("0.10"), # Above threshold
"atr_pct": Decimal("2.0"), # Above 1.5x average
"atr_pct_avg": Decimal("1.0"),
}
candles = [
Candle(
timestamp=pd.Timestamp("2024-01-01"),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
close=Decimal("50500"),
volume=Decimal("100"),
)
]
regime = classifier.classify(candles)
assert regime == Regime.HIGH_VOLATILITY
def test_classify_uncertain_fallback(self) -> None:
"""UNCERTAIN returned when no conditions met."""
classifier = RegimeClassifier()
classifier.get_indicators = lambda c: {
"adx": Decimal("15"),
"plus_di": Decimal("20"),
"minus_di": Decimal("18"),
"bb_width": Decimal("0.10"),
"atr_pct": Decimal("1.2"), # Below 1.5x threshold
"atr_pct_avg": Decimal("1.0"),
}
candles = [
Candle(
timestamp=pd.Timestamp("2024-01-01"),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
close=Decimal("50500"),
volume=Decimal("100"),
)
]
regime = classifier.classify(candles)
assert regime == Regime.UNCERTAIN
class TestRegimeClassifierGetIndicators:
"""Tests for indicator calculation."""
def test_get_indicators_insufficient_data(self) -> None:
"""Empty dict returned when insufficient data."""
classifier = RegimeClassifier(atr_lookback=50, bb_period=50)
candles = [
Candle(
timestamp=pd.Timestamp("2024-01-01"),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
close=Decimal("50500"),
volume=Decimal("100"),
)
]
indicators = classifier.get_indicators(candles)
assert indicators == {}
@patch("tradefinder.core.regime.ta.adx")
@patch("tradefinder.core.regime.ta.atr")
@patch("tradefinder.core.regime.ta.bbands")
def test_get_indicators_calculates_correctly(
self, mock_bbands: patch, mock_atr: patch, mock_adx: patch
) -> None:
"""Indicators are calculated and returned correctly."""
# Mock pandas TA functions
mock_adx.return_value = pd.DataFrame(
{
"ADX_14": [20.5, 21.0, 22.0],
"DMP_14": [25.0, 26.0, 27.0],
"DMN_14": [15.0, 16.0, 17.0],
}
)
mock_atr.return_value = pd.DataFrame(
{
"ATR_14": [1000.0, 1100.0, 1200.0],
}
)
mock_bbands.return_value = pd.DataFrame(
{
"BBL_20_2.0": [48000.0, 48500.0, 49000.0],
"BBM_20_2.0": [50000.0, 50500.0, 51000.0],
"BBU_20_2.0": [52000.0, 52500.0, 53000.0],
}
)
classifier = RegimeClassifier()
candles = [
Candle(
timestamp=pd.Timestamp("2024-01-01") + pd.Timedelta(days=i),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
close=Decimal("50500"),
volume=Decimal("100"),
)
for i in range(25) # Enough data
]
indicators = classifier.get_indicators(candles)
assert "adx" in indicators
assert "plus_di" in indicators
assert "minus_di" in indicators
assert "atr" in indicators
assert "atr_pct" in indicators
assert "atr_pct_avg" in indicators
assert "bb_width" in indicators
assert indicators["adx"] == Decimal("22.0")
assert indicators["plus_di"] == Decimal("27.0")
assert indicators["minus_di"] == Decimal("17.0")
class TestRegimeClassifierStaticMethods:
"""Tests for static helper methods."""
def test_to_decimal_from_float(self) -> None:
"""Float conversion works."""
result = RegimeClassifier._to_decimal(123.45)
assert result == Decimal("123.45")
def test_to_decimal_from_int(self) -> None:
"""Int conversion works."""
result = RegimeClassifier._to_decimal(123)
assert result == Decimal("123")
def test_to_decimal_from_decimal_passthrough(self) -> None:
"""Decimal passthrough works."""
dec = Decimal("123.45")
result = RegimeClassifier._to_decimal(dec)
assert result == dec
def test_to_decimal_from_none(self) -> None:
"""None returns None."""
result = RegimeClassifier._to_decimal(None)
assert result is None
@patch("pandas.isna")
def test_to_decimal_from_nan(self, mock_isna: patch) -> None:
"""NaN values return None."""
mock_isna.return_value = True
result = RegimeClassifier._to_decimal(float("nan"))
assert result is None
def test_decimal_from_series_tail(self) -> None:
"""Last value from series is extracted."""
series = pd.Series([1.0, 2.0, 3.0])
result = RegimeClassifier._decimal_from_series_tail(series)
assert result == Decimal("3.0")
def test_decimal_from_series_tail_empty(self) -> None:
"""Empty series returns None."""
series = pd.Series([])
result = RegimeClassifier._decimal_from_series_tail(series)
assert result is None
def test_decimal_from_series_avg(self) -> None:
"""Average of series is calculated."""
series = pd.Series([1.0, 2.0, 3.0, 4.0])
result = RegimeClassifier._decimal_from_series_avg(series)
assert result == Decimal("2.5")
def test_decimal_from_series_avg_empty(self) -> None:
"""Empty series average returns None."""
series = pd.Series([])
result = RegimeClassifier._decimal_from_series_avg(series)
assert result is None
def test_candles_to_frame(self) -> None:
"""Candles are converted to DataFrame correctly."""
candles = [
Candle(
timestamp=pd.Timestamp("2024-01-01"),
open=Decimal("50000"),
high=Decimal("51000"),
low=Decimal("49000"),
close=Decimal("50500"),
volume=Decimal("100"),
)
]
df = RegimeClassifier._candles_to_frame(candles)
assert len(df) == 1
assert df.iloc[0]["open"] == 50000.0
assert df.iloc[0]["close"] == 50500.0
def test_candles_to_frame_empty(self) -> None:
"""Empty candles list returns empty DataFrame."""
df = RegimeClassifier._candles_to_frame([])
assert df.empty