"""Supertrend-based trading strategy implementation.""" from __future__ import annotations from decimal import Decimal from typing import Any import pandas as pd import pandas_ta as ta import structlog from tradefinder.adapters.types import Candle, Side from tradefinder.core.regime import Regime from tradefinder.strategies.base import Strategy from tradefinder.strategies.signals import Signal, SignalType logger = structlog.get_logger(__name__) class SupertrendStrategy(Strategy): """Supertrend indicator strategy with ATR-based stops.""" name = "supertrend" def __init__(self, period: int = 10, multiplier: float = 3.0) -> None: self._period = max(1, period) self._multiplier = Decimal(str(multiplier)) self._min_required = self._period + 5 self._last_atr: Decimal | None = None def generate_signal(self, candles: list[Candle]) -> Signal | None: """Return a Supertrend signal when the trend changes direction.""" if not self.validate_candles(candles, self._min_required): return None frame = self._candles_to_frame(candles) if frame.empty: return None supertrend = ta.supertrend( high=frame["high"], low=frame["low"], close=frame["close"], length=self._period, multiplier=float(self._multiplier), ) direction_col = next( (col for col in supertrend.columns if col.startswith("SUPERTd_")), None ) if direction_col is None: logger.debug("Supertrend direction series missing", strategy=self.name) return None direction_series = supertrend[direction_col].dropna() if len(direction_series) < 2: return None latest_direction = self._to_decimal(direction_series.iloc[-1]) previous_direction = self._to_decimal(direction_series.iloc[-2]) if latest_direction is None or previous_direction is None: return None atr_value = self._compute_atr(frame) self._last_atr = atr_value entry_price = self._decimal_from_series_tail(frame["close"]) if entry_price is None or atr_value is None and entry_price <= Decimal("0"): return None signal_type: SignalType signal_side: Side direction_label: str if previous_direction < Decimal("0") and latest_direction > Decimal("0"): signal_type = SignalType.ENTRY_LONG signal_side = Side.BUY direction_label = "bullish" elif previous_direction > Decimal("0") and latest_direction < Decimal("0"): signal_type = SignalType.ENTRY_SHORT signal_side = Side.SELL direction_label = "bearish" else: return None stop_loss = self.get_stop_loss(entry_price, signal_side) trend_value = self._trend_level(supertrend) metadata = { "direction": direction_label, "atr": atr_value, "supertrend": trend_value, } logger.info( "Supertrend crossover detected", strategy=self.name, signal_type=signal_type.value, direction=direction_label, ) return Signal( signal_type=signal_type, symbol="", price=entry_price, stop_loss=stop_loss, take_profit=None, confidence=0.65, timestamp=candles[-1].timestamp, strategy_name=self.name, metadata=metadata, ) def get_stop_loss(self, entry_price: Decimal, side: Side) -> Decimal: """Use ATR buffer for Supertrend stop loss.""" atr_buffer = ( self._last_atr if self._last_atr and self._last_atr > Decimal("0") else entry_price * Decimal("0.02") ) if side == Side.BUY: stop = entry_price - atr_buffer else: stop = entry_price + atr_buffer return stop if stop > Decimal("0") else Decimal("0.01") @property def parameters(self) -> dict[str, Decimal | int]: """Expose current Supertrend parameters.""" return {"period": self._period, "multiplier": self._multiplier} @property def suitable_regimes(self) -> list[Regime]: """This strategy runs only in trending regimes.""" return [Regime.TRENDING_UP, Regime.TRENDING_DOWN] def _compute_atr(self, frame: pd.DataFrame) -> Decimal | None: atr_result: Any = ta.atr( high=frame["high"], low=frame["low"], close=frame["close"], length=self._period, ) if atr_result is None: return None if isinstance(atr_result, pd.Series): if atr_result.empty: return None return self._to_decimal(atr_result.iloc[-1]) atr_df: pd.DataFrame = atr_result if atr_df.empty: return None atr_col = next((col for col in atr_df.columns if "ATR" in col), None) if atr_col is None: return None return self._to_decimal(atr_df[atr_col].iloc[-1]) @staticmethod def _trend_level(supertrend: pd.DataFrame) -> Decimal | None: trend_col = next( ( col for col in supertrend.columns if col.startswith("SUPERT_") and not col.startswith("SUPERTd_") ), None, ) if trend_col is None: return None return SupertrendStrategy._decimal_from_series_tail(supertrend[trend_col]) @staticmethod def _candles_to_frame(candles: list[Candle]) -> pd.DataFrame: if not candles: return pd.DataFrame() frame = pd.DataFrame([candle.to_dict() for candle in candles]) frame["timestamp"] = pd.to_datetime(frame["timestamp"], utc=True) return frame @staticmethod def _decimal_from_series_tail(series: pd.Series) -> Decimal | None: if series.empty: return None return SupertrendStrategy._to_decimal(series.iloc[-1]) @staticmethod def _to_decimal(value: float | int | Decimal | None) -> Decimal | None: if value is None: return None try: if pd.isna(value): # type: ignore[arg-type] return None except (TypeError, ValueError): pass return Decimal(str(value))