import unittest from datetime import datetime, timedelta from unittest.mock import MagicMock from app.services.stats_service import StatsService from app.models import PlayHistory, Track, Artist class TestStatsService(unittest.TestCase): def setUp(self): self.mock_db = MagicMock() self.service = StatsService(self.mock_db) def test_compute_volume_stats_empty(self): # Mock empty query result self.mock_db.query.return_value.filter.return_value.all.return_value = [] start = datetime.utcnow() end = datetime.utcnow() stats = self.service.compute_volume_stats(start, end) self.assertEqual(stats["total_plays"], 0) self.assertEqual(stats["unique_tracks"], 0) def test_compute_session_stats(self): # Create dummy plays t1 = datetime(2023, 1, 1, 10, 0, 0) t2 = datetime(2023, 1, 1, 10, 5, 0) # 5 min gap (same session) t3 = datetime(2023, 1, 1, 12, 0, 0) # 1h 55m gap (new session) plays = [ PlayHistory(played_at=t1, track_id="1"), PlayHistory(played_at=t2, track_id="2"), PlayHistory(played_at=t3, track_id="3"), ] # Mock the query chain # service.db.query().filter().order_by().all() query_mock = self.mock_db.query.return_value.filter.return_value.order_by.return_value query_mock.all.return_value = plays stats = self.service.compute_session_stats(datetime.utcnow(), datetime.utcnow()) # Expected: 2 sessions ([t1, t2], [t3]) self.assertEqual(stats["count"], 2) # Avg tracks: 3 plays / 2 sessions = 1.5 self.assertEqual(stats["avg_tracks"], 1.5) def test_compute_skip_stats(self): # Track duration = 30s track = Track(id="t1", duration_ms=30000) # Play 1: 10:00:00 # Play 2: 10:00:10 (Diff 10s. Duration 30s. 10 < 20 (30-10) -> Skip) p1 = PlayHistory(played_at=datetime(2023, 1, 1, 10, 0, 0), track_id="t1") p2 = PlayHistory(played_at=datetime(2023, 1, 1, 10, 0, 10), track_id="t1") plays = [p1, p2] query_mock = self.mock_db.query.return_value.filter.return_value.order_by.return_value query_mock.all.return_value = plays # Mock track lookup self.mock_db.query.return_value.filter.return_value.all.return_value = [track] stats = self.service.compute_skip_stats(datetime.utcnow(), datetime.utcnow()) self.assertEqual(stats["total_skips"], 1) if __name__ == '__main__': unittest.main()