Add Supertrend strategy and Risk Engine (Phase 2 Milestones 2.2, 2.3)

- Implement SupertrendStrategy with pandas-ta indicator, ATR-based stops
- Add RiskEngine with position sizing, risk limits, portfolio heat tracking
- Add comprehensive tests for both modules (32 new tests)
- Update AGENTS.md with accurate project structure and py312 target
This commit is contained in:
bnair123
2025-12-27 18:24:20 +04:00
parent eca17b42fe
commit e17c3bf508
5 changed files with 818 additions and 49 deletions

View File

@@ -0,0 +1,204 @@
"""Risk module implementing sizing, limits, and portfolio heat tracking."""
from __future__ import annotations
from dataclasses import dataclass, field
from decimal import Decimal
import structlog
logger = structlog.get_logger(__name__)
@dataclass
class PortfolioRisk:
"""Track aggregate exposure, per-strategy exposure, and portfolio heat."""
total_exposure: Decimal = Decimal("0")
per_strategy_exposure: dict[str, Decimal] = field(default_factory=dict)
portfolio_heat: Decimal = Decimal("0")
def add_exposure(self, strategy: str, amount: Decimal, equity: Decimal) -> None:
"""Add risk amount for a strategy and refresh totals."""
if amount <= Decimal("0"):
logger.debug("Ignoring non-positive exposure", strategy=strategy, amount=amount)
return
current = self.per_strategy_exposure.get(strategy, Decimal("0"))
self.per_strategy_exposure[strategy] = current + amount
self.total_exposure += amount
self._recalculate_heat(equity)
logger.debug(
"Registered exposure",
strategy=strategy,
added_amount=str(amount),
strategy_total=str(self.per_strategy_exposure[strategy]),
total_exposure=str(self.total_exposure),
heat=str(self.portfolio_heat),
)
def remove_exposure(self, strategy: str, amount: Decimal, equity: Decimal) -> None:
"""Remove risk amount for a strategy and refresh totals."""
if amount <= Decimal("0"):
return
current = self.per_strategy_exposure.get(strategy, Decimal("0"))
reduction = min(current, amount)
self.per_strategy_exposure[strategy] = current - reduction
self.total_exposure = max(Decimal("0"), self.total_exposure - reduction)
self._recalculate_heat(equity)
logger.debug(
"Reduced exposure",
strategy=strategy,
removed_amount=str(reduction),
strategy_total=str(self.per_strategy_exposure[strategy]),
total_exposure=str(self.total_exposure),
heat=str(self.portfolio_heat),
)
def _recalculate_heat(self, equity: Decimal) -> None:
"""Recalculate portfolio heat as total exposure percent of equity."""
if equity <= Decimal("0"):
self.portfolio_heat = Decimal("0")
return
self.portfolio_heat = (self.total_exposure / equity) * Decimal("100")
class RiskEngine:
"""Encapsulate position sizing, risk limits, and risk amount calculations."""
_min_risk_pct = Decimal("1")
_max_risk_pct = Decimal("3")
_max_per_strategy_pct = Decimal("25")
_max_total_exposure_pct = Decimal("100")
def __init__(self, portfolio_risk: PortfolioRisk | None = None) -> None:
self._portfolio_risk = portfolio_risk or PortfolioRisk()
def calculate_position_size(
self,
equity: Decimal,
entry_price: Decimal,
stop_loss: Decimal,
risk_pct: Decimal,
) -> Decimal:
"""Size a position to risk a percentage of equity between entry and stop."""
if equity <= Decimal("0"):
raise ValueError("Equity must be positive")
if entry_price <= Decimal("0") or stop_loss <= Decimal("0"):
raise ValueError("Entry and stop must be positive values")
stop_distance = abs(entry_price - stop_loss)
if stop_distance == Decimal("0"):
raise ValueError("Stop loss must differ from entry price")
normalized_risk_pct = max(self._min_risk_pct, min(risk_pct, self._max_risk_pct))
risk_amount = equity * (normalized_risk_pct / Decimal("100"))
position_size = risk_amount / stop_distance
logger.debug(
"Calculated position size",
equity=str(equity),
entry_price=str(entry_price),
stop_loss=str(stop_loss),
stop_distance=str(stop_distance),
risk_pct=str(normalized_risk_pct),
position_size=str(position_size),
)
return position_size
def validate_risk_limits(
self,
position_size: Decimal,
entry_price: Decimal,
max_per_trade_pct: Decimal,
equity: Decimal,
) -> bool:
"""Ensure the per-trade exposure stays within configured limits."""
if equity <= Decimal("0"):
raise ValueError("Equity must be positive to validate risk limits")
if position_size <= Decimal("0") or entry_price <= Decimal("0"):
raise ValueError("Position size and entry price must be positive")
allowed_pct = max(Decimal("0"), min(max_per_trade_pct, self._max_risk_pct))
max_notional = equity * (allowed_pct / Decimal("100"))
notional = position_size * entry_price
within_limits = notional <= max_notional
logger.debug(
"Validated risk limits",
position_size=str(position_size),
entry_price=str(entry_price),
notional=str(notional),
max_notional=str(max_notional),
within_limits=within_limits,
)
return within_limits
def calculate_risk_amount(
self,
position_size: Decimal,
entry_price: Decimal,
stop_loss: Decimal,
) -> Decimal:
"""Compute the absolute capital at risk between entry and stop loss."""
if position_size <= Decimal("0"):
raise ValueError("Position size must be positive")
if entry_price <= Decimal("0") or stop_loss <= Decimal("0"):
raise ValueError("Entry and stop prices must be positive")
stop_distance = abs(entry_price - stop_loss)
if stop_distance == Decimal("0"):
raise ValueError("Stop loss distance must be non-zero")
risk_amount = position_size * stop_distance
logger.debug(
"Calculated risk amount",
position_size=str(position_size),
entry_price=str(entry_price),
stop_loss=str(stop_loss),
risk_amount=str(risk_amount),
)
return risk_amount
def can_allocate_strategy(
self,
strategy: str,
risk_amount: Decimal,
equity: Decimal,
max_per_strategy_pct: Decimal | None = None,
max_total_exposure_pct: Decimal | None = None,
) -> bool:
"""Return True if adding exposure keeps strategy and total caps."""
if equity <= Decimal("0"):
raise ValueError("Equity must be positive to allocate exposure")
if risk_amount <= Decimal("0"):
logger.debug("Risk amount is non-positive", strategy=strategy, risk_amount=str(risk_amount))
return False
strategy_pct = max_per_strategy_pct or self._max_per_strategy_pct
total_pct = max_total_exposure_pct or self._max_total_exposure_pct
strategy_limit = equity * (strategy_pct / Decimal("100"))
total_limit = equity * (total_pct / Decimal("100"))
current_strategy = self._portfolio_risk.per_strategy_exposure.get(strategy, Decimal("0"))
strategy_after = current_strategy + risk_amount
total_after = self._portfolio_risk.total_exposure + risk_amount
within_strategy = strategy_after <= strategy_limit
within_total = total_after <= total_limit
if within_strategy and within_total:
self._portfolio_risk.add_exposure(strategy, risk_amount, equity)
logger.debug("Allocated exposure", strategy=strategy, risk_amount=str(risk_amount))
return True
logger.warning(
"Allocation exceeds limits",
strategy=strategy,
strategy_after=strategy_after,
strategy_limit=strategy_limit,
total_after=str(total_after),
total_limit=str(total_limit),
)
return False

View File

@@ -0,0 +1,181 @@
"""Supertrend-based trading strategy implementation."""
from __future__ import annotations
from decimal import Decimal
from typing import Any
import pandas as pd
import pandas_ta as ta
import structlog
from tradefinder.adapters.types import Candle, Side
from tradefinder.core.regime import Regime
from tradefinder.strategies.base import Strategy
from tradefinder.strategies.signals import Signal, SignalType
logger = structlog.get_logger(__name__)
class SupertrendStrategy(Strategy):
"""Supertrend indicator strategy with ATR-based stops."""
name = "supertrend"
def __init__(self, period: int = 10, multiplier: float = 3.0) -> None:
self._period = max(1, period)
self._multiplier = Decimal(str(multiplier))
self._min_required = self._period + 5
self._last_atr: Decimal | None = None
def generate_signal(self, candles: list[Candle]) -> Signal | None:
"""Return a Supertrend signal when the trend changes direction."""
if not self.validate_candles(candles, self._min_required):
return None
frame = self._candles_to_frame(candles)
if frame.empty:
return None
supertrend = ta.supertrend(
high=frame["high"],
low=frame["low"],
close=frame["close"],
length=self._period,
multiplier=float(self._multiplier),
)
direction_col = next((col for col in supertrend.columns if col.startswith("SUPERTd_")), None)
if direction_col is None:
logger.debug("Supertrend direction series missing", strategy=self.name)
return None
direction_series = supertrend[direction_col].dropna()
if len(direction_series) < 2:
return None
latest_direction = self._to_decimal(direction_series.iloc[-1])
previous_direction = self._to_decimal(direction_series.iloc[-2])
if latest_direction is None or previous_direction is None:
return None
atr_value = self._compute_atr(frame)
self._last_atr = atr_value
entry_price = self._decimal_from_series_tail(frame["close"])
if entry_price is None or atr_value is None and entry_price <= Decimal("0"):
return None
signal_type: SignalType
signal_side: Side
direction_label: str
if previous_direction < Decimal("0") and latest_direction > Decimal("0"):
signal_type = SignalType.ENTRY_LONG
signal_side = Side.BUY
direction_label = "bullish"
elif previous_direction > Decimal("0") and latest_direction < Decimal("0"):
signal_type = SignalType.ENTRY_SHORT
signal_side = Side.SELL
direction_label = "bearish"
else:
return None
stop_loss = self.get_stop_loss(entry_price, signal_side)
trend_value = self._trend_level(supertrend)
metadata = {
"direction": direction_label,
"atr": atr_value,
"supertrend": trend_value,
}
logger.info(
"Supertrend crossover detected",
strategy=self.name,
signal_type=signal_type.value,
direction=direction_label,
)
return Signal(
signal_type=signal_type,
symbol="",
price=entry_price,
stop_loss=stop_loss,
take_profit=None,
confidence=0.65,
timestamp=candles[-1].timestamp,
strategy_name=self.name,
metadata=metadata,
)
def get_stop_loss(self, entry_price: Decimal, side: Side) -> Decimal:
"""Use ATR buffer for Supertrend stop loss."""
atr_buffer = self._last_atr if self._last_atr and self._last_atr > Decimal("0") else entry_price * Decimal("0.02")
if side == Side.BUY:
stop = entry_price - atr_buffer
else:
stop = entry_price + atr_buffer
return stop if stop > Decimal("0") else Decimal("0.01")
@property
def parameters(self) -> dict[str, Decimal | int]:
"""Expose current Supertrend parameters."""
return {"period": self._period, "multiplier": self._multiplier}
@property
def suitable_regimes(self) -> list[Regime]:
"""This strategy runs only in trending regimes."""
return [Regime.TRENDING_UP, Regime.TRENDING_DOWN]
def _compute_atr(self, frame: pd.DataFrame) -> Decimal | None:
atr_result: Any = ta.atr(
high=frame["high"],
low=frame["low"],
close=frame["close"],
length=self._period,
)
if atr_result is None:
return None
if isinstance(atr_result, pd.Series):
if atr_result.empty:
return None
return self._to_decimal(atr_result.iloc[-1])
atr_df: pd.DataFrame = atr_result
if atr_df.empty:
return None
atr_col = next((col for col in atr_df.columns if "ATR" in col), None)
if atr_col is None:
return None
return self._to_decimal(atr_df[atr_col].iloc[-1])
@staticmethod
def _trend_level(supertrend: pd.DataFrame) -> Decimal | None:
trend_col = next((col for col in supertrend.columns if col.startswith("SUPERT_") and not col.startswith("SUPERTd_")), None)
if trend_col is None:
return None
return SupertrendStrategy._decimal_from_series_tail(supertrend[trend_col])
@staticmethod
def _candles_to_frame(candles: list[Candle]) -> pd.DataFrame:
if not candles:
return pd.DataFrame()
frame = pd.DataFrame([candle.to_dict() for candle in candles])
frame["timestamp"] = pd.to_datetime(frame["timestamp"], utc=True)
return frame
@staticmethod
def _decimal_from_series_tail(series: pd.Series) -> Decimal | None:
if series.empty:
return None
return SupertrendStrategy._to_decimal(series.iloc[-1])
@staticmethod
def _to_decimal(value: float | int | Decimal | None) -> Decimal | None:
if value is None:
return None
try:
if pd.isna(value): # type: ignore[arg-type]
return None
except (TypeError, ValueError):
pass
return Decimal(str(value))