Add Phase 2 foundation: regime classifier, strategy framework, WebSocket streamer

Phase 1 completion:
- Add DataStreamer for real-time Binance Futures WebSocket data (klines, mark price)
- Add DataValidator for candle validation and gap detection
- Add timeframes module with interval mappings

Phase 2 foundation:
- Add RegimeClassifier with ADX/ATR/Bollinger Band analysis
- Add Regime enum (TRENDING_UP/DOWN, RANGING, HIGH_VOLATILITY, UNCERTAIN)
- Add Strategy ABC defining generate_signal, get_stop_loss, parameters, suitable_regimes
- Add Signal dataclass and SignalType enum for strategy outputs

Testing:
- Add comprehensive test suites for all new modules
- 159 tests passing, 24 skipped (async WebSocket timing)
- 82% code coverage

Dependencies:
- Add pandas-stubs to dev dependencies for mypy compatibility
This commit is contained in:
bnair123
2025-12-27 15:28:28 +04:00
parent 7d63e43b7b
commit eca17b42fe
15 changed files with 2579 additions and 1 deletions

View File

@@ -14,6 +14,7 @@ from tradefinder.core.config import (
get_settings,
reset_settings,
)
from tradefinder.core.regime import Regime, RegimeClassifier
__all__ = [
"Settings",
@@ -21,4 +22,6 @@ __all__ = [
"LogFormat",
"get_settings",
"reset_settings",
"Regime",
"RegimeClassifier",
]

View File

@@ -0,0 +1,169 @@
from __future__ import annotations
from decimal import Decimal
from enum import Enum
import pandas as pd
import pandas_ta as ta
import structlog
from tradefinder.adapters.types import Candle
logger = structlog.get_logger(__name__)
class Regime(str, Enum):
"""Market regime categories derived from technical indicators."""
TRENDING_UP = "trending_up"
TRENDING_DOWN = "trending_down"
RANGING = "ranging"
HIGH_VOLATILITY = "high_volatility"
UNCERTAIN = "uncertain"
class RegimeClassifier:
"""Detects regime using ADX, ATR%, and Bollinger Band width."""
_ranging_adx_ceiling = Decimal("20")
_bb_width_threshold = Decimal("0.05")
def __init__(
self,
adx_threshold: int = 25,
atr_lookback: int = 14,
bb_period: int = 20,
) -> None:
self._adx_threshold = Decimal(adx_threshold)
self._atr_lookback = max(1, atr_lookback)
self._bb_period = max(1, bb_period)
def classify(self, candles: list[Candle]) -> Regime:
"""Return the current regime based on the latest indicator values."""
indicators = self.get_indicators(candles)
if not indicators:
logger.debug("Insufficient data to compute regime", candle_count=len(candles))
return Regime.UNCERTAIN
adx = indicators.get("adx")
plus_di = indicators.get("plus_di")
minus_di = indicators.get("minus_di")
bb_width = indicators.get("bb_width")
atr_pct = indicators.get("atr_pct")
atr_pct_avg = indicators.get("atr_pct_avg")
if adx is not None and adx > self._adx_threshold:
if plus_di is not None and minus_di is not None:
if plus_di > minus_di:
return Regime.TRENDING_UP
if minus_di > plus_di:
return Regime.TRENDING_DOWN
if (
adx is not None
and adx < self._ranging_adx_ceiling
and bb_width is not None
and bb_width < self._bb_width_threshold
):
return Regime.RANGING
if (
atr_pct is not None
and atr_pct_avg is not None
and atr_pct_avg > Decimal("0")
and atr_pct > atr_pct_avg * Decimal("1.5")
):
return Regime.HIGH_VOLATILITY
logger.debug("Unable to classify regime definitively", indicators=indicators)
return Regime.UNCERTAIN
def get_indicators(self, candles: list[Candle]) -> dict[str, Decimal | None]:
"""Compute the latest indicators required for regime detection."""
df = self._candles_to_frame(candles)
if df.empty or len(df) < max(self._atr_lookback, self._bb_period):
return {}
length = self._atr_lookback
adx_df = ta.adx(high=df["high"], low=df["low"], close=df["close"], length=length)
adx_key = f"ADX_{length}"
plus_key = f"DMP_{length}"
minus_key = f"DMN_{length}"
if adx_key not in adx_df or plus_key not in adx_df or minus_key not in adx_df:
return {}
adx = self._to_decimal(adx_df[adx_key].iloc[-1])
plus_di = self._to_decimal(adx_df[plus_key].iloc[-1])
minus_di = self._to_decimal(adx_df[minus_key].iloc[-1])
atr_series = ta.atr(high=df["high"], low=df["low"], close=df["close"], length=length)
atr_key = f"ATR_{length}"
if atr_key not in atr_series:
return {}
atr_series = atr_series[atr_key]
atr_latest = self._to_decimal(atr_series.iloc[-1])
atr_pct_series = (atr_series / df["close"]) * 100
atr_pct_series = atr_pct_series.dropna()
atr_pct = self._decimal_from_series_tail(atr_pct_series)
atr_pct_avg = self._decimal_from_series_avg(atr_pct_series)
bb_df = ta.bbands(close=df["close"], length=self._bb_period)
bb_lower = bb_df.get(f"BBL_{self._bb_period}_2.0")
bb_middle = bb_df.get(f"BBM_{self._bb_period}_2.0")
bb_upper = bb_df.get(f"BBU_{self._bb_period}_2.0")
if bb_lower is None or bb_middle is None or bb_upper is None:
bb_width = None
else:
width_series = (bb_upper - bb_lower) / bb_middle.replace(0, pd.NA)
bb_width = self._decimal_from_series_tail(width_series)
indicators = {
"adx": adx,
"plus_di": plus_di,
"minus_di": minus_di,
"atr": atr_latest,
"atr_pct": atr_pct,
"atr_pct_avg": atr_pct_avg,
"bb_width": bb_width,
}
logger.debug("Computed regime indicators", indicators=indicators)
return indicators
@staticmethod
def _to_decimal(value: float | int | Decimal | None) -> Decimal | None:
if value is None:
return None
try:
if pd.isna(value): # type: ignore[arg-type]
return None
except (TypeError, ValueError):
pass
return Decimal(str(value))
@staticmethod
def _decimal_from_series_tail(series: pd.Series) -> Decimal | None:
if series.empty:
return None
return RegimeClassifier._to_decimal(series.iloc[-1])
@staticmethod
def _decimal_from_series_avg(series: pd.Series) -> Decimal | None:
if series.empty:
return None
return RegimeClassifier._to_decimal(series.mean())
@staticmethod
def _candles_to_frame(candles: list[Candle]) -> pd.DataFrame:
if not candles:
return pd.DataFrame()
frame = pd.DataFrame([candle.to_dict() for candle in candles])
frame["timestamp"] = pd.to_datetime(frame["timestamp"], utc=True)
return frame

View File

@@ -21,5 +21,6 @@ Usage:
from tradefinder.data.fetcher import DataFetcher
from tradefinder.data.storage import DataStorage
from tradefinder.data.streamer import DataStreamer
__all__ = ["DataStorage", "DataFetcher"]
__all__ = ["DataStorage", "DataFetcher", "DataStreamer", "DataValidator"]

View File

@@ -0,0 +1,265 @@
"""WebSocket streamer for Binance Futures market data."""
from __future__ import annotations
import asyncio
import inspect
import json
from collections.abc import Awaitable, Callable, Iterable, Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from decimal import Decimal
from typing import Any
import structlog
import websockets
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
from tradefinder.core.config import Settings
logger = structlog.get_logger(__name__)
@dataclass
class KlineMessage:
"""Candlestick payload emitted by the streamer."""
stream: str
symbol: str
timeframe: str
event_time: datetime
open_time: datetime
close_time: datetime
open: Decimal
high: Decimal
low: Decimal
close: Decimal
volume: Decimal
trades: int
is_closed: bool
@dataclass
class MarkPriceMessage:
"""Mark price payload emitted by the streamer."""
stream: str
symbol: str
event_time: datetime
mark_price: Decimal
index_price: Decimal
funding_rate: Decimal
next_funding_time: datetime
KlineCallback = Callable[[KlineMessage], Awaitable[None] | None]
MarkPriceCallback = Callable[[MarkPriceMessage], Awaitable[None] | None]
class DataStreamer:
"""Async streamer for Binance Futures WebSocket data."""
DEFAULT_SYMBOLS: Sequence[str] = ("BTCUSDT", "ETHUSDT")
def __init__(
self,
settings: Settings,
symbols: Iterable[str] | None = None,
timeframe: str | None = None,
min_backoff: float = 1.0,
max_backoff: float = 60.0,
backoff_multiplier: float = 2.0,
) -> None:
"""Create a Binance futures data streamer.
Args:
settings: Application settings containing trading mode.
symbols: Trading symbols to stream (defaults to settings symbols plus BTC/ETH).
timeframe: Candle timeframe for kline streams.
min_backoff: Minimum reconnect delay in seconds.
max_backoff: Maximum reconnect delay in seconds.
backoff_multiplier: Factor to increase backoff on each failure.
"""
self._settings = settings
self._timeframe = (timeframe or settings.execution_timeframe).lower()
self._symbols = tuple(self._normalize_symbols(symbols))
self._kline_streams = tuple(s.lower() + f"@kline_{self._timeframe}" for s in self._symbols)
self._mark_price_streams = tuple(s.lower() + "@markPrice@1s" for s in self._symbols)
self._kline_callbacks: list[KlineCallback] = []
self._mark_price_callbacks: list[MarkPriceCallback] = []
self._task: asyncio.Task[None] | None = None
self._stop_event: asyncio.Event | None = None
self._connection: Any = None # WebSocket connection (type varies by websockets version)
self._min_backoff = min_backoff
self._max_backoff = max_backoff
self._backoff_multiplier = backoff_multiplier
self._stream_path = "/stream?streams=" + "/".join(
(*self._kline_streams, *self._mark_price_streams)
)
@property
def symbols(self) -> tuple[str, ...]:
"""Symbols currently being streamed."""
return self._symbols
def register_kline_callback(self, callback: KlineCallback) -> None:
"""Register a callback invoked for each kline message."""
self._kline_callbacks.append(callback)
def register_mark_price_callback(self, callback: MarkPriceCallback) -> None:
"""Register a callback invoked for each mark price message."""
self._mark_price_callbacks.append(callback)
async def start(self) -> None:
"""Start the streamer in the background."""
if self._task and not self._task.done():
return
self._stop_event = asyncio.Event()
self._task = asyncio.create_task(self._run())
async def run(self) -> None:
"""Run the streamer until stopped."""
await self.start()
if self._task:
await self._task
async def stop(self) -> None:
"""Stop streaming and close the connection."""
if self._stop_event and not self._stop_event.is_set():
self._stop_event.set()
if self._connection:
await self._connection.close(code=1000, reason="shutdown")
if self._task:
await self._task
self._task = None
async def __aenter__(self) -> DataStreamer:
await self.start()
return self
async def __aexit__(
self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: Any | None
) -> None:
await self.stop()
async def _run(self) -> None:
backoff = self._min_backoff
while not self._should_stop:
stream_url = f"{self._settings.binance_ws_url}{self._stream_path}"
try:
async with websockets.connect(
stream_url,
ping_interval=60.0,
ping_timeout=30.0,
) as connection:
self._connection = connection
backoff = self._min_backoff
logger.info("Connected to Binance Futures stream", url=stream_url)
while not self._should_stop:
raw = await connection.recv()
await self._handle_raw(raw)
except asyncio.CancelledError:
raise
except (ConnectionClosedOK, ConnectionClosedError) as exc:
logger.warning("WebSocket connection closed", error=str(exc))
except Exception as exc: # pragma: no cover - best effort logging
logger.error("WebSocket error", error=str(exc))
finally:
self._connection = None
if self._should_stop:
break
await asyncio.sleep(backoff)
backoff = min(backoff * self._backoff_multiplier, self._max_backoff)
logger.info("Reconnecting to Binance Futures stream", delay=backoff)
@property
def _should_stop(self) -> bool:
return bool(self._stop_event and self._stop_event.is_set())
async def _handle_raw(self, raw: str | bytes) -> None:
try:
text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
payload = json.loads(text)
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
logger.warning("Unable to decode WebSocket payload", error=str(exc))
return
stream = payload.get("stream", "")
data = payload.get("data", payload)
lower_stream = stream.lower()
event_type = str(data.get("e", "")).lower()
kline_key = f"@kline_{self._timeframe}"
mark_price_key = "@markprice@1s"
if lower_stream.endswith(kline_key) or event_type == "kline":
await self._handle_kline(data, stream)
elif mark_price_key in lower_stream or event_type == "markprice":
await self._handle_mark_price(data, stream)
else:
logger.debug(
"Unknown stream payload",
stream=stream,
event_type=event_type,
payload=data,
)
async def _handle_kline(self, data: dict[str, Any], stream: str) -> None:
event_time = int(data.get("E", 0))
kline = data.get("k", {})
message = KlineMessage(
stream=stream,
symbol=kline.get("s", ""),
timeframe=kline.get("i", self._timeframe),
event_time=self._datetime_from_ms(event_time),
open_time=self._datetime_from_ms(int(kline.get("t", 0))),
close_time=self._datetime_from_ms(int(kline.get("T", 0))),
open=self._to_decimal(kline.get("o")),
high=self._to_decimal(kline.get("h")),
low=self._to_decimal(kline.get("l")),
close=self._to_decimal(kline.get("c")),
volume=self._to_decimal(kline.get("v")),
trades=int(kline.get("n", 0)),
is_closed=bool(kline.get("x", False)),
)
await self._dispatch_callbacks(self._kline_callbacks, message)
async def _handle_mark_price(self, data: dict[str, Any], stream: str) -> None:
event_time = int(data.get("E", 0))
message = MarkPriceMessage(
stream=stream,
symbol=data.get("s", ""),
event_time=self._datetime_from_ms(event_time),
mark_price=self._to_decimal(data.get("p")),
index_price=self._to_decimal(data.get("i")),
funding_rate=self._to_decimal(data.get("r")),
next_funding_time=self._datetime_from_ms(int(data.get("T", 0))),
)
await self._dispatch_callbacks(self._mark_price_callbacks, message)
async def _dispatch_callbacks(
self, callbacks: list[Callable[[Any], Awaitable[None] | None]], message: Any
) -> None:
for callback in callbacks:
try:
result = callback(message)
if inspect.isawaitable(result):
await result
except Exception as exc: # pragma: no cover - best effort
callback_name = getattr(callback, "__name__", repr(callback))
logger.error("Callback failed", callback=callback_name, error=str(exc))
@staticmethod
def _datetime_from_ms(timestamp_ms: int) -> datetime:
return datetime.fromtimestamp(timestamp_ms / 1000, tz=UTC)
@staticmethod
def _to_decimal(value: Any) -> Decimal:
if isinstance(value, Decimal):
return value
return Decimal(str(value)) if value not in (None, "") else Decimal("0")
def _normalize_symbols(self, symbols: Iterable[str] | None) -> tuple[str, ...]:
candidates = tuple(symbols or self._settings.symbols_list)
normalized = {sym.upper() for sym in candidates if sym}
normalized.update(self.DEFAULT_SYMBOLS)
return tuple(sorted(normalized))

View File

@@ -0,0 +1,18 @@
"""Shared timeframe helpers for the data layer."""
TIMEFRAME_MS: dict[str, int] = {
"1m": 60 * 1000,
"3m": 3 * 60 * 1000,
"5m": 5 * 60 * 1000,
"15m": 15 * 60 * 1000,
"30m": 30 * 60 * 1000,
"1h": 60 * 60 * 1000,
"2h": 2 * 60 * 60 * 1000,
"4h": 4 * 60 * 60 * 1000,
"6h": 6 * 60 * 60 * 1000,
"8h": 8 * 60 * 60 * 1000,
"12h": 12 * 60 * 60 * 1000,
"1d": 24 * 60 * 60 * 1000,
"3d": 3 * 24 * 60 * 60 * 1000,
"1w": 7 * 24 * 60 * 60 * 1000,
}

View File

@@ -0,0 +1,217 @@
"""Candle data validation helpers for the data layer."""
from datetime import datetime, timedelta
from decimal import Decimal
from typing import Any
import structlog
from tradefinder.adapters.types import Candle
from tradefinder.data.fetcher import TIMEFRAME_MS
from tradefinder.data.storage import DataStorage
logger = structlog.get_logger(__name__)
class DataValidator:
"""Validate candles and detect gaps in DuckDB storage."""
@classmethod
def validate_candle(cls, candle: Candle) -> list[str]:
"""Return validation errors for a single candle."""
errors: list[str] = []
high = candle.high
low = candle.low
open_ = candle.open
close = candle.close
volume = candle.volume
if high < low:
errors.append("high < low")
if high < open_:
errors.append("high < open")
if high < close:
errors.append("high < close")
if low > open_:
errors.append("low > open")
if low > close:
errors.append("low > close")
if volume < Decimal("0"):
errors.append("volume < 0")
if not isinstance(candle.timestamp, datetime):
errors.append("timestamp must be datetime")
if errors:
logger.debug(
"Candle validation failed",
timestamp=candle.timestamp,
errors=errors,
)
return errors
@classmethod
def validate_candles(cls, candles: list[Candle]) -> list[str]:
"""Validate a batch of candles and collect error messages."""
errors: list[str] = []
for candle in candles:
candle_errors = cls.validate_candle(candle)
if candle_errors:
timestamp_repr = (
candle.timestamp.isoformat()
if isinstance(candle.timestamp, datetime)
else str(candle.timestamp)
)
errors.append(f"{timestamp_repr} | {', '.join(candle_errors)}")
return errors
@classmethod
def find_gaps(
cls,
storage: DataStorage,
symbol: str,
timeframe: str,
start: datetime,
end: datetime,
) -> list[tuple[datetime, datetime]]:
"""Detect missing candles between start and end timestamps."""
if start > end:
raise ValueError("start must be before end")
interval = cls._interval_for_timeframe(timeframe)
query = """
SELECT timestamp
FROM candles
WHERE symbol = ?
AND timeframe = ?
AND timestamp >= ?
AND timestamp <= ?
ORDER BY timestamp ASC
"""
rows = storage.conn.execute(query, [symbol, timeframe, start, end]).fetchall()
timestamps: list[datetime] = [row[0] for row in rows if isinstance(row[0], datetime)]
if not timestamps:
if start <= end:
logger.debug(
"No candles found in range",
symbol=symbol,
timeframe=timeframe,
start=start,
end=end,
)
return [(start, end)]
return []
gaps: list[tuple[datetime, datetime]] = []
last_ts: datetime | None = None
for ts in timestamps:
if last_ts is None:
delta = ts - start
if delta > interval:
gap_start = start
gap_end = ts
if gap_start <= gap_end:
gaps.append((gap_start, gap_end))
last_ts = ts
continue
delta = ts - last_ts
if delta > interval:
gap_start = last_ts + interval
gap_end = ts
if gap_start <= gap_end:
gaps.append((gap_start, gap_end))
last_ts = ts
if last_ts is not None:
tail_delta = end - last_ts
if tail_delta > interval:
gap_start = last_ts + interval
gap_end = end
if gap_start <= gap_end:
gaps.append((gap_start, gap_end))
logger.debug(
"Gap check complete",
symbol=symbol,
timeframe=timeframe,
span_start=start,
span_end=end,
gaps=len(gaps),
)
return gaps
@classmethod
def get_gap_report(
cls,
storage: DataStorage,
symbol: str,
timeframe: str,
) -> dict[str, Any]:
"""Return a summary of gaps for a symbol/timeframe."""
bounds = storage.conn.execute(
"""
SELECT MIN(timestamp), MAX(timestamp)
FROM candles
WHERE symbol = ?
AND timeframe = ?
""",
[symbol, timeframe],
).fetchone()
start, end = bounds if bounds else (None, None)
report: dict[str, Any] = {
"symbol": symbol,
"timeframe": timeframe,
"checked_from": start,
"checked_to": end,
"gap_count": 0,
"total_gap_seconds": 0.0,
"max_gap_seconds": 0.0,
"gaps": [],
}
if start is None or end is None:
logger.info("Gap report has no data", symbol=symbol, timeframe=timeframe)
return report
gaps = cls.find_gaps(storage, symbol, timeframe, start, end)
durations = [(gap_end - gap_start).total_seconds() for gap_start, gap_end in gaps]
report.update(
gap_count=len(gaps),
total_gap_seconds=sum(durations),
max_gap_seconds=max(durations) if durations else 0.0,
gaps=[
{
"start": gap_start,
"end": gap_end,
"duration_seconds": (gap_end - gap_start).total_seconds(),
}
for gap_start, gap_end in gaps
],
)
logger.info(
"Gap report generated",
symbol=symbol,
timeframe=timeframe,
gap_count=report["gap_count"],
total_gap_seconds=report["total_gap_seconds"],
)
return report
@classmethod
def _interval_for_timeframe(cls, timeframe: str) -> timedelta:
ms = TIMEFRAME_MS.get(timeframe)
if ms is None:
raise ValueError(f"Unknown timeframe: {timeframe}")
return timedelta(milliseconds=ms)

View File

@@ -0,0 +1,6 @@
"""Trading strategies module."""
from tradefinder.strategies.base import Strategy
from tradefinder.strategies.signals import Signal, SignalType
__all__ = ["Signal", "SignalType", "Strategy"]

View File

@@ -0,0 +1,112 @@
"""Base strategy interface for trading strategies."""
from abc import ABC, abstractmethod
from decimal import Decimal
from typing import Any
import structlog
from tradefinder.adapters.types import Candle, Side
from tradefinder.core.regime import Regime
from tradefinder.strategies.signals import Signal
logger = structlog.get_logger(__name__)
class Strategy(ABC):
"""Abstract base class for trading strategies.
All concrete strategies must implement this interface to be compatible
with the trading engine's strategy selection and signal generation.
Example:
class SupertrendStrategy(Strategy):
name = "supertrend"
def generate_signal(self, candles: list[Candle]) -> Signal | None:
# Analyze candles and return signal if conditions met
...
def get_stop_loss(self, entry_price: Decimal, side: Side) -> Decimal:
# Calculate stop loss based on ATR or fixed percentage
...
@property
def parameters(self) -> dict[str, Any]:
return {"period": self._period, "multiplier": self._multiplier}
@property
def suitable_regimes(self) -> list[Regime]:
return [Regime.TRENDING_UP, Regime.TRENDING_DOWN]
"""
name: str # Strategy identifier, must be set by subclasses
@abstractmethod
def generate_signal(self, candles: list[Candle]) -> Signal | None:
"""Analyze candles and generate a trading signal if conditions are met.
Args:
candles: List of OHLCV candles, ordered oldest to newest.
Must contain sufficient history for indicator calculation.
Returns:
Signal object if entry/exit conditions are met, None otherwise.
"""
pass
@abstractmethod
def get_stop_loss(self, entry_price: Decimal, side: Side) -> Decimal:
"""Calculate stop loss price for a given entry.
Args:
entry_price: The planned or actual entry price.
side: Whether this is a BUY (long) or SELL (short) position.
Returns:
Stop loss price. For longs, this should be below entry_price.
For shorts, this should be above entry_price.
"""
pass
@property
@abstractmethod
def parameters(self) -> dict[str, Any]:
"""Return current strategy parameters for display and logging.
Returns:
Dictionary of parameter names to values. Used by UI and logging
to show what settings the strategy is using.
"""
pass
@property
@abstractmethod
def suitable_regimes(self) -> list[Regime]:
"""Return list of regimes where this strategy should be active.
Returns:
List of Regime values. The strategy will only be selected
when the current market regime matches one of these.
"""
pass
def validate_candles(self, candles: list[Candle], min_required: int) -> bool:
"""Check if sufficient candle data is available for signal generation.
Args:
candles: List of candles to validate.
min_required: Minimum number of candles needed.
Returns:
True if enough candles are available, False otherwise.
"""
if len(candles) < min_required:
logger.debug(
"Insufficient candles for strategy",
strategy=self.name,
available=len(candles),
required=min_required,
)
return False
return True

View File

@@ -0,0 +1,87 @@
"""Signal types for trading strategies."""
from dataclasses import dataclass, field
from datetime import datetime
from decimal import Decimal
from enum import Enum
from typing import Any
import structlog
logger = structlog.get_logger(__name__)
class SignalType(str, Enum):
"""Type of trading signal."""
ENTRY_LONG = "entry_long"
ENTRY_SHORT = "entry_short"
EXIT_LONG = "exit_long"
EXIT_SHORT = "exit_short"
@dataclass
class Signal:
"""Trading signal emitted by a strategy.
Attributes:
signal_type: Type of signal (entry/exit, long/short)
symbol: Trading symbol (e.g., "BTCUSDT")
price: Suggested entry/exit price
stop_loss: Stop loss price for risk calculation
take_profit: Optional take profit price
confidence: Signal confidence from 0.0 to 1.0
timestamp: When the signal was generated
strategy_name: Name of the strategy that generated the signal
metadata: Additional signal-specific data
"""
signal_type: SignalType
symbol: str
price: Decimal
stop_loss: Decimal
take_profit: Decimal | None
confidence: float
timestamp: datetime
strategy_name: str
metadata: dict[str, Any] = field(default_factory=dict)
def __post_init__(self) -> None:
"""Validate signal parameters."""
if not 0.0 <= self.confidence <= 1.0:
raise ValueError(f"Confidence must be between 0.0 and 1.0, got {self.confidence}")
if self.price <= Decimal("0"):
raise ValueError(f"Price must be positive, got {self.price}")
if self.stop_loss <= Decimal("0"):
raise ValueError(f"Stop loss must be positive, got {self.stop_loss}")
@property
def is_entry(self) -> bool:
"""Check if this is an entry signal."""
return self.signal_type in (SignalType.ENTRY_LONG, SignalType.ENTRY_SHORT)
@property
def is_exit(self) -> bool:
"""Check if this is an exit signal."""
return self.signal_type in (SignalType.EXIT_LONG, SignalType.EXIT_SHORT)
@property
def is_long(self) -> bool:
"""Check if this is a long-side signal."""
return self.signal_type in (SignalType.ENTRY_LONG, SignalType.EXIT_LONG)
@property
def is_short(self) -> bool:
"""Check if this is a short-side signal."""
return self.signal_type in (SignalType.ENTRY_SHORT, SignalType.EXIT_SHORT)
@property
def risk_reward_ratio(self) -> Decimal | None:
"""Calculate risk/reward ratio if take profit is set."""
if self.take_profit is None:
return None
risk = abs(self.price - self.stop_loss)
reward = abs(self.take_profit - self.price)
if risk == Decimal("0"):
return None
return reward / risk