Add Phase 2 foundation: regime classifier, strategy framework, WebSocket streamer
Phase 1 completion: - Add DataStreamer for real-time Binance Futures WebSocket data (klines, mark price) - Add DataValidator for candle validation and gap detection - Add timeframes module with interval mappings Phase 2 foundation: - Add RegimeClassifier with ADX/ATR/Bollinger Band analysis - Add Regime enum (TRENDING_UP/DOWN, RANGING, HIGH_VOLATILITY, UNCERTAIN) - Add Strategy ABC defining generate_signal, get_stop_loss, parameters, suitable_regimes - Add Signal dataclass and SignalType enum for strategy outputs Testing: - Add comprehensive test suites for all new modules - 159 tests passing, 24 skipped (async WebSocket timing) - 82% code coverage Dependencies: - Add pandas-stubs to dev dependencies for mypy compatibility
This commit is contained in:
@@ -14,6 +14,7 @@ from tradefinder.core.config import (
|
||||
get_settings,
|
||||
reset_settings,
|
||||
)
|
||||
from tradefinder.core.regime import Regime, RegimeClassifier
|
||||
|
||||
__all__ = [
|
||||
"Settings",
|
||||
@@ -21,4 +22,6 @@ __all__ = [
|
||||
"LogFormat",
|
||||
"get_settings",
|
||||
"reset_settings",
|
||||
"Regime",
|
||||
"RegimeClassifier",
|
||||
]
|
||||
|
||||
169
src/tradefinder/core/regime.py
Normal file
169
src/tradefinder/core/regime.py
Normal file
@@ -0,0 +1,169 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
|
||||
import pandas as pd
|
||||
import pandas_ta as ta
|
||||
import structlog
|
||||
|
||||
from tradefinder.adapters.types import Candle
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class Regime(str, Enum):
|
||||
"""Market regime categories derived from technical indicators."""
|
||||
|
||||
TRENDING_UP = "trending_up"
|
||||
TRENDING_DOWN = "trending_down"
|
||||
RANGING = "ranging"
|
||||
HIGH_VOLATILITY = "high_volatility"
|
||||
UNCERTAIN = "uncertain"
|
||||
|
||||
|
||||
class RegimeClassifier:
|
||||
"""Detects regime using ADX, ATR%, and Bollinger Band width."""
|
||||
|
||||
_ranging_adx_ceiling = Decimal("20")
|
||||
_bb_width_threshold = Decimal("0.05")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
adx_threshold: int = 25,
|
||||
atr_lookback: int = 14,
|
||||
bb_period: int = 20,
|
||||
) -> None:
|
||||
self._adx_threshold = Decimal(adx_threshold)
|
||||
self._atr_lookback = max(1, atr_lookback)
|
||||
self._bb_period = max(1, bb_period)
|
||||
|
||||
def classify(self, candles: list[Candle]) -> Regime:
|
||||
"""Return the current regime based on the latest indicator values."""
|
||||
|
||||
indicators = self.get_indicators(candles)
|
||||
if not indicators:
|
||||
logger.debug("Insufficient data to compute regime", candle_count=len(candles))
|
||||
return Regime.UNCERTAIN
|
||||
|
||||
adx = indicators.get("adx")
|
||||
plus_di = indicators.get("plus_di")
|
||||
minus_di = indicators.get("minus_di")
|
||||
bb_width = indicators.get("bb_width")
|
||||
atr_pct = indicators.get("atr_pct")
|
||||
atr_pct_avg = indicators.get("atr_pct_avg")
|
||||
|
||||
if adx is not None and adx > self._adx_threshold:
|
||||
if plus_di is not None and minus_di is not None:
|
||||
if plus_di > minus_di:
|
||||
return Regime.TRENDING_UP
|
||||
if minus_di > plus_di:
|
||||
return Regime.TRENDING_DOWN
|
||||
|
||||
if (
|
||||
adx is not None
|
||||
and adx < self._ranging_adx_ceiling
|
||||
and bb_width is not None
|
||||
and bb_width < self._bb_width_threshold
|
||||
):
|
||||
return Regime.RANGING
|
||||
|
||||
if (
|
||||
atr_pct is not None
|
||||
and atr_pct_avg is not None
|
||||
and atr_pct_avg > Decimal("0")
|
||||
and atr_pct > atr_pct_avg * Decimal("1.5")
|
||||
):
|
||||
return Regime.HIGH_VOLATILITY
|
||||
|
||||
logger.debug("Unable to classify regime definitively", indicators=indicators)
|
||||
return Regime.UNCERTAIN
|
||||
|
||||
def get_indicators(self, candles: list[Candle]) -> dict[str, Decimal | None]:
|
||||
"""Compute the latest indicators required for regime detection."""
|
||||
|
||||
df = self._candles_to_frame(candles)
|
||||
if df.empty or len(df) < max(self._atr_lookback, self._bb_period):
|
||||
return {}
|
||||
|
||||
length = self._atr_lookback
|
||||
adx_df = ta.adx(high=df["high"], low=df["low"], close=df["close"], length=length)
|
||||
adx_key = f"ADX_{length}"
|
||||
plus_key = f"DMP_{length}"
|
||||
minus_key = f"DMN_{length}"
|
||||
|
||||
if adx_key not in adx_df or plus_key not in adx_df or minus_key not in adx_df:
|
||||
return {}
|
||||
|
||||
adx = self._to_decimal(adx_df[adx_key].iloc[-1])
|
||||
plus_di = self._to_decimal(adx_df[plus_key].iloc[-1])
|
||||
minus_di = self._to_decimal(adx_df[minus_key].iloc[-1])
|
||||
|
||||
atr_series = ta.atr(high=df["high"], low=df["low"], close=df["close"], length=length)
|
||||
atr_key = f"ATR_{length}"
|
||||
if atr_key not in atr_series:
|
||||
return {}
|
||||
|
||||
atr_series = atr_series[atr_key]
|
||||
atr_latest = self._to_decimal(atr_series.iloc[-1])
|
||||
|
||||
atr_pct_series = (atr_series / df["close"]) * 100
|
||||
atr_pct_series = atr_pct_series.dropna()
|
||||
atr_pct = self._decimal_from_series_tail(atr_pct_series)
|
||||
atr_pct_avg = self._decimal_from_series_avg(atr_pct_series)
|
||||
|
||||
bb_df = ta.bbands(close=df["close"], length=self._bb_period)
|
||||
bb_lower = bb_df.get(f"BBL_{self._bb_period}_2.0")
|
||||
bb_middle = bb_df.get(f"BBM_{self._bb_period}_2.0")
|
||||
bb_upper = bb_df.get(f"BBU_{self._bb_period}_2.0")
|
||||
|
||||
if bb_lower is None or bb_middle is None or bb_upper is None:
|
||||
bb_width = None
|
||||
else:
|
||||
width_series = (bb_upper - bb_lower) / bb_middle.replace(0, pd.NA)
|
||||
bb_width = self._decimal_from_series_tail(width_series)
|
||||
|
||||
indicators = {
|
||||
"adx": adx,
|
||||
"plus_di": plus_di,
|
||||
"minus_di": minus_di,
|
||||
"atr": atr_latest,
|
||||
"atr_pct": atr_pct,
|
||||
"atr_pct_avg": atr_pct_avg,
|
||||
"bb_width": bb_width,
|
||||
}
|
||||
|
||||
logger.debug("Computed regime indicators", indicators=indicators)
|
||||
return indicators
|
||||
|
||||
@staticmethod
|
||||
def _to_decimal(value: float | int | Decimal | None) -> Decimal | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
if pd.isna(value): # type: ignore[arg-type]
|
||||
return None
|
||||
except (TypeError, ValueError):
|
||||
pass
|
||||
return Decimal(str(value))
|
||||
|
||||
@staticmethod
|
||||
def _decimal_from_series_tail(series: pd.Series) -> Decimal | None:
|
||||
if series.empty:
|
||||
return None
|
||||
return RegimeClassifier._to_decimal(series.iloc[-1])
|
||||
|
||||
@staticmethod
|
||||
def _decimal_from_series_avg(series: pd.Series) -> Decimal | None:
|
||||
if series.empty:
|
||||
return None
|
||||
return RegimeClassifier._to_decimal(series.mean())
|
||||
|
||||
@staticmethod
|
||||
def _candles_to_frame(candles: list[Candle]) -> pd.DataFrame:
|
||||
if not candles:
|
||||
return pd.DataFrame()
|
||||
|
||||
frame = pd.DataFrame([candle.to_dict() for candle in candles])
|
||||
frame["timestamp"] = pd.to_datetime(frame["timestamp"], utc=True)
|
||||
return frame
|
||||
@@ -21,5 +21,6 @@ Usage:
|
||||
|
||||
from tradefinder.data.fetcher import DataFetcher
|
||||
from tradefinder.data.storage import DataStorage
|
||||
from tradefinder.data.streamer import DataStreamer
|
||||
|
||||
__all__ = ["DataStorage", "DataFetcher"]
|
||||
__all__ = ["DataStorage", "DataFetcher", "DataStreamer", "DataValidator"]
|
||||
|
||||
265
src/tradefinder/data/streamer.py
Normal file
265
src/tradefinder/data/streamer.py
Normal file
@@ -0,0 +1,265 @@
|
||||
"""WebSocket streamer for Binance Futures market data."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
from collections.abc import Awaitable, Callable, Iterable, Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
import websockets
|
||||
from websockets.exceptions import ConnectionClosedError, ConnectionClosedOK
|
||||
|
||||
from tradefinder.core.config import Settings
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class KlineMessage:
|
||||
"""Candlestick payload emitted by the streamer."""
|
||||
|
||||
stream: str
|
||||
symbol: str
|
||||
timeframe: str
|
||||
event_time: datetime
|
||||
open_time: datetime
|
||||
close_time: datetime
|
||||
open: Decimal
|
||||
high: Decimal
|
||||
low: Decimal
|
||||
close: Decimal
|
||||
volume: Decimal
|
||||
trades: int
|
||||
is_closed: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class MarkPriceMessage:
|
||||
"""Mark price payload emitted by the streamer."""
|
||||
|
||||
stream: str
|
||||
symbol: str
|
||||
event_time: datetime
|
||||
mark_price: Decimal
|
||||
index_price: Decimal
|
||||
funding_rate: Decimal
|
||||
next_funding_time: datetime
|
||||
|
||||
|
||||
KlineCallback = Callable[[KlineMessage], Awaitable[None] | None]
|
||||
MarkPriceCallback = Callable[[MarkPriceMessage], Awaitable[None] | None]
|
||||
|
||||
|
||||
class DataStreamer:
|
||||
"""Async streamer for Binance Futures WebSocket data."""
|
||||
|
||||
DEFAULT_SYMBOLS: Sequence[str] = ("BTCUSDT", "ETHUSDT")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
settings: Settings,
|
||||
symbols: Iterable[str] | None = None,
|
||||
timeframe: str | None = None,
|
||||
min_backoff: float = 1.0,
|
||||
max_backoff: float = 60.0,
|
||||
backoff_multiplier: float = 2.0,
|
||||
) -> None:
|
||||
"""Create a Binance futures data streamer.
|
||||
|
||||
Args:
|
||||
settings: Application settings containing trading mode.
|
||||
symbols: Trading symbols to stream (defaults to settings symbols plus BTC/ETH).
|
||||
timeframe: Candle timeframe for kline streams.
|
||||
min_backoff: Minimum reconnect delay in seconds.
|
||||
max_backoff: Maximum reconnect delay in seconds.
|
||||
backoff_multiplier: Factor to increase backoff on each failure.
|
||||
"""
|
||||
self._settings = settings
|
||||
self._timeframe = (timeframe or settings.execution_timeframe).lower()
|
||||
self._symbols = tuple(self._normalize_symbols(symbols))
|
||||
self._kline_streams = tuple(s.lower() + f"@kline_{self._timeframe}" for s in self._symbols)
|
||||
self._mark_price_streams = tuple(s.lower() + "@markPrice@1s" for s in self._symbols)
|
||||
self._kline_callbacks: list[KlineCallback] = []
|
||||
self._mark_price_callbacks: list[MarkPriceCallback] = []
|
||||
self._task: asyncio.Task[None] | None = None
|
||||
self._stop_event: asyncio.Event | None = None
|
||||
self._connection: Any = None # WebSocket connection (type varies by websockets version)
|
||||
self._min_backoff = min_backoff
|
||||
self._max_backoff = max_backoff
|
||||
self._backoff_multiplier = backoff_multiplier
|
||||
self._stream_path = "/stream?streams=" + "/".join(
|
||||
(*self._kline_streams, *self._mark_price_streams)
|
||||
)
|
||||
|
||||
@property
|
||||
def symbols(self) -> tuple[str, ...]:
|
||||
"""Symbols currently being streamed."""
|
||||
return self._symbols
|
||||
|
||||
def register_kline_callback(self, callback: KlineCallback) -> None:
|
||||
"""Register a callback invoked for each kline message."""
|
||||
self._kline_callbacks.append(callback)
|
||||
|
||||
def register_mark_price_callback(self, callback: MarkPriceCallback) -> None:
|
||||
"""Register a callback invoked for each mark price message."""
|
||||
self._mark_price_callbacks.append(callback)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Start the streamer in the background."""
|
||||
if self._task and not self._task.done():
|
||||
return
|
||||
self._stop_event = asyncio.Event()
|
||||
self._task = asyncio.create_task(self._run())
|
||||
|
||||
async def run(self) -> None:
|
||||
"""Run the streamer until stopped."""
|
||||
await self.start()
|
||||
if self._task:
|
||||
await self._task
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Stop streaming and close the connection."""
|
||||
if self._stop_event and not self._stop_event.is_set():
|
||||
self._stop_event.set()
|
||||
if self._connection:
|
||||
await self._connection.close(code=1000, reason="shutdown")
|
||||
if self._task:
|
||||
await self._task
|
||||
self._task = None
|
||||
|
||||
async def __aenter__(self) -> DataStreamer:
|
||||
await self.start()
|
||||
return self
|
||||
|
||||
async def __aexit__(
|
||||
self, exc_type: type[BaseException] | None, exc: BaseException | None, tb: Any | None
|
||||
) -> None:
|
||||
await self.stop()
|
||||
|
||||
async def _run(self) -> None:
|
||||
backoff = self._min_backoff
|
||||
while not self._should_stop:
|
||||
stream_url = f"{self._settings.binance_ws_url}{self._stream_path}"
|
||||
try:
|
||||
async with websockets.connect(
|
||||
stream_url,
|
||||
ping_interval=60.0,
|
||||
ping_timeout=30.0,
|
||||
) as connection:
|
||||
self._connection = connection
|
||||
backoff = self._min_backoff
|
||||
logger.info("Connected to Binance Futures stream", url=stream_url)
|
||||
while not self._should_stop:
|
||||
raw = await connection.recv()
|
||||
await self._handle_raw(raw)
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except (ConnectionClosedOK, ConnectionClosedError) as exc:
|
||||
logger.warning("WebSocket connection closed", error=str(exc))
|
||||
except Exception as exc: # pragma: no cover - best effort logging
|
||||
logger.error("WebSocket error", error=str(exc))
|
||||
finally:
|
||||
self._connection = None
|
||||
if self._should_stop:
|
||||
break
|
||||
await asyncio.sleep(backoff)
|
||||
backoff = min(backoff * self._backoff_multiplier, self._max_backoff)
|
||||
logger.info("Reconnecting to Binance Futures stream", delay=backoff)
|
||||
|
||||
@property
|
||||
def _should_stop(self) -> bool:
|
||||
return bool(self._stop_event and self._stop_event.is_set())
|
||||
|
||||
async def _handle_raw(self, raw: str | bytes) -> None:
|
||||
try:
|
||||
text = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
||||
payload = json.loads(text)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as exc:
|
||||
logger.warning("Unable to decode WebSocket payload", error=str(exc))
|
||||
return
|
||||
|
||||
stream = payload.get("stream", "")
|
||||
data = payload.get("data", payload)
|
||||
lower_stream = stream.lower()
|
||||
event_type = str(data.get("e", "")).lower()
|
||||
kline_key = f"@kline_{self._timeframe}"
|
||||
mark_price_key = "@markprice@1s"
|
||||
|
||||
if lower_stream.endswith(kline_key) or event_type == "kline":
|
||||
await self._handle_kline(data, stream)
|
||||
elif mark_price_key in lower_stream or event_type == "markprice":
|
||||
await self._handle_mark_price(data, stream)
|
||||
else:
|
||||
logger.debug(
|
||||
"Unknown stream payload",
|
||||
stream=stream,
|
||||
event_type=event_type,
|
||||
payload=data,
|
||||
)
|
||||
|
||||
async def _handle_kline(self, data: dict[str, Any], stream: str) -> None:
|
||||
event_time = int(data.get("E", 0))
|
||||
kline = data.get("k", {})
|
||||
message = KlineMessage(
|
||||
stream=stream,
|
||||
symbol=kline.get("s", ""),
|
||||
timeframe=kline.get("i", self._timeframe),
|
||||
event_time=self._datetime_from_ms(event_time),
|
||||
open_time=self._datetime_from_ms(int(kline.get("t", 0))),
|
||||
close_time=self._datetime_from_ms(int(kline.get("T", 0))),
|
||||
open=self._to_decimal(kline.get("o")),
|
||||
high=self._to_decimal(kline.get("h")),
|
||||
low=self._to_decimal(kline.get("l")),
|
||||
close=self._to_decimal(kline.get("c")),
|
||||
volume=self._to_decimal(kline.get("v")),
|
||||
trades=int(kline.get("n", 0)),
|
||||
is_closed=bool(kline.get("x", False)),
|
||||
)
|
||||
await self._dispatch_callbacks(self._kline_callbacks, message)
|
||||
|
||||
async def _handle_mark_price(self, data: dict[str, Any], stream: str) -> None:
|
||||
event_time = int(data.get("E", 0))
|
||||
message = MarkPriceMessage(
|
||||
stream=stream,
|
||||
symbol=data.get("s", ""),
|
||||
event_time=self._datetime_from_ms(event_time),
|
||||
mark_price=self._to_decimal(data.get("p")),
|
||||
index_price=self._to_decimal(data.get("i")),
|
||||
funding_rate=self._to_decimal(data.get("r")),
|
||||
next_funding_time=self._datetime_from_ms(int(data.get("T", 0))),
|
||||
)
|
||||
await self._dispatch_callbacks(self._mark_price_callbacks, message)
|
||||
|
||||
async def _dispatch_callbacks(
|
||||
self, callbacks: list[Callable[[Any], Awaitable[None] | None]], message: Any
|
||||
) -> None:
|
||||
for callback in callbacks:
|
||||
try:
|
||||
result = callback(message)
|
||||
if inspect.isawaitable(result):
|
||||
await result
|
||||
except Exception as exc: # pragma: no cover - best effort
|
||||
callback_name = getattr(callback, "__name__", repr(callback))
|
||||
logger.error("Callback failed", callback=callback_name, error=str(exc))
|
||||
|
||||
@staticmethod
|
||||
def _datetime_from_ms(timestamp_ms: int) -> datetime:
|
||||
return datetime.fromtimestamp(timestamp_ms / 1000, tz=UTC)
|
||||
|
||||
@staticmethod
|
||||
def _to_decimal(value: Any) -> Decimal:
|
||||
if isinstance(value, Decimal):
|
||||
return value
|
||||
return Decimal(str(value)) if value not in (None, "") else Decimal("0")
|
||||
|
||||
def _normalize_symbols(self, symbols: Iterable[str] | None) -> tuple[str, ...]:
|
||||
candidates = tuple(symbols or self._settings.symbols_list)
|
||||
normalized = {sym.upper() for sym in candidates if sym}
|
||||
normalized.update(self.DEFAULT_SYMBOLS)
|
||||
return tuple(sorted(normalized))
|
||||
18
src/tradefinder/data/timeframes.py
Normal file
18
src/tradefinder/data/timeframes.py
Normal file
@@ -0,0 +1,18 @@
|
||||
"""Shared timeframe helpers for the data layer."""
|
||||
|
||||
TIMEFRAME_MS: dict[str, int] = {
|
||||
"1m": 60 * 1000,
|
||||
"3m": 3 * 60 * 1000,
|
||||
"5m": 5 * 60 * 1000,
|
||||
"15m": 15 * 60 * 1000,
|
||||
"30m": 30 * 60 * 1000,
|
||||
"1h": 60 * 60 * 1000,
|
||||
"2h": 2 * 60 * 60 * 1000,
|
||||
"4h": 4 * 60 * 60 * 1000,
|
||||
"6h": 6 * 60 * 60 * 1000,
|
||||
"8h": 8 * 60 * 60 * 1000,
|
||||
"12h": 12 * 60 * 60 * 1000,
|
||||
"1d": 24 * 60 * 60 * 1000,
|
||||
"3d": 3 * 24 * 60 * 60 * 1000,
|
||||
"1w": 7 * 24 * 60 * 60 * 1000,
|
||||
}
|
||||
217
src/tradefinder/data/validator.py
Normal file
217
src/tradefinder/data/validator.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Candle data validation helpers for the data layer."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
from tradefinder.adapters.types import Candle
|
||||
from tradefinder.data.fetcher import TIMEFRAME_MS
|
||||
from tradefinder.data.storage import DataStorage
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class DataValidator:
|
||||
"""Validate candles and detect gaps in DuckDB storage."""
|
||||
|
||||
@classmethod
|
||||
def validate_candle(cls, candle: Candle) -> list[str]:
|
||||
"""Return validation errors for a single candle."""
|
||||
|
||||
errors: list[str] = []
|
||||
high = candle.high
|
||||
low = candle.low
|
||||
open_ = candle.open
|
||||
close = candle.close
|
||||
volume = candle.volume
|
||||
|
||||
if high < low:
|
||||
errors.append("high < low")
|
||||
if high < open_:
|
||||
errors.append("high < open")
|
||||
if high < close:
|
||||
errors.append("high < close")
|
||||
if low > open_:
|
||||
errors.append("low > open")
|
||||
if low > close:
|
||||
errors.append("low > close")
|
||||
if volume < Decimal("0"):
|
||||
errors.append("volume < 0")
|
||||
if not isinstance(candle.timestamp, datetime):
|
||||
errors.append("timestamp must be datetime")
|
||||
|
||||
if errors:
|
||||
logger.debug(
|
||||
"Candle validation failed",
|
||||
timestamp=candle.timestamp,
|
||||
errors=errors,
|
||||
)
|
||||
return errors
|
||||
|
||||
@classmethod
|
||||
def validate_candles(cls, candles: list[Candle]) -> list[str]:
|
||||
"""Validate a batch of candles and collect error messages."""
|
||||
|
||||
errors: list[str] = []
|
||||
for candle in candles:
|
||||
candle_errors = cls.validate_candle(candle)
|
||||
if candle_errors:
|
||||
timestamp_repr = (
|
||||
candle.timestamp.isoformat()
|
||||
if isinstance(candle.timestamp, datetime)
|
||||
else str(candle.timestamp)
|
||||
)
|
||||
errors.append(f"{timestamp_repr} | {', '.join(candle_errors)}")
|
||||
return errors
|
||||
|
||||
@classmethod
|
||||
def find_gaps(
|
||||
cls,
|
||||
storage: DataStorage,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
start: datetime,
|
||||
end: datetime,
|
||||
) -> list[tuple[datetime, datetime]]:
|
||||
"""Detect missing candles between start and end timestamps."""
|
||||
|
||||
if start > end:
|
||||
raise ValueError("start must be before end")
|
||||
|
||||
interval = cls._interval_for_timeframe(timeframe)
|
||||
|
||||
query = """
|
||||
SELECT timestamp
|
||||
FROM candles
|
||||
WHERE symbol = ?
|
||||
AND timeframe = ?
|
||||
AND timestamp >= ?
|
||||
AND timestamp <= ?
|
||||
ORDER BY timestamp ASC
|
||||
"""
|
||||
|
||||
rows = storage.conn.execute(query, [symbol, timeframe, start, end]).fetchall()
|
||||
timestamps: list[datetime] = [row[0] for row in rows if isinstance(row[0], datetime)]
|
||||
|
||||
if not timestamps:
|
||||
if start <= end:
|
||||
logger.debug(
|
||||
"No candles found in range",
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
return [(start, end)]
|
||||
return []
|
||||
|
||||
gaps: list[tuple[datetime, datetime]] = []
|
||||
last_ts: datetime | None = None
|
||||
|
||||
for ts in timestamps:
|
||||
if last_ts is None:
|
||||
delta = ts - start
|
||||
if delta > interval:
|
||||
gap_start = start
|
||||
gap_end = ts
|
||||
if gap_start <= gap_end:
|
||||
gaps.append((gap_start, gap_end))
|
||||
last_ts = ts
|
||||
continue
|
||||
|
||||
delta = ts - last_ts
|
||||
if delta > interval:
|
||||
gap_start = last_ts + interval
|
||||
gap_end = ts
|
||||
if gap_start <= gap_end:
|
||||
gaps.append((gap_start, gap_end))
|
||||
last_ts = ts
|
||||
|
||||
if last_ts is not None:
|
||||
tail_delta = end - last_ts
|
||||
if tail_delta > interval:
|
||||
gap_start = last_ts + interval
|
||||
gap_end = end
|
||||
if gap_start <= gap_end:
|
||||
gaps.append((gap_start, gap_end))
|
||||
|
||||
logger.debug(
|
||||
"Gap check complete",
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
span_start=start,
|
||||
span_end=end,
|
||||
gaps=len(gaps),
|
||||
)
|
||||
return gaps
|
||||
|
||||
@classmethod
|
||||
def get_gap_report(
|
||||
cls,
|
||||
storage: DataStorage,
|
||||
symbol: str,
|
||||
timeframe: str,
|
||||
) -> dict[str, Any]:
|
||||
"""Return a summary of gaps for a symbol/timeframe."""
|
||||
|
||||
bounds = storage.conn.execute(
|
||||
"""
|
||||
SELECT MIN(timestamp), MAX(timestamp)
|
||||
FROM candles
|
||||
WHERE symbol = ?
|
||||
AND timeframe = ?
|
||||
""",
|
||||
[symbol, timeframe],
|
||||
).fetchone()
|
||||
|
||||
start, end = bounds if bounds else (None, None)
|
||||
|
||||
report: dict[str, Any] = {
|
||||
"symbol": symbol,
|
||||
"timeframe": timeframe,
|
||||
"checked_from": start,
|
||||
"checked_to": end,
|
||||
"gap_count": 0,
|
||||
"total_gap_seconds": 0.0,
|
||||
"max_gap_seconds": 0.0,
|
||||
"gaps": [],
|
||||
}
|
||||
|
||||
if start is None or end is None:
|
||||
logger.info("Gap report has no data", symbol=symbol, timeframe=timeframe)
|
||||
return report
|
||||
|
||||
gaps = cls.find_gaps(storage, symbol, timeframe, start, end)
|
||||
durations = [(gap_end - gap_start).total_seconds() for gap_start, gap_end in gaps]
|
||||
|
||||
report.update(
|
||||
gap_count=len(gaps),
|
||||
total_gap_seconds=sum(durations),
|
||||
max_gap_seconds=max(durations) if durations else 0.0,
|
||||
gaps=[
|
||||
{
|
||||
"start": gap_start,
|
||||
"end": gap_end,
|
||||
"duration_seconds": (gap_end - gap_start).total_seconds(),
|
||||
}
|
||||
for gap_start, gap_end in gaps
|
||||
],
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Gap report generated",
|
||||
symbol=symbol,
|
||||
timeframe=timeframe,
|
||||
gap_count=report["gap_count"],
|
||||
total_gap_seconds=report["total_gap_seconds"],
|
||||
)
|
||||
return report
|
||||
|
||||
@classmethod
|
||||
def _interval_for_timeframe(cls, timeframe: str) -> timedelta:
|
||||
ms = TIMEFRAME_MS.get(timeframe)
|
||||
if ms is None:
|
||||
raise ValueError(f"Unknown timeframe: {timeframe}")
|
||||
return timedelta(milliseconds=ms)
|
||||
@@ -0,0 +1,6 @@
|
||||
"""Trading strategies module."""
|
||||
|
||||
from tradefinder.strategies.base import Strategy
|
||||
from tradefinder.strategies.signals import Signal, SignalType
|
||||
|
||||
__all__ = ["Signal", "SignalType", "Strategy"]
|
||||
|
||||
112
src/tradefinder/strategies/base.py
Normal file
112
src/tradefinder/strategies/base.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Base strategy interface for trading strategies."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from decimal import Decimal
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
from tradefinder.adapters.types import Candle, Side
|
||||
from tradefinder.core.regime import Regime
|
||||
from tradefinder.strategies.signals import Signal
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class Strategy(ABC):
|
||||
"""Abstract base class for trading strategies.
|
||||
|
||||
All concrete strategies must implement this interface to be compatible
|
||||
with the trading engine's strategy selection and signal generation.
|
||||
|
||||
Example:
|
||||
class SupertrendStrategy(Strategy):
|
||||
name = "supertrend"
|
||||
|
||||
def generate_signal(self, candles: list[Candle]) -> Signal | None:
|
||||
# Analyze candles and return signal if conditions met
|
||||
...
|
||||
|
||||
def get_stop_loss(self, entry_price: Decimal, side: Side) -> Decimal:
|
||||
# Calculate stop loss based on ATR or fixed percentage
|
||||
...
|
||||
|
||||
@property
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
return {"period": self._period, "multiplier": self._multiplier}
|
||||
|
||||
@property
|
||||
def suitable_regimes(self) -> list[Regime]:
|
||||
return [Regime.TRENDING_UP, Regime.TRENDING_DOWN]
|
||||
"""
|
||||
|
||||
name: str # Strategy identifier, must be set by subclasses
|
||||
|
||||
@abstractmethod
|
||||
def generate_signal(self, candles: list[Candle]) -> Signal | None:
|
||||
"""Analyze candles and generate a trading signal if conditions are met.
|
||||
|
||||
Args:
|
||||
candles: List of OHLCV candles, ordered oldest to newest.
|
||||
Must contain sufficient history for indicator calculation.
|
||||
|
||||
Returns:
|
||||
Signal object if entry/exit conditions are met, None otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_stop_loss(self, entry_price: Decimal, side: Side) -> Decimal:
|
||||
"""Calculate stop loss price for a given entry.
|
||||
|
||||
Args:
|
||||
entry_price: The planned or actual entry price.
|
||||
side: Whether this is a BUY (long) or SELL (short) position.
|
||||
|
||||
Returns:
|
||||
Stop loss price. For longs, this should be below entry_price.
|
||||
For shorts, this should be above entry_price.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def parameters(self) -> dict[str, Any]:
|
||||
"""Return current strategy parameters for display and logging.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameter names to values. Used by UI and logging
|
||||
to show what settings the strategy is using.
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def suitable_regimes(self) -> list[Regime]:
|
||||
"""Return list of regimes where this strategy should be active.
|
||||
|
||||
Returns:
|
||||
List of Regime values. The strategy will only be selected
|
||||
when the current market regime matches one of these.
|
||||
"""
|
||||
pass
|
||||
|
||||
def validate_candles(self, candles: list[Candle], min_required: int) -> bool:
|
||||
"""Check if sufficient candle data is available for signal generation.
|
||||
|
||||
Args:
|
||||
candles: List of candles to validate.
|
||||
min_required: Minimum number of candles needed.
|
||||
|
||||
Returns:
|
||||
True if enough candles are available, False otherwise.
|
||||
"""
|
||||
if len(candles) < min_required:
|
||||
logger.debug(
|
||||
"Insufficient candles for strategy",
|
||||
strategy=self.name,
|
||||
available=len(candles),
|
||||
required=min_required,
|
||||
)
|
||||
return False
|
||||
return True
|
||||
87
src/tradefinder/strategies/signals.py
Normal file
87
src/tradefinder/strategies/signals.py
Normal file
@@ -0,0 +1,87 @@
|
||||
"""Signal types for trading strategies."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
import structlog
|
||||
|
||||
logger = structlog.get_logger(__name__)
|
||||
|
||||
|
||||
class SignalType(str, Enum):
|
||||
"""Type of trading signal."""
|
||||
|
||||
ENTRY_LONG = "entry_long"
|
||||
ENTRY_SHORT = "entry_short"
|
||||
EXIT_LONG = "exit_long"
|
||||
EXIT_SHORT = "exit_short"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Signal:
|
||||
"""Trading signal emitted by a strategy.
|
||||
|
||||
Attributes:
|
||||
signal_type: Type of signal (entry/exit, long/short)
|
||||
symbol: Trading symbol (e.g., "BTCUSDT")
|
||||
price: Suggested entry/exit price
|
||||
stop_loss: Stop loss price for risk calculation
|
||||
take_profit: Optional take profit price
|
||||
confidence: Signal confidence from 0.0 to 1.0
|
||||
timestamp: When the signal was generated
|
||||
strategy_name: Name of the strategy that generated the signal
|
||||
metadata: Additional signal-specific data
|
||||
"""
|
||||
|
||||
signal_type: SignalType
|
||||
symbol: str
|
||||
price: Decimal
|
||||
stop_loss: Decimal
|
||||
take_profit: Decimal | None
|
||||
confidence: float
|
||||
timestamp: datetime
|
||||
strategy_name: str
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
"""Validate signal parameters."""
|
||||
if not 0.0 <= self.confidence <= 1.0:
|
||||
raise ValueError(f"Confidence must be between 0.0 and 1.0, got {self.confidence}")
|
||||
if self.price <= Decimal("0"):
|
||||
raise ValueError(f"Price must be positive, got {self.price}")
|
||||
if self.stop_loss <= Decimal("0"):
|
||||
raise ValueError(f"Stop loss must be positive, got {self.stop_loss}")
|
||||
|
||||
@property
|
||||
def is_entry(self) -> bool:
|
||||
"""Check if this is an entry signal."""
|
||||
return self.signal_type in (SignalType.ENTRY_LONG, SignalType.ENTRY_SHORT)
|
||||
|
||||
@property
|
||||
def is_exit(self) -> bool:
|
||||
"""Check if this is an exit signal."""
|
||||
return self.signal_type in (SignalType.EXIT_LONG, SignalType.EXIT_SHORT)
|
||||
|
||||
@property
|
||||
def is_long(self) -> bool:
|
||||
"""Check if this is a long-side signal."""
|
||||
return self.signal_type in (SignalType.ENTRY_LONG, SignalType.EXIT_LONG)
|
||||
|
||||
@property
|
||||
def is_short(self) -> bool:
|
||||
"""Check if this is a short-side signal."""
|
||||
return self.signal_type in (SignalType.ENTRY_SHORT, SignalType.EXIT_SHORT)
|
||||
|
||||
@property
|
||||
def risk_reward_ratio(self) -> Decimal | None:
|
||||
"""Calculate risk/reward ratio if take profit is set."""
|
||||
if self.take_profit is None:
|
||||
return None
|
||||
risk = abs(self.price - self.stop_loss)
|
||||
reward = abs(self.take_profit - self.price)
|
||||
if risk == Decimal("0"):
|
||||
return None
|
||||
return reward / risk
|
||||
Reference in New Issue
Block a user