195 lines
6.4 KiB
Python
195 lines
6.4 KiB
Python
"""Supertrend-based trading strategy implementation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from decimal import Decimal
|
|
from typing import Any
|
|
|
|
import pandas as pd
|
|
import pandas_ta as ta
|
|
import structlog
|
|
|
|
from tradefinder.adapters.types import Candle, Side
|
|
from tradefinder.core.regime import Regime
|
|
from tradefinder.strategies.base import Strategy
|
|
from tradefinder.strategies.signals import Signal, SignalType
|
|
|
|
logger = structlog.get_logger(__name__)
|
|
|
|
|
|
class SupertrendStrategy(Strategy):
|
|
"""Supertrend indicator strategy with ATR-based stops."""
|
|
|
|
name = "supertrend"
|
|
|
|
def __init__(self, period: int = 10, multiplier: float = 3.0) -> None:
|
|
self._period = max(1, period)
|
|
self._multiplier = Decimal(str(multiplier))
|
|
self._min_required = self._period + 5
|
|
self._last_atr: Decimal | None = None
|
|
|
|
def generate_signal(self, candles: list[Candle]) -> Signal | None:
|
|
"""Return a Supertrend signal when the trend changes direction."""
|
|
if not self.validate_candles(candles, self._min_required):
|
|
return None
|
|
|
|
frame = self._candles_to_frame(candles)
|
|
if frame.empty:
|
|
return None
|
|
|
|
supertrend = ta.supertrend(
|
|
high=frame["high"],
|
|
low=frame["low"],
|
|
close=frame["close"],
|
|
length=self._period,
|
|
multiplier=float(self._multiplier),
|
|
)
|
|
|
|
direction_col = next(
|
|
(col for col in supertrend.columns if col.startswith("SUPERTd_")), None
|
|
)
|
|
if direction_col is None:
|
|
logger.debug("Supertrend direction series missing", strategy=self.name)
|
|
return None
|
|
|
|
direction_series = supertrend[direction_col].dropna()
|
|
if len(direction_series) < 2:
|
|
return None
|
|
|
|
latest_direction = self._to_decimal(direction_series.iloc[-1])
|
|
previous_direction = self._to_decimal(direction_series.iloc[-2])
|
|
if latest_direction is None or previous_direction is None:
|
|
return None
|
|
|
|
atr_value = self._compute_atr(frame)
|
|
self._last_atr = atr_value
|
|
|
|
entry_price = self._decimal_from_series_tail(frame["close"])
|
|
if entry_price is None or atr_value is None and entry_price <= Decimal("0"):
|
|
return None
|
|
|
|
signal_type: SignalType
|
|
signal_side: Side
|
|
direction_label: str
|
|
if previous_direction < Decimal("0") and latest_direction > Decimal("0"):
|
|
signal_type = SignalType.ENTRY_LONG
|
|
signal_side = Side.BUY
|
|
direction_label = "bullish"
|
|
elif previous_direction > Decimal("0") and latest_direction < Decimal("0"):
|
|
signal_type = SignalType.ENTRY_SHORT
|
|
signal_side = Side.SELL
|
|
direction_label = "bearish"
|
|
else:
|
|
return None
|
|
|
|
stop_loss = self.get_stop_loss(entry_price, signal_side)
|
|
trend_value = self._trend_level(supertrend)
|
|
|
|
metadata = {
|
|
"direction": direction_label,
|
|
"atr": atr_value,
|
|
"supertrend": trend_value,
|
|
}
|
|
|
|
logger.info(
|
|
"Supertrend crossover detected",
|
|
strategy=self.name,
|
|
signal_type=signal_type.value,
|
|
direction=direction_label,
|
|
)
|
|
|
|
return Signal(
|
|
signal_type=signal_type,
|
|
symbol="",
|
|
price=entry_price,
|
|
stop_loss=stop_loss,
|
|
take_profit=None,
|
|
confidence=0.65,
|
|
timestamp=candles[-1].timestamp,
|
|
strategy_name=self.name,
|
|
metadata=metadata,
|
|
)
|
|
|
|
def get_stop_loss(self, entry_price: Decimal, side: Side) -> Decimal:
|
|
"""Use ATR buffer for Supertrend stop loss."""
|
|
atr_buffer = (
|
|
self._last_atr
|
|
if self._last_atr and self._last_atr > Decimal("0")
|
|
else entry_price * Decimal("0.02")
|
|
)
|
|
if side == Side.BUY:
|
|
stop = entry_price - atr_buffer
|
|
else:
|
|
stop = entry_price + atr_buffer
|
|
return stop if stop > Decimal("0") else Decimal("0.01")
|
|
|
|
@property
|
|
def parameters(self) -> dict[str, Decimal | int]:
|
|
"""Expose current Supertrend parameters."""
|
|
return {"period": self._period, "multiplier": self._multiplier}
|
|
|
|
@property
|
|
def suitable_regimes(self) -> list[Regime]:
|
|
"""This strategy runs only in trending regimes."""
|
|
return [Regime.TRENDING_UP, Regime.TRENDING_DOWN]
|
|
|
|
def _compute_atr(self, frame: pd.DataFrame) -> Decimal | None:
|
|
atr_result: Any = ta.atr(
|
|
high=frame["high"],
|
|
low=frame["low"],
|
|
close=frame["close"],
|
|
length=self._period,
|
|
)
|
|
if atr_result is None:
|
|
return None
|
|
if isinstance(atr_result, pd.Series):
|
|
if atr_result.empty:
|
|
return None
|
|
return self._to_decimal(atr_result.iloc[-1])
|
|
atr_df: pd.DataFrame = atr_result
|
|
if atr_df.empty:
|
|
return None
|
|
atr_col = next((col for col in atr_df.columns if "ATR" in col), None)
|
|
if atr_col is None:
|
|
return None
|
|
return self._to_decimal(atr_df[atr_col].iloc[-1])
|
|
|
|
@staticmethod
|
|
def _trend_level(supertrend: pd.DataFrame) -> Decimal | None:
|
|
trend_col = next(
|
|
(
|
|
col
|
|
for col in supertrend.columns
|
|
if col.startswith("SUPERT_") and not col.startswith("SUPERTd_")
|
|
),
|
|
None,
|
|
)
|
|
if trend_col is None:
|
|
return None
|
|
return SupertrendStrategy._decimal_from_series_tail(supertrend[trend_col])
|
|
|
|
@staticmethod
|
|
def _candles_to_frame(candles: list[Candle]) -> pd.DataFrame:
|
|
if not candles:
|
|
return pd.DataFrame()
|
|
frame = pd.DataFrame([candle.to_dict() for candle in candles])
|
|
frame["timestamp"] = pd.to_datetime(frame["timestamp"], utc=True)
|
|
return frame
|
|
|
|
@staticmethod
|
|
def _decimal_from_series_tail(series: pd.Series) -> Decimal | None:
|
|
if series.empty:
|
|
return None
|
|
return SupertrendStrategy._to_decimal(series.iloc[-1])
|
|
|
|
@staticmethod
|
|
def _to_decimal(value: float | int | Decimal | None) -> Decimal | None:
|
|
if value is None:
|
|
return None
|
|
try:
|
|
if pd.isna(value): # type: ignore[arg-type]
|
|
return None
|
|
except (TypeError, ValueError):
|
|
pass
|
|
return Decimal(str(value))
|