"""Unit tests for RegimeClassifier (market regime detection).""" from decimal import Decimal from unittest.mock import patch import pandas as pd from tradefinder.adapters.types import Candle from tradefinder.core.regime import Regime, RegimeClassifier class TestRegimeClassifierInit: """Tests for RegimeClassifier initialization.""" def test_init_default_parameters(self) -> None: """Default parameters are set correctly.""" classifier = RegimeClassifier() assert classifier._adx_threshold == Decimal("25") assert classifier._atr_lookback == 14 assert classifier._bb_period == 20 def test_init_custom_parameters(self) -> None: """Custom parameters are accepted.""" classifier = RegimeClassifier(adx_threshold=30, atr_lookback=20, bb_period=25) assert classifier._adx_threshold == Decimal("30") assert classifier._atr_lookback == 20 assert classifier._bb_period == 25 def test_init_validates_parameters(self) -> None: """Parameters are validated on init.""" classifier = RegimeClassifier(atr_lookback=0, bb_period=0) assert classifier._atr_lookback == 1 # Min value assert classifier._bb_period == 1 # Min value class TestRegimeClassifierClassify: """Tests for regime classification logic.""" def test_classify_insufficient_data(self) -> None: """UNCERTAIN returned when insufficient data.""" classifier = RegimeClassifier() candles = [ Candle( timestamp=pd.Timestamp("2024-01-01"), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), close=Decimal("50500"), volume=Decimal("100"), ) ] regime = classifier.classify(candles) assert regime == Regime.UNCERTAIN def test_classify_trending_up(self) -> None: """TRENDING_UP detected with high ADX and +DI > -DI.""" classifier = RegimeClassifier(adx_threshold=20) # Lower threshold for test # Create mock indicators classifier.get_indicators = lambda c: { "adx": Decimal("25"), "plus_di": Decimal("30"), "minus_di": Decimal("15"), "bb_width": Decimal("0.02"), "atr_pct": Decimal("1.0"), "atr_pct_avg": Decimal("0.5"), } candles = [ Candle( timestamp=pd.Timestamp("2024-01-01"), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), close=Decimal("50500"), volume=Decimal("100"), ) ] regime = classifier.classify(candles) assert regime == Regime.TRENDING_UP def test_classify_trending_down(self) -> None: """TRENDING_DOWN detected with high ADX and -DI > +DI.""" classifier = RegimeClassifier(adx_threshold=20) classifier.get_indicators = lambda c: { "adx": Decimal("25"), "plus_di": Decimal("15"), "minus_di": Decimal("30"), "bb_width": Decimal("0.02"), "atr_pct": Decimal("1.0"), "atr_pct_avg": Decimal("0.5"), } candles = [ Candle( timestamp=pd.Timestamp("2024-01-01"), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), close=Decimal("50500"), volume=Decimal("100"), ) ] regime = classifier.classify(candles) assert regime == Regime.TRENDING_DOWN def test_classify_ranging(self) -> None: """RANGING detected with low ADX and narrow BB.""" classifier = RegimeClassifier() classifier.get_indicators = lambda c: { "adx": Decimal("15"), # Below ceiling "plus_di": Decimal("20"), "minus_di": Decimal("18"), "bb_width": Decimal("0.02"), # Below threshold "atr_pct": Decimal("0.5"), "atr_pct_avg": Decimal("1.0"), } candles = [ Candle( timestamp=pd.Timestamp("2024-01-01"), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), close=Decimal("50500"), volume=Decimal("100"), ) ] regime = classifier.classify(candles) assert regime == Regime.RANGING def test_classify_high_volatility(self) -> None: """HIGH_VOLATILITY detected with high ATR% vs average.""" classifier = RegimeClassifier() classifier.get_indicators = lambda c: { "adx": Decimal("15"), "plus_di": Decimal("20"), "minus_di": Decimal("18"), "bb_width": Decimal("0.10"), # Above threshold "atr_pct": Decimal("2.0"), # Above 1.5x average "atr_pct_avg": Decimal("1.0"), } candles = [ Candle( timestamp=pd.Timestamp("2024-01-01"), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), close=Decimal("50500"), volume=Decimal("100"), ) ] regime = classifier.classify(candles) assert regime == Regime.HIGH_VOLATILITY def test_classify_uncertain_fallback(self) -> None: """UNCERTAIN returned when no conditions met.""" classifier = RegimeClassifier() classifier.get_indicators = lambda c: { "adx": Decimal("15"), "plus_di": Decimal("20"), "minus_di": Decimal("18"), "bb_width": Decimal("0.10"), "atr_pct": Decimal("1.2"), # Below 1.5x threshold "atr_pct_avg": Decimal("1.0"), } candles = [ Candle( timestamp=pd.Timestamp("2024-01-01"), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), close=Decimal("50500"), volume=Decimal("100"), ) ] regime = classifier.classify(candles) assert regime == Regime.UNCERTAIN class TestRegimeClassifierGetIndicators: """Tests for indicator calculation.""" def test_get_indicators_insufficient_data(self) -> None: """Empty dict returned when insufficient data.""" classifier = RegimeClassifier(atr_lookback=50, bb_period=50) candles = [ Candle( timestamp=pd.Timestamp("2024-01-01"), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), close=Decimal("50500"), volume=Decimal("100"), ) ] indicators = classifier.get_indicators(candles) assert indicators == {} @patch("tradefinder.core.regime.ta.adx") @patch("tradefinder.core.regime.ta.atr") @patch("tradefinder.core.regime.ta.bbands") def test_get_indicators_calculates_correctly( self, mock_bbands: patch, mock_atr: patch, mock_adx: patch ) -> None: """Indicators are calculated and returned correctly.""" # Mock pandas TA functions mock_adx.return_value = pd.DataFrame( { "ADX_14": [20.5, 21.0, 22.0], "DMP_14": [25.0, 26.0, 27.0], "DMN_14": [15.0, 16.0, 17.0], } ) mock_atr.return_value = pd.DataFrame( { "ATR_14": [1000.0, 1100.0, 1200.0], } ) mock_bbands.return_value = pd.DataFrame( { "BBL_20_2.0": [48000.0, 48500.0, 49000.0], "BBM_20_2.0": [50000.0, 50500.0, 51000.0], "BBU_20_2.0": [52000.0, 52500.0, 53000.0], } ) classifier = RegimeClassifier() candles = [ Candle( timestamp=pd.Timestamp("2024-01-01") + pd.Timedelta(days=i), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), close=Decimal("50500"), volume=Decimal("100"), ) for i in range(25) # Enough data ] indicators = classifier.get_indicators(candles) assert "adx" in indicators assert "plus_di" in indicators assert "minus_di" in indicators assert "atr" in indicators assert "atr_pct" in indicators assert "atr_pct_avg" in indicators assert "bb_width" in indicators assert indicators["adx"] == Decimal("22.0") assert indicators["plus_di"] == Decimal("27.0") assert indicators["minus_di"] == Decimal("17.0") class TestRegimeClassifierStaticMethods: """Tests for static helper methods.""" def test_to_decimal_from_float(self) -> None: """Float conversion works.""" result = RegimeClassifier._to_decimal(123.45) assert result == Decimal("123.45") def test_to_decimal_from_int(self) -> None: """Int conversion works.""" result = RegimeClassifier._to_decimal(123) assert result == Decimal("123") def test_to_decimal_from_decimal_passthrough(self) -> None: """Decimal passthrough works.""" dec = Decimal("123.45") result = RegimeClassifier._to_decimal(dec) assert result == dec def test_to_decimal_from_none(self) -> None: """None returns None.""" result = RegimeClassifier._to_decimal(None) assert result is None @patch("pandas.isna") def test_to_decimal_from_nan(self, mock_isna: patch) -> None: """NaN values return None.""" mock_isna.return_value = True result = RegimeClassifier._to_decimal(float("nan")) assert result is None def test_decimal_from_series_tail(self) -> None: """Last value from series is extracted.""" series = pd.Series([1.0, 2.0, 3.0]) result = RegimeClassifier._decimal_from_series_tail(series) assert result == Decimal("3.0") def test_decimal_from_series_tail_empty(self) -> None: """Empty series returns None.""" series = pd.Series([]) result = RegimeClassifier._decimal_from_series_tail(series) assert result is None def test_decimal_from_series_avg(self) -> None: """Average of series is calculated.""" series = pd.Series([1.0, 2.0, 3.0, 4.0]) result = RegimeClassifier._decimal_from_series_avg(series) assert result == Decimal("2.5") def test_decimal_from_series_avg_empty(self) -> None: """Empty series average returns None.""" series = pd.Series([]) result = RegimeClassifier._decimal_from_series_avg(series) assert result is None def test_candles_to_frame(self) -> None: """Candles are converted to DataFrame correctly.""" candles = [ Candle( timestamp=pd.Timestamp("2024-01-01"), open=Decimal("50000"), high=Decimal("51000"), low=Decimal("49000"), close=Decimal("50500"), volume=Decimal("100"), ) ] df = RegimeClassifier._candles_to_frame(candles) assert len(df) == 1 assert df.iloc[0]["open"] == 50000.0 assert df.iloc[0]["close"] == 50500.0 def test_candles_to_frame_empty(self) -> None: """Empty candles list returns empty DataFrame.""" df = RegimeClassifier._candles_to_frame([]) assert df.empty