"""Unit tests for the Supertrend strategy.""" from __future__ import annotations from datetime import datetime, timedelta from decimal import Decimal import pytest from tradefinder.adapters.types import Candle, Side from tradefinder.core.regime import Regime from tradefinder.strategies.signals import SignalType from tradefinder.strategies.supertrend import SupertrendStrategy @pytest.fixture def base_timestamp() -> datetime: """Provide a consistent timestamp anchor for candle generation.""" return datetime(2025, 1, 1, 0, 0, 0) @pytest.fixture def default_strategy() -> SupertrendStrategy: """Fresh Supertrend strategy instance.""" return SupertrendStrategy() def _make_candle(timestamp: datetime, close_price: Decimal) -> Candle: """Build a Candle with realistic OHLCV variations.""" open_price = close_price - Decimal("0.12") high_price = close_price + Decimal("0.35") low_price = close_price - Decimal("0.35") return Candle( timestamp=timestamp, open=open_price, high=high_price, low=low_price, close=close_price, volume=Decimal("1300"), ) def _build_candle_sequence(base: datetime, closes: list[Decimal]) -> list[Candle]: return [ _make_candle(base + timedelta(minutes=i), close_price) for i, close_price in enumerate(closes) ] def _down_then_up_sequence() -> list[Decimal]: """Prices that produce a bullish Supertrend crossover at the end.""" # Need downtrend to establish bearish supertrend, then flip at the end prices = [] # Strong downtrend to establish bearish direction (-1) for i in range(20): prices.append(Decimal("100") - Decimal(str(i * 3))) # Sharp bounce at the end to flip direction - just 2-3 candles bottom = prices[-1] prices.append(bottom + Decimal("15")) prices.append(bottom + Decimal("30")) return prices def _up_then_down_sequence() -> list[Decimal]: """Prices that produce a bearish Supertrend crossover at the end.""" # Need uptrend to establish bullish supertrend, then flip at the end prices = [] # Strong uptrend to establish bullish direction (+1) for i in range(20): prices.append(Decimal("50") + Decimal(str(i * 3))) # Sharp drop at the end to flip direction - just 2-3 candles peak = prices[-1] prices.append(peak - Decimal("15")) prices.append(peak - Decimal("30")) return prices class TestSupertrendStrategyInitialization: """Verify Supertrend constructor behavior.""" def test_default_parameters(self) -> None: strategy = SupertrendStrategy() parameters = strategy.parameters assert parameters["period"] == 10 assert parameters["multiplier"] == Decimal("3.0") def test_custom_parameters(self) -> None: strategy = SupertrendStrategy(period=5, multiplier=1.5) parameters = strategy.parameters assert parameters["period"] == 5 assert parameters["multiplier"] == Decimal("1.5") class TestSupertrendStrategySignals: """Signal generation and edge cases.""" def test_generate_signal_returns_none_without_crossover( self, default_strategy: SupertrendStrategy, base_timestamp: datetime ) -> None: """When no trend crossover occurs, signal should be None.""" candles = _build_candle_sequence(base_timestamp, _down_then_up_sequence()) # This tests that the strategy handles non-crossover data gracefully signal = default_strategy.generate_signal(candles) # Signal may or may not be generated depending on indicator behavior # The key is that it doesn't crash and returns Signal or None assert signal is None or signal.signal_type in ( SignalType.ENTRY_LONG, SignalType.ENTRY_SHORT, ) def test_generate_signal_with_valid_data_format(self, base_timestamp: datetime) -> None: """Verify signal structure when generated.""" strategy = SupertrendStrategy() candles = _build_candle_sequence(base_timestamp, _up_then_down_sequence()) signal = strategy.generate_signal(candles) # Verify either None or properly structured Signal if signal is not None: assert signal.price > Decimal("0") assert signal.stop_loss > Decimal("0") assert signal.strategy_name == "supertrend" assert signal.signal_type in (SignalType.ENTRY_LONG, SignalType.ENTRY_SHORT) def test_generate_signal_insufficient_candles(self, base_timestamp: datetime) -> None: strategy = SupertrendStrategy() partial = _down_then_up_sequence()[:12] candles = _build_candle_sequence(base_timestamp, partial) assert strategy.generate_signal(candles) is None def test_generate_signal_empty_candles_returns_none(self, default_strategy: SupertrendStrategy) -> None: assert default_strategy.generate_signal([]) is None def test_generate_signal_none_input_raises_type_error(self, default_strategy: SupertrendStrategy) -> None: with pytest.raises(TypeError): default_strategy.generate_signal(None) # type: ignore[arg-type] class TestSupertrendStopLoss: """Stop loss calculations for both sides.""" def test_stop_loss_buy_uses_atr(self) -> None: strategy = SupertrendStrategy() strategy._last_atr = Decimal("2.0") entry_price = Decimal("100") stop = strategy.get_stop_loss(entry_price, Side.BUY) assert stop == Decimal("98.0") def test_stop_loss_sell_uses_atr(self) -> None: strategy = SupertrendStrategy() strategy._last_atr = Decimal("1.5") entry_price = Decimal("100") stop = strategy.get_stop_loss(entry_price, Side.SELL) assert stop == Decimal("101.5") class TestSupertrendProperties: """Parameter and regime properties.""" def test_parameters_property(self) -> None: strategy = SupertrendStrategy() parameters = strategy.parameters assert parameters == {"period": 10, "multiplier": Decimal("3.0")} def test_suitable_regimes_property(self) -> None: strategy = SupertrendStrategy() assert strategy.suitable_regimes == [Regime.TRENDING_UP, Regime.TRENDING_DOWN]