Files
CryptoTrading/src/tradefinder/strategies/supertrend.py
bnair123 30af8e7c70
Some checks failed
CI/CD Pipeline / build-and-push (push) Has been cancelled
CI/CD Pipeline / test (push) Has been cancelled
Format code with ruff to pass CI format check
2025-12-27 19:12:19 +04:00

195 lines
6.4 KiB
Python

"""Supertrend-based trading strategy implementation."""
from __future__ import annotations
from decimal import Decimal
from typing import Any
import pandas as pd
import pandas_ta as ta
import structlog
from tradefinder.adapters.types import Candle, Side
from tradefinder.core.regime import Regime
from tradefinder.strategies.base import Strategy
from tradefinder.strategies.signals import Signal, SignalType
logger = structlog.get_logger(__name__)
class SupertrendStrategy(Strategy):
"""Supertrend indicator strategy with ATR-based stops."""
name = "supertrend"
def __init__(self, period: int = 10, multiplier: float = 3.0) -> None:
self._period = max(1, period)
self._multiplier = Decimal(str(multiplier))
self._min_required = self._period + 5
self._last_atr: Decimal | None = None
def generate_signal(self, candles: list[Candle]) -> Signal | None:
"""Return a Supertrend signal when the trend changes direction."""
if not self.validate_candles(candles, self._min_required):
return None
frame = self._candles_to_frame(candles)
if frame.empty:
return None
supertrend = ta.supertrend(
high=frame["high"],
low=frame["low"],
close=frame["close"],
length=self._period,
multiplier=float(self._multiplier),
)
direction_col = next(
(col for col in supertrend.columns if col.startswith("SUPERTd_")), None
)
if direction_col is None:
logger.debug("Supertrend direction series missing", strategy=self.name)
return None
direction_series = supertrend[direction_col].dropna()
if len(direction_series) < 2:
return None
latest_direction = self._to_decimal(direction_series.iloc[-1])
previous_direction = self._to_decimal(direction_series.iloc[-2])
if latest_direction is None or previous_direction is None:
return None
atr_value = self._compute_atr(frame)
self._last_atr = atr_value
entry_price = self._decimal_from_series_tail(frame["close"])
if entry_price is None or atr_value is None and entry_price <= Decimal("0"):
return None
signal_type: SignalType
signal_side: Side
direction_label: str
if previous_direction < Decimal("0") and latest_direction > Decimal("0"):
signal_type = SignalType.ENTRY_LONG
signal_side = Side.BUY
direction_label = "bullish"
elif previous_direction > Decimal("0") and latest_direction < Decimal("0"):
signal_type = SignalType.ENTRY_SHORT
signal_side = Side.SELL
direction_label = "bearish"
else:
return None
stop_loss = self.get_stop_loss(entry_price, signal_side)
trend_value = self._trend_level(supertrend)
metadata = {
"direction": direction_label,
"atr": atr_value,
"supertrend": trend_value,
}
logger.info(
"Supertrend crossover detected",
strategy=self.name,
signal_type=signal_type.value,
direction=direction_label,
)
return Signal(
signal_type=signal_type,
symbol="",
price=entry_price,
stop_loss=stop_loss,
take_profit=None,
confidence=0.65,
timestamp=candles[-1].timestamp,
strategy_name=self.name,
metadata=metadata,
)
def get_stop_loss(self, entry_price: Decimal, side: Side) -> Decimal:
"""Use ATR buffer for Supertrend stop loss."""
atr_buffer = (
self._last_atr
if self._last_atr and self._last_atr > Decimal("0")
else entry_price * Decimal("0.02")
)
if side == Side.BUY:
stop = entry_price - atr_buffer
else:
stop = entry_price + atr_buffer
return stop if stop > Decimal("0") else Decimal("0.01")
@property
def parameters(self) -> dict[str, Decimal | int]:
"""Expose current Supertrend parameters."""
return {"period": self._period, "multiplier": self._multiplier}
@property
def suitable_regimes(self) -> list[Regime]:
"""This strategy runs only in trending regimes."""
return [Regime.TRENDING_UP, Regime.TRENDING_DOWN]
def _compute_atr(self, frame: pd.DataFrame) -> Decimal | None:
atr_result: Any = ta.atr(
high=frame["high"],
low=frame["low"],
close=frame["close"],
length=self._period,
)
if atr_result is None:
return None
if isinstance(atr_result, pd.Series):
if atr_result.empty:
return None
return self._to_decimal(atr_result.iloc[-1])
atr_df: pd.DataFrame = atr_result
if atr_df.empty:
return None
atr_col = next((col for col in atr_df.columns if "ATR" in col), None)
if atr_col is None:
return None
return self._to_decimal(atr_df[atr_col].iloc[-1])
@staticmethod
def _trend_level(supertrend: pd.DataFrame) -> Decimal | None:
trend_col = next(
(
col
for col in supertrend.columns
if col.startswith("SUPERT_") and not col.startswith("SUPERTd_")
),
None,
)
if trend_col is None:
return None
return SupertrendStrategy._decimal_from_series_tail(supertrend[trend_col])
@staticmethod
def _candles_to_frame(candles: list[Candle]) -> pd.DataFrame:
if not candles:
return pd.DataFrame()
frame = pd.DataFrame([candle.to_dict() for candle in candles])
frame["timestamp"] = pd.to_datetime(frame["timestamp"], utc=True)
return frame
@staticmethod
def _decimal_from_series_tail(series: pd.Series) -> Decimal | None:
if series.empty:
return None
return SupertrendStrategy._to_decimal(series.iloc[-1])
@staticmethod
def _to_decimal(value: float | int | Decimal | None) -> Decimal | None:
if value is None:
return None
try:
if pd.isna(value): # type: ignore[arg-type]
return None
except (TypeError, ValueError):
pass
return Decimal(str(value))