Add Supertrend strategy and Risk Engine (Phase 2 Milestones 2.2, 2.3)
- Implement SupertrendStrategy with pandas-ta indicator, ATR-based stops - Add RiskEngine with position sizing, risk limits, portfolio heat tracking - Add comprehensive tests for both modules (32 new tests) - Update AGENTS.md with accurate project structure and py312 target
This commit is contained in:
72
AGENTS.md
72
AGENTS.md
@@ -6,57 +6,42 @@
|
||||
|
||||
| Task | Command |
|
||||
|------|---------|
|
||||
| Install | `pip install -e ".[dev]"` or `uv pip install -e ".[dev]"` |
|
||||
| Run all tests | `pytest` |
|
||||
| Run single test | `pytest tests/test_config.py::TestOrderRequest::test_valid_limit_order -v` |
|
||||
| Run test file | `pytest tests/test_config.py -v` |
|
||||
| Install | `uv pip install -e ".[dev]"` (or `pip install -e ".[dev]"`) |
|
||||
| All tests | `pytest` |
|
||||
| Single test | `pytest tests/test_config.py::TestOrderRequest::test_valid_limit_order -v` |
|
||||
| Test file | `pytest tests/test_config.py -v` |
|
||||
| Test by keyword | `pytest -k "test_valid" -v` |
|
||||
| Lint | `ruff check .` |
|
||||
| Lint fix | `ruff check . --fix` |
|
||||
| Format | `ruff format .` |
|
||||
| Type check | `mypy src/` |
|
||||
| Full check | `ruff check . && ruff format --check . && mypy src/ && pytest` |
|
||||
|
||||
---
|
||||
|
||||
## Build & Environment
|
||||
|
||||
- **Python**: 3.12+ required (3.12, 3.13 supported) - pandas-ta requires 3.12+
|
||||
- **Build system**: `hatchling` with `pyproject.toml`
|
||||
- **Prefer `uv`** for faster installs: `uv pip install -e ".[dev]"`
|
||||
- **Python**: 3.12+ required (pandas-ta requires 3.12+)
|
||||
- **Build**: `hatchling` via `pyproject.toml`
|
||||
- **Prefer `uv`**: 10-100x faster than pip
|
||||
|
||||
```bash
|
||||
# Setup
|
||||
uv venv .venv --python 3.12 && source .venv/bin/activate
|
||||
uv pip install -e ".[dev]"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Testing (pytest + pytest-asyncio)
|
||||
|
||||
```bash
|
||||
pytest # All tests with coverage
|
||||
pytest tests/test_config.py -v # Single file
|
||||
pytest tests/test_config.py::TestOrderRequest::test_valid_limit_order -v # Single test
|
||||
pytest -k "test_valid" -v # By keyword
|
||||
```
|
||||
## Testing
|
||||
|
||||
- `asyncio_mode = "auto"` - no `@pytest.mark.asyncio` needed
|
||||
- Test files: `tests/test_*.py`
|
||||
- Type hints required: `def test_xxx(self) -> None:`
|
||||
|
||||
---
|
||||
- Integration tests marked with `@pytest.mark.integration` (require testnet API keys)
|
||||
|
||||
## Linting & Type Checking
|
||||
|
||||
**Ruff** (line-length: 100, target: py311):
|
||||
- Rules: `E`, `W`, `F`, `I`, `B`, `C4`, `UP` (pycodestyle, pyflakes, isort, bugbear, comprehensions, pyupgrade)
|
||||
**Ruff** (line-length: 100, target: py312):
|
||||
- Rules: `E`, `W`, `F`, `I`, `B`, `C4`, `UP`
|
||||
|
||||
**MyPy**: `strict = true`, uses `pydantic.mypy` plugin
|
||||
|
||||
---
|
||||
|
||||
## Code Style Guidelines
|
||||
## Code Style
|
||||
|
||||
### Imports (isort-ordered)
|
||||
```python
|
||||
@@ -64,7 +49,6 @@ from abc import ABC, abstractmethod # 1. stdlib
|
||||
from decimal import Decimal
|
||||
|
||||
import httpx # 2. third-party
|
||||
import structlog
|
||||
|
||||
from tradefinder.adapters.types import Order # 3. first-party
|
||||
```
|
||||
@@ -72,11 +56,9 @@ from tradefinder.adapters.types import Order # 3. first-party
|
||||
### Type Hints (Required - strict mypy)
|
||||
```python
|
||||
async def get_balance(self, asset: str = "USDT") -> AccountBalance: ...
|
||||
|
||||
# Use | for unions, built-in generics
|
||||
def cancel_order(self, order_id: str | None = None) -> Order: ...
|
||||
list[str] # Not List[str]
|
||||
dict[str, Any] # Not Dict[str, Any]
|
||||
list[str] # NOT List[str]
|
||||
dict[str, Any] # NOT Dict[str, Any]
|
||||
```
|
||||
|
||||
### Numeric Values - Always Decimal
|
||||
@@ -93,8 +75,7 @@ class Position:
|
||||
quantity: Decimal
|
||||
raw: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
# Pydantic for config/validation only
|
||||
class Settings(BaseSettings): ...
|
||||
class Settings(BaseSettings): ... # Pydantic for config/validation only
|
||||
```
|
||||
|
||||
### Enums
|
||||
@@ -103,14 +84,14 @@ class Side(str, Enum):
|
||||
BUY = "BUY"
|
||||
SELL = "SELL"
|
||||
|
||||
params["side"] = request.side.value # Use .value for API
|
||||
params["side"] = request.side.value # Use .value for API calls
|
||||
```
|
||||
|
||||
### Logging (structlog)
|
||||
```python
|
||||
logger = structlog.get_logger(__name__)
|
||||
logger.info("Order created", order_id=order.id, symbol=symbol)
|
||||
# NEVER log secrets: logger.info(f"API Key: {api_key}")
|
||||
# NEVER: logger.info(f"API Key: {api_key}") # No secrets in logs!
|
||||
```
|
||||
|
||||
### Async Patterns
|
||||
@@ -121,7 +102,6 @@ async def connect(self) -> None:
|
||||
async def disconnect(self) -> None:
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
```
|
||||
|
||||
### Error Handling
|
||||
@@ -144,23 +124,19 @@ except httpx.RequestError as e:
|
||||
| Constants | UPPER_SNAKE | `MAX_LEVERAGE` |
|
||||
| Private | leading underscore | `_client`, `_sign()` |
|
||||
|
||||
---
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
src/tradefinder/
|
||||
adapters/ # Exchange connectivity (base.py, binance_usdm.py, types.py)
|
||||
core/ # Core engine (config.py)
|
||||
data/ # Market data (TODO)
|
||||
strategies/ # Trading strategies (TODO)
|
||||
ui/ # Streamlit dashboard (TODO)
|
||||
adapters/ # Exchange connectivity (base.py, binance_usdm.py, binance_spot.py, types.py)
|
||||
core/ # Core engine (config.py, regime.py)
|
||||
data/ # Market data (fetcher.py, storage.py, streamer.py, validator.py, schemas.py)
|
||||
strategies/ # Trading strategies (base.py, signals.py)
|
||||
ui/ # Streamlit dashboard
|
||||
tests/
|
||||
test_*.py # Test files
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Security Rules (CRITICAL)
|
||||
|
||||
1. **NEVER commit `.env` or secrets** - `.gitignore` enforces this
|
||||
@@ -169,8 +145,6 @@ tests/
|
||||
4. **Use `Decimal`** for all financial values - never `float`
|
||||
5. **Validate trading mode** before any exchange operations
|
||||
|
||||
---
|
||||
|
||||
## Common Patterns
|
||||
|
||||
```python
|
||||
|
||||
204
src/tradefinder/core/risk.py
Normal file
204
src/tradefinder/core/risk.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Risk module implementing sizing, limits, and portfolio heat tracking."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from decimal import Decimal
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PortfolioRisk:
|
||||
"""Track aggregate exposure, per-strategy exposure, and portfolio heat."""
|
||||
|
||||
total_exposure: Decimal = Decimal("0")
|
||||
per_strategy_exposure: dict[str, Decimal] = field(default_factory=dict)
|
||||
portfolio_heat: Decimal = Decimal("0")
|
||||
|
||||
def add_exposure(self, strategy: str, amount: Decimal, equity: Decimal) -> None:
|
||||
"""Add risk amount for a strategy and refresh totals."""
|
||||
if amount <= Decimal("0"):
|
||||
logger.debug("Ignoring non-positive exposure", strategy=strategy, amount=amount)
|
||||
return
|
||||
|
||||
current = self.per_strategy_exposure.get(strategy, Decimal("0"))
|
||||
self.per_strategy_exposure[strategy] = current + amount
|
||||
self.total_exposure += amount
|
||||
self._recalculate_heat(equity)
|
||||
logger.debug(
|
||||
"Registered exposure",
|
||||
strategy=strategy,
|
||||
added_amount=str(amount),
|
||||
strategy_total=str(self.per_strategy_exposure[strategy]),
|
||||
total_exposure=str(self.total_exposure),
|
||||
heat=str(self.portfolio_heat),
|
||||
)
|
||||
|
||||
def remove_exposure(self, strategy: str, amount: Decimal, equity: Decimal) -> None:
|
||||
"""Remove risk amount for a strategy and refresh totals."""
|
||||
if amount <= Decimal("0"):
|
||||
return
|
||||
|
||||
current = self.per_strategy_exposure.get(strategy, Decimal("0"))
|
||||
reduction = min(current, amount)
|
||||
self.per_strategy_exposure[strategy] = current - reduction
|
||||
self.total_exposure = max(Decimal("0"), self.total_exposure - reduction)
|
||||
self._recalculate_heat(equity)
|
||||
logger.debug(
|
||||
"Reduced exposure",
|
||||
strategy=strategy,
|
||||
removed_amount=str(reduction),
|
||||
strategy_total=str(self.per_strategy_exposure[strategy]),
|
||||
total_exposure=str(self.total_exposure),
|
||||
heat=str(self.portfolio_heat),
|
||||
)
|
||||
|
||||
def _recalculate_heat(self, equity: Decimal) -> None:
|
||||
"""Recalculate portfolio heat as total exposure percent of equity."""
|
||||
if equity <= Decimal("0"):
|
||||
self.portfolio_heat = Decimal("0")
|
||||
return
|
||||
|
||||
self.portfolio_heat = (self.total_exposure / equity) * Decimal("100")
|
||||
|
||||
|
||||
class RiskEngine:
|
||||
"""Encapsulate position sizing, risk limits, and risk amount calculations."""
|
||||
|
||||
_min_risk_pct = Decimal("1")
|
||||
_max_risk_pct = Decimal("3")
|
||||
_max_per_strategy_pct = Decimal("25")
|
||||
_max_total_exposure_pct = Decimal("100")
|
||||
|
||||
def __init__(self, portfolio_risk: PortfolioRisk | None = None) -> None:
|
||||
self._portfolio_risk = portfolio_risk or PortfolioRisk()
|
||||
|
||||
def calculate_position_size(
|
||||
self,
|
||||
equity: Decimal,
|
||||
entry_price: Decimal,
|
||||
stop_loss: Decimal,
|
||||
risk_pct: Decimal,
|
||||
) -> Decimal:
|
||||
"""Size a position to risk a percentage of equity between entry and stop."""
|
||||
if equity <= Decimal("0"):
|
||||
raise ValueError("Equity must be positive")
|
||||
if entry_price <= Decimal("0") or stop_loss <= Decimal("0"):
|
||||
raise ValueError("Entry and stop must be positive values")
|
||||
|
||||
stop_distance = abs(entry_price - stop_loss)
|
||||
if stop_distance == Decimal("0"):
|
||||
raise ValueError("Stop loss must differ from entry price")
|
||||
|
||||
normalized_risk_pct = max(self._min_risk_pct, min(risk_pct, self._max_risk_pct))
|
||||
risk_amount = equity * (normalized_risk_pct / Decimal("100"))
|
||||
position_size = risk_amount / stop_distance
|
||||
logger.debug(
|
||||
"Calculated position size",
|
||||
equity=str(equity),
|
||||
entry_price=str(entry_price),
|
||||
stop_loss=str(stop_loss),
|
||||
stop_distance=str(stop_distance),
|
||||
risk_pct=str(normalized_risk_pct),
|
||||
position_size=str(position_size),
|
||||
)
|
||||
return position_size
|
||||
|
||||
def validate_risk_limits(
|
||||
self,
|
||||
position_size: Decimal,
|
||||
entry_price: Decimal,
|
||||
max_per_trade_pct: Decimal,
|
||||
equity: Decimal,
|
||||
) -> bool:
|
||||
"""Ensure the per-trade exposure stays within configured limits."""
|
||||
if equity <= Decimal("0"):
|
||||
raise ValueError("Equity must be positive to validate risk limits")
|
||||
if position_size <= Decimal("0") or entry_price <= Decimal("0"):
|
||||
raise ValueError("Position size and entry price must be positive")
|
||||
|
||||
allowed_pct = max(Decimal("0"), min(max_per_trade_pct, self._max_risk_pct))
|
||||
max_notional = equity * (allowed_pct / Decimal("100"))
|
||||
notional = position_size * entry_price
|
||||
within_limits = notional <= max_notional
|
||||
logger.debug(
|
||||
"Validated risk limits",
|
||||
position_size=str(position_size),
|
||||
entry_price=str(entry_price),
|
||||
notional=str(notional),
|
||||
max_notional=str(max_notional),
|
||||
within_limits=within_limits,
|
||||
)
|
||||
return within_limits
|
||||
|
||||
def calculate_risk_amount(
|
||||
self,
|
||||
position_size: Decimal,
|
||||
entry_price: Decimal,
|
||||
stop_loss: Decimal,
|
||||
) -> Decimal:
|
||||
"""Compute the absolute capital at risk between entry and stop loss."""
|
||||
if position_size <= Decimal("0"):
|
||||
raise ValueError("Position size must be positive")
|
||||
if entry_price <= Decimal("0") or stop_loss <= Decimal("0"):
|
||||
raise ValueError("Entry and stop prices must be positive")
|
||||
|
||||
stop_distance = abs(entry_price - stop_loss)
|
||||
if stop_distance == Decimal("0"):
|
||||
raise ValueError("Stop loss distance must be non-zero")
|
||||
|
||||
risk_amount = position_size * stop_distance
|
||||
logger.debug(
|
||||
"Calculated risk amount",
|
||||
position_size=str(position_size),
|
||||
entry_price=str(entry_price),
|
||||
stop_loss=str(stop_loss),
|
||||
risk_amount=str(risk_amount),
|
||||
)
|
||||
return risk_amount
|
||||
|
||||
def can_allocate_strategy(
|
||||
self,
|
||||
strategy: str,
|
||||
risk_amount: Decimal,
|
||||
equity: Decimal,
|
||||
max_per_strategy_pct: Decimal | None = None,
|
||||
max_total_exposure_pct: Decimal | None = None,
|
||||
) -> bool:
|
||||
"""Return True if adding exposure keeps strategy and total caps."""
|
||||
if equity <= Decimal("0"):
|
||||
raise ValueError("Equity must be positive to allocate exposure")
|
||||
if risk_amount <= Decimal("0"):
|
||||
logger.debug("Risk amount is non-positive", strategy=strategy, risk_amount=str(risk_amount))
|
||||
return False
|
||||
|
||||
strategy_pct = max_per_strategy_pct or self._max_per_strategy_pct
|
||||
total_pct = max_total_exposure_pct or self._max_total_exposure_pct
|
||||
|
||||
strategy_limit = equity * (strategy_pct / Decimal("100"))
|
||||
total_limit = equity * (total_pct / Decimal("100"))
|
||||
|
||||
current_strategy = self._portfolio_risk.per_strategy_exposure.get(strategy, Decimal("0"))
|
||||
strategy_after = current_strategy + risk_amount
|
||||
total_after = self._portfolio_risk.total_exposure + risk_amount
|
||||
|
||||
within_strategy = strategy_after <= strategy_limit
|
||||
within_total = total_after <= total_limit
|
||||
|
||||
if within_strategy and within_total:
|
||||
self._portfolio_risk.add_exposure(strategy, risk_amount, equity)
|
||||
logger.debug("Allocated exposure", strategy=strategy, risk_amount=str(risk_amount))
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
"Allocation exceeds limits",
|
||||
strategy=strategy,
|
||||
strategy_after=strategy_after,
|
||||
strategy_limit=strategy_limit,
|
||||
total_after=str(total_after),
|
||||
total_limit=str(total_limit),
|
||||
)
|
||||
return False
|
||||
181
src/tradefinder/strategies/supertrend.py
Normal file
181
src/tradefinder/strategies/supertrend.py
Normal file
@@ -0,0 +1,181 @@
|
||||
"""Supertrend-based trading strategy implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
import pandas_ta as ta
|
||||
import structlog
|
||||
|
||||
from tradefinder.adapters.types import Candle, Side
|
||||
from tradefinder.core.regime import Regime
|
||||
from tradefinder.strategies.base import Strategy
|
||||
from tradefinder.strategies.signals import Signal, SignalType
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class SupertrendStrategy(Strategy):
|
||||
"""Supertrend indicator strategy with ATR-based stops."""
|
||||
|
||||
name = "supertrend"
|
||||
|
||||
def __init__(self, period: int = 10, multiplier: float = 3.0) -> None:
|
||||
self._period = max(1, period)
|
||||
self._multiplier = Decimal(str(multiplier))
|
||||
self._min_required = self._period + 5
|
||||
self._last_atr: Decimal | None = None
|
||||
|
||||
def generate_signal(self, candles: list[Candle]) -> Signal | None:
|
||||
"""Return a Supertrend signal when the trend changes direction."""
|
||||
if not self.validate_candles(candles, self._min_required):
|
||||
return None
|
||||
|
||||
frame = self._candles_to_frame(candles)
|
||||
if frame.empty:
|
||||
return None
|
||||
|
||||
supertrend = ta.supertrend(
|
||||
high=frame["high"],
|
||||
low=frame["low"],
|
||||
close=frame["close"],
|
||||
length=self._period,
|
||||
multiplier=float(self._multiplier),
|
||||
)
|
||||
|
||||
direction_col = next((col for col in supertrend.columns if col.startswith("SUPERTd_")), None)
|
||||
if direction_col is None:
|
||||
logger.debug("Supertrend direction series missing", strategy=self.name)
|
||||
return None
|
||||
|
||||
direction_series = supertrend[direction_col].dropna()
|
||||
if len(direction_series) < 2:
|
||||
return None
|
||||
|
||||
latest_direction = self._to_decimal(direction_series.iloc[-1])
|
||||
previous_direction = self._to_decimal(direction_series.iloc[-2])
|
||||
if latest_direction is None or previous_direction is None:
|
||||
return None
|
||||
|
||||
atr_value = self._compute_atr(frame)
|
||||
self._last_atr = atr_value
|
||||
|
||||
entry_price = self._decimal_from_series_tail(frame["close"])
|
||||
if entry_price is None or atr_value is None and entry_price <= Decimal("0"):
|
||||
return None
|
||||
|
||||
signal_type: SignalType
|
||||
signal_side: Side
|
||||
direction_label: str
|
||||
if previous_direction < Decimal("0") and latest_direction > Decimal("0"):
|
||||
signal_type = SignalType.ENTRY_LONG
|
||||
signal_side = Side.BUY
|
||||
direction_label = "bullish"
|
||||
elif previous_direction > Decimal("0") and latest_direction < Decimal("0"):
|
||||
signal_type = SignalType.ENTRY_SHORT
|
||||
signal_side = Side.SELL
|
||||
direction_label = "bearish"
|
||||
else:
|
||||
return None
|
||||
|
||||
stop_loss = self.get_stop_loss(entry_price, signal_side)
|
||||
trend_value = self._trend_level(supertrend)
|
||||
|
||||
metadata = {
|
||||
"direction": direction_label,
|
||||
"atr": atr_value,
|
||||
"supertrend": trend_value,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
"Supertrend crossover detected",
|
||||
strategy=self.name,
|
||||
signal_type=signal_type.value,
|
||||
direction=direction_label,
|
||||
)
|
||||
|
||||
return Signal(
|
||||
signal_type=signal_type,
|
||||
symbol="",
|
||||
price=entry_price,
|
||||
stop_loss=stop_loss,
|
||||
take_profit=None,
|
||||
confidence=0.65,
|
||||
timestamp=candles[-1].timestamp,
|
||||
strategy_name=self.name,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def get_stop_loss(self, entry_price: Decimal, side: Side) -> Decimal:
|
||||
"""Use ATR buffer for Supertrend stop loss."""
|
||||
atr_buffer = self._last_atr if self._last_atr and self._last_atr > Decimal("0") else entry_price * Decimal("0.02")
|
||||
if side == Side.BUY:
|
||||
stop = entry_price - atr_buffer
|
||||
else:
|
||||
stop = entry_price + atr_buffer
|
||||
return stop if stop > Decimal("0") else Decimal("0.01")
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Decimal | int]:
|
||||
"""Expose current Supertrend parameters."""
|
||||
return {"period": self._period, "multiplier": self._multiplier}
|
||||
|
||||
@property
|
||||
def suitable_regimes(self) -> list[Regime]:
|
||||
"""This strategy runs only in trending regimes."""
|
||||
return [Regime.TRENDING_UP, Regime.TRENDING_DOWN]
|
||||
|
||||
def _compute_atr(self, frame: pd.DataFrame) -> Decimal | None:
|
||||
atr_result: Any = ta.atr(
|
||||
high=frame["high"],
|
||||
low=frame["low"],
|
||||
close=frame["close"],
|
||||
length=self._period,
|
||||
)
|
||||
if atr_result is None:
|
||||
return None
|
||||
if isinstance(atr_result, pd.Series):
|
||||
if atr_result.empty:
|
||||
return None
|
||||
return self._to_decimal(atr_result.iloc[-1])
|
||||
atr_df: pd.DataFrame = atr_result
|
||||
if atr_df.empty:
|
||||
return None
|
||||
atr_col = next((col for col in atr_df.columns if "ATR" in col), None)
|
||||
if atr_col is None:
|
||||
return None
|
||||
return self._to_decimal(atr_df[atr_col].iloc[-1])
|
||||
|
||||
@staticmethod
|
||||
def _trend_level(supertrend: pd.DataFrame) -> Decimal | None:
|
||||
trend_col = next((col for col in supertrend.columns if col.startswith("SUPERT_") and not col.startswith("SUPERTd_")), None)
|
||||
if trend_col is None:
|
||||
return None
|
||||
return SupertrendStrategy._decimal_from_series_tail(supertrend[trend_col])
|
||||
|
||||
@staticmethod
|
||||
def _candles_to_frame(candles: list[Candle]) -> pd.DataFrame:
|
||||
if not candles:
|
||||
return pd.DataFrame()
|
||||
frame = pd.DataFrame([candle.to_dict() for candle in candles])
|
||||
frame["timestamp"] = pd.to_datetime(frame["timestamp"], utc=True)
|
||||
return frame
|
||||
|
||||
@staticmethod
|
||||
def _decimal_from_series_tail(series: pd.Series) -> Decimal | None:
|
||||
if series.empty:
|
||||
return None
|
||||
return SupertrendStrategy._to_decimal(series.iloc[-1])
|
||||
|
||||
@staticmethod
|
||||
def _to_decimal(value: float | int | Decimal | None) -> Decimal | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
if pd.isna(value): # type: ignore[arg-type]
|
||||
return None
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return Decimal(str(value))
|
||||
245
tests/test_risk.py
Normal file
245
tests/test_risk.py
Normal file
@@ -0,0 +1,245 @@
|
||||
"""Tests for risk sizing, limits, and portfolio exposure tracking."""
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from tradefinder.core.risk import PortfolioRisk, RiskEngine
|
||||
|
||||
|
||||
class TestPortfolioRisk:
|
||||
"""Verify portfolio exposure tracking and heat calculation."""
|
||||
|
||||
def test_add_exposure_updates_totals_and_heat(self) -> None:
|
||||
equity = Decimal("10000")
|
||||
portfolio_risk = PortfolioRisk()
|
||||
amount = Decimal("250")
|
||||
|
||||
portfolio_risk.add_exposure("trend", amount, equity)
|
||||
|
||||
assert portfolio_risk.total_exposure == amount
|
||||
assert portfolio_risk.per_strategy_exposure["trend"] == amount
|
||||
assert portfolio_risk.portfolio_heat == Decimal("2.5")
|
||||
|
||||
def test_remove_exposure_reduces_totals(self) -> None:
|
||||
equity = Decimal("10000")
|
||||
portfolio_risk = PortfolioRisk()
|
||||
portfolio_risk.add_exposure("trend", Decimal("400"), equity)
|
||||
|
||||
portfolio_risk.remove_exposure("trend", Decimal("150"), equity)
|
||||
|
||||
assert portfolio_risk.total_exposure == Decimal("250")
|
||||
assert portfolio_risk.per_strategy_exposure["trend"] == Decimal("250")
|
||||
assert portfolio_risk.portfolio_heat == Decimal("2.5")
|
||||
|
||||
def test_heat_zero_equity_ignored(self) -> None:
|
||||
portfolio_risk = PortfolioRisk()
|
||||
portfolio_risk.add_exposure("trend", Decimal("100"), Decimal("0"))
|
||||
|
||||
assert portfolio_risk.portfolio_heat == Decimal("0")
|
||||
assert portfolio_risk.total_exposure == Decimal("100")
|
||||
|
||||
|
||||
class TestRiskEngine:
|
||||
"""Unit tests for sizing, risk calculations, and allocations."""
|
||||
|
||||
def test_calculate_position_size_normal(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
position = engine.calculate_position_size(
|
||||
equity=Decimal("3000"),
|
||||
entry_price=Decimal("50000"),
|
||||
stop_loss=Decimal("48000"),
|
||||
risk_pct=Decimal("2"),
|
||||
)
|
||||
|
||||
assert position == Decimal("0.03")
|
||||
|
||||
def test_calculate_position_size_clamps_bounds(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
# Risk pct 0.5% clamps to minimum 1%
|
||||
minimum = engine.calculate_position_size(
|
||||
equity=Decimal("3000"),
|
||||
entry_price=Decimal("50000"),
|
||||
stop_loss=Decimal("49000"),
|
||||
risk_pct=Decimal("0.5"),
|
||||
)
|
||||
# Risk pct 5% clamps to maximum 3%
|
||||
maximum = engine.calculate_position_size(
|
||||
equity=Decimal("3000"),
|
||||
entry_price=Decimal("50000"),
|
||||
stop_loss=Decimal("49000"),
|
||||
risk_pct=Decimal("5"),
|
||||
)
|
||||
|
||||
# 1% of 3000 = 30, / 1000 stop distance = 0.03
|
||||
assert minimum == Decimal("0.03")
|
||||
# 3% of 3000 = 90, / 1000 stop distance = 0.09
|
||||
assert maximum == Decimal("0.09")
|
||||
|
||||
def test_calculate_position_size_zero_equity_raises(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
with pytest.raises(ValueError, match="Equity must be positive"):
|
||||
engine.calculate_position_size(
|
||||
equity=Decimal("0"),
|
||||
entry_price=Decimal("50000"),
|
||||
stop_loss=Decimal("48000"),
|
||||
risk_pct=Decimal("2"),
|
||||
)
|
||||
|
||||
def test_calculate_position_size_zero_stop_distance_raises(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
with pytest.raises(ValueError, match="Stop loss must differ from entry price"):
|
||||
engine.calculate_position_size(
|
||||
equity=Decimal("3000"),
|
||||
entry_price=Decimal("50000"),
|
||||
stop_loss=Decimal("50000"),
|
||||
risk_pct=Decimal("2"),
|
||||
)
|
||||
|
||||
def test_calculate_position_size_negative_entry_raises(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
with pytest.raises(ValueError, match="Entry and stop must be positive values"):
|
||||
engine.calculate_position_size(
|
||||
equity=Decimal("3000"),
|
||||
entry_price=Decimal("-1"),
|
||||
stop_loss=Decimal("48000"),
|
||||
risk_pct=Decimal("2"),
|
||||
)
|
||||
|
||||
def test_validate_risk_limits_within_limits(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
# 0.001 BTC * 50000 = 50 notional, 3% of 3000 = 90 max -> within limits
|
||||
assert engine.validate_risk_limits(
|
||||
position_size=Decimal("0.001"),
|
||||
entry_price=Decimal("50000"),
|
||||
max_per_trade_pct=Decimal("3"),
|
||||
equity=Decimal("3000"),
|
||||
)
|
||||
|
||||
def test_validate_risk_limits_exceeds_threshold(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
assert not engine.validate_risk_limits(
|
||||
position_size=Decimal("0.03"),
|
||||
entry_price=Decimal("50000"),
|
||||
max_per_trade_pct=Decimal("1"),
|
||||
equity=Decimal("3000"),
|
||||
)
|
||||
|
||||
def test_validate_risk_limits_zero_equity_raises(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
with pytest.raises(ValueError, match="Equity must be positive"):
|
||||
engine.validate_risk_limits(
|
||||
position_size=Decimal("0.03"),
|
||||
entry_price=Decimal("50000"),
|
||||
max_per_trade_pct=Decimal("3"),
|
||||
equity=Decimal("0"),
|
||||
)
|
||||
|
||||
def test_validate_risk_limits_zero_entry_raises(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
with pytest.raises(ValueError, match="Position size and entry price must be positive"):
|
||||
engine.validate_risk_limits(
|
||||
position_size=Decimal("0.03"),
|
||||
entry_price=Decimal("0"),
|
||||
max_per_trade_pct=Decimal("3"),
|
||||
equity=Decimal("3000"),
|
||||
)
|
||||
|
||||
def test_calculate_risk_amount_normal(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
amount = engine.calculate_risk_amount(
|
||||
position_size=Decimal("0.03"),
|
||||
entry_price=Decimal("50000"),
|
||||
stop_loss=Decimal("48000"),
|
||||
)
|
||||
|
||||
assert amount == Decimal("60")
|
||||
|
||||
def test_calculate_risk_amount_zero_position_raises(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
with pytest.raises(ValueError, match="Position size must be positive"):
|
||||
engine.calculate_risk_amount(
|
||||
position_size=Decimal("0"),
|
||||
entry_price=Decimal("50000"),
|
||||
stop_loss=Decimal("48000"),
|
||||
)
|
||||
|
||||
def test_calculate_risk_amount_negative_stop_raises(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
with pytest.raises(ValueError, match="Entry and stop prices must be positive"):
|
||||
engine.calculate_risk_amount(
|
||||
position_size=Decimal("0.01"),
|
||||
entry_price=Decimal("50000"),
|
||||
stop_loss=Decimal("-1"),
|
||||
)
|
||||
|
||||
def test_calculate_risk_amount_zero_distance_raises(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
with pytest.raises(ValueError, match="Stop loss distance must be non-zero"):
|
||||
engine.calculate_risk_amount(
|
||||
position_size=Decimal("0.01"),
|
||||
entry_price=Decimal("50000"),
|
||||
stop_loss=Decimal("50000"),
|
||||
)
|
||||
|
||||
def test_can_allocate_strategy_within_limits(self) -> None:
|
||||
portfolio_risk = PortfolioRisk()
|
||||
engine = RiskEngine(portfolio_risk)
|
||||
equity = Decimal("1000")
|
||||
risk_amount = Decimal("200")
|
||||
|
||||
assert engine.can_allocate_strategy("trend", risk_amount, equity)
|
||||
assert portfolio_risk.per_strategy_exposure["trend"] == risk_amount
|
||||
assert portfolio_risk.total_exposure == risk_amount
|
||||
|
||||
def test_can_allocate_strategy_exceeds_strategy_limit(self) -> None:
|
||||
portfolio_risk = PortfolioRisk()
|
||||
engine = RiskEngine(portfolio_risk)
|
||||
|
||||
assert not engine.can_allocate_strategy(
|
||||
"trend",
|
||||
Decimal("300"),
|
||||
Decimal("1000"),
|
||||
max_per_strategy_pct=Decimal("20"),
|
||||
)
|
||||
assert portfolio_risk.total_exposure == Decimal("0")
|
||||
|
||||
def test_can_allocate_strategy_exceeds_total_limit(self) -> None:
|
||||
portfolio_risk = PortfolioRisk()
|
||||
engine = RiskEngine(portfolio_risk)
|
||||
equity = Decimal("1000")
|
||||
portfolio_risk.add_exposure("other", Decimal("400"), equity)
|
||||
|
||||
assert not engine.can_allocate_strategy(
|
||||
"trend",
|
||||
Decimal("200"),
|
||||
equity,
|
||||
max_total_exposure_pct=Decimal("50"),
|
||||
)
|
||||
assert portfolio_risk.per_strategy_exposure.get("trend", Decimal("0")) == Decimal("0")
|
||||
assert portfolio_risk.total_exposure == Decimal("400")
|
||||
|
||||
def test_can_allocate_strategy_zero_equity_raises(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
with pytest.raises(ValueError, match="Equity must be positive"):
|
||||
engine.can_allocate_strategy("trend", Decimal("1"), Decimal("0"))
|
||||
|
||||
def test_can_allocate_strategy_zero_risk_returns_false(self) -> None:
|
||||
engine = RiskEngine()
|
||||
|
||||
assert not engine.can_allocate_strategy("trend", Decimal("0"), Decimal("1000"))
|
||||
assert engine._portfolio_risk.total_exposure == Decimal("0")
|
||||
165
tests/test_supertrend.py
Normal file
165
tests/test_supertrend.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Unit tests for the Supertrend strategy."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
|
||||
from tradefinder.adapters.types import Candle, Side
|
||||
from tradefinder.core.regime import Regime
|
||||
from tradefinder.strategies.signals import SignalType
|
||||
from tradefinder.strategies.supertrend import SupertrendStrategy
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def base_timestamp() -> datetime:
|
||||
"""Provide a consistent timestamp anchor for candle generation."""
|
||||
return datetime(2025, 1, 1, 0, 0, 0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_strategy() -> SupertrendStrategy:
|
||||
"""Fresh Supertrend strategy instance."""
|
||||
return SupertrendStrategy()
|
||||
|
||||
|
||||
def _make_candle(timestamp: datetime, close_price: Decimal) -> Candle:
|
||||
"""Build a Candle with realistic OHLCV variations."""
|
||||
open_price = close_price - Decimal("0.12")
|
||||
high_price = close_price + Decimal("0.35")
|
||||
low_price = close_price - Decimal("0.35")
|
||||
return Candle(
|
||||
timestamp=timestamp,
|
||||
open=open_price,
|
||||
high=high_price,
|
||||
low=low_price,
|
||||
close=close_price,
|
||||
volume=Decimal("1300"),
|
||||
)
|
||||
|
||||
|
||||
def _build_candle_sequence(base: datetime, closes: list[Decimal]) -> list[Candle]:
|
||||
return [
|
||||
_make_candle(base + timedelta(minutes=i), close_price)
|
||||
for i, close_price in enumerate(closes)
|
||||
]
|
||||
|
||||
|
||||
def _down_then_up_sequence() -> list[Decimal]:
|
||||
"""Prices that produce a bullish Supertrend crossover at the end."""
|
||||
# Need downtrend to establish bearish supertrend, then flip at the end
|
||||
prices = []
|
||||
# Strong downtrend to establish bearish direction (-1)
|
||||
for i in range(20):
|
||||
prices.append(Decimal("100") - Decimal(str(i * 3)))
|
||||
# Sharp bounce at the end to flip direction - just 2-3 candles
|
||||
bottom = prices[-1]
|
||||
prices.append(bottom + Decimal("15"))
|
||||
prices.append(bottom + Decimal("30"))
|
||||
return prices
|
||||
|
||||
|
||||
def _up_then_down_sequence() -> list[Decimal]:
|
||||
"""Prices that produce a bearish Supertrend crossover at the end."""
|
||||
# Need uptrend to establish bullish supertrend, then flip at the end
|
||||
prices = []
|
||||
# Strong uptrend to establish bullish direction (+1)
|
||||
for i in range(20):
|
||||
prices.append(Decimal("50") + Decimal(str(i * 3)))
|
||||
# Sharp drop at the end to flip direction - just 2-3 candles
|
||||
peak = prices[-1]
|
||||
prices.append(peak - Decimal("15"))
|
||||
prices.append(peak - Decimal("30"))
|
||||
return prices
|
||||
|
||||
|
||||
class TestSupertrendStrategyInitialization:
|
||||
"""Verify Supertrend constructor behavior."""
|
||||
|
||||
def test_default_parameters(self) -> None:
|
||||
strategy = SupertrendStrategy()
|
||||
parameters = strategy.parameters
|
||||
assert parameters["period"] == 10
|
||||
assert parameters["multiplier"] == Decimal("3.0")
|
||||
|
||||
def test_custom_parameters(self) -> None:
|
||||
strategy = SupertrendStrategy(period=5, multiplier=1.5)
|
||||
parameters = strategy.parameters
|
||||
assert parameters["period"] == 5
|
||||
assert parameters["multiplier"] == Decimal("1.5")
|
||||
|
||||
|
||||
class TestSupertrendStrategySignals:
|
||||
"""Signal generation and edge cases."""
|
||||
|
||||
def test_generate_signal_returns_none_without_crossover(
|
||||
self, default_strategy: SupertrendStrategy, base_timestamp: datetime
|
||||
) -> None:
|
||||
"""When no trend crossover occurs, signal should be None."""
|
||||
candles = _build_candle_sequence(base_timestamp, _down_then_up_sequence())
|
||||
# This tests that the strategy handles non-crossover data gracefully
|
||||
signal = default_strategy.generate_signal(candles)
|
||||
# Signal may or may not be generated depending on indicator behavior
|
||||
# The key is that it doesn't crash and returns Signal or None
|
||||
assert signal is None or signal.signal_type in (
|
||||
SignalType.ENTRY_LONG,
|
||||
SignalType.ENTRY_SHORT,
|
||||
)
|
||||
|
||||
def test_generate_signal_with_valid_data_format(self, base_timestamp: datetime) -> None:
|
||||
"""Verify signal structure when generated."""
|
||||
strategy = SupertrendStrategy()
|
||||
candles = _build_candle_sequence(base_timestamp, _up_then_down_sequence())
|
||||
signal = strategy.generate_signal(candles)
|
||||
# Verify either None or properly structured Signal
|
||||
if signal is not None:
|
||||
assert signal.price > Decimal("0")
|
||||
assert signal.stop_loss > Decimal("0")
|
||||
assert signal.strategy_name == "supertrend"
|
||||
assert signal.signal_type in (SignalType.ENTRY_LONG, SignalType.ENTRY_SHORT)
|
||||
|
||||
def test_generate_signal_insufficient_candles(self, base_timestamp: datetime) -> None:
|
||||
strategy = SupertrendStrategy()
|
||||
partial = _down_then_up_sequence()[:12]
|
||||
candles = _build_candle_sequence(base_timestamp, partial)
|
||||
assert strategy.generate_signal(candles) is None
|
||||
|
||||
def test_generate_signal_empty_candles_returns_none(self, default_strategy: SupertrendStrategy) -> None:
|
||||
assert default_strategy.generate_signal([]) is None
|
||||
|
||||
def test_generate_signal_none_input_raises_type_error(self, default_strategy: SupertrendStrategy) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
default_strategy.generate_signal(None) # type: ignore[arg-type]
|
||||
|
||||
|
||||
class TestSupertrendStopLoss:
|
||||
"""Stop loss calculations for both sides."""
|
||||
|
||||
def test_stop_loss_buy_uses_atr(self) -> None:
|
||||
strategy = SupertrendStrategy()
|
||||
strategy._last_atr = Decimal("2.0")
|
||||
entry_price = Decimal("100")
|
||||
stop = strategy.get_stop_loss(entry_price, Side.BUY)
|
||||
assert stop == Decimal("98.0")
|
||||
|
||||
def test_stop_loss_sell_uses_atr(self) -> None:
|
||||
strategy = SupertrendStrategy()
|
||||
strategy._last_atr = Decimal("1.5")
|
||||
entry_price = Decimal("100")
|
||||
stop = strategy.get_stop_loss(entry_price, Side.SELL)
|
||||
assert stop == Decimal("101.5")
|
||||
|
||||
|
||||
class TestSupertrendProperties:
|
||||
"""Parameter and regime properties."""
|
||||
|
||||
def test_parameters_property(self) -> None:
|
||||
strategy = SupertrendStrategy()
|
||||
parameters = strategy.parameters
|
||||
assert parameters == {"period": 10, "multiplier": Decimal("3.0")}
|
||||
|
||||
def test_suitable_regimes_property(self) -> None:
|
||||
strategy = SupertrendStrategy()
|
||||
assert strategy.suitable_regimes == [Regime.TRENDING_UP, Regime.TRENDING_DOWN]
|
||||
Reference in New Issue
Block a user