mirror of
https://github.com/bnair123/MusicAnalyser.git
synced 2026-02-25 11:46:07 +00:00
- Add reccobeats_id column to Track model for API mapping - Fix ReccoBeats batch size limit (max 40 IDs per request) - Extract spotify_id from href field in ReccoBeats responses - Fix OpenAI API: remove unsupported temperature param, increase max_completion_tokens to 4000 - Add playlist/user management methods to spotify_client for future auto-playlist feature
170 lines
5.6 KiB
Python
170 lines
5.6 KiB
Python
import os
|
|
import json
|
|
import re
|
|
from typing import Dict, Any
|
|
|
|
try:
|
|
from openai import OpenAI
|
|
except ImportError:
|
|
OpenAI = None
|
|
|
|
try:
|
|
from google import genai
|
|
except ImportError:
|
|
genai = None
|
|
|
|
|
|
class NarrativeService:
|
|
def __init__(self, model_name: str = "gpt-5-mini-2025-08-07"):
|
|
self.model_name = model_name
|
|
self.provider = self._detect_provider()
|
|
self.client = self._init_client()
|
|
|
|
def _detect_provider(self) -> str:
|
|
openai_key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENAI_APIKEY")
|
|
gemini_key = os.getenv("GEMINI_API_KEY")
|
|
|
|
if self.model_name.startswith("gpt") and openai_key and OpenAI:
|
|
return "openai"
|
|
elif gemini_key and genai:
|
|
return "gemini"
|
|
elif openai_key and OpenAI:
|
|
return "openai"
|
|
elif gemini_key and genai:
|
|
return "gemini"
|
|
return "none"
|
|
|
|
def _init_client(self):
|
|
if self.provider == "openai":
|
|
api_key = os.getenv("OPENAI_API_KEY") or os.getenv("OPENAI_APIKEY")
|
|
return OpenAI(api_key=api_key)
|
|
elif self.provider == "gemini":
|
|
api_key = os.getenv("GEMINI_API_KEY")
|
|
return genai.Client(api_key=api_key)
|
|
return None
|
|
|
|
def generate_full_narrative(self, stats_json: Dict[str, Any]) -> Dict[str, Any]:
|
|
if not self.client:
|
|
print("WARNING: No LLM client available")
|
|
return self._get_fallback_narrative()
|
|
|
|
clean_stats = self._shape_payload(stats_json)
|
|
prompt = self._build_prompt(clean_stats)
|
|
|
|
try:
|
|
if self.provider == "openai":
|
|
return self._call_openai(prompt)
|
|
elif self.provider == "gemini":
|
|
return self._call_gemini(prompt)
|
|
except Exception as e:
|
|
print(f"LLM Generation Error: {e}")
|
|
return self._get_fallback_narrative()
|
|
|
|
return self._get_fallback_narrative()
|
|
|
|
def _call_openai(self, prompt: str) -> Dict[str, Any]:
|
|
response = self.client.chat.completions.create(
|
|
model=self.model_name,
|
|
messages=[
|
|
{
|
|
"role": "system",
|
|
"content": "You are a witty music critic. Output only valid JSON.",
|
|
},
|
|
{"role": "user", "content": prompt},
|
|
],
|
|
response_format={"type": "json_object"},
|
|
max_completion_tokens=4000,
|
|
)
|
|
return self._clean_and_parse_json(response.choices[0].message.content)
|
|
|
|
def _call_gemini(self, prompt: str) -> Dict[str, Any]:
|
|
response = self.client.models.generate_content(
|
|
model=self.model_name,
|
|
contents=prompt,
|
|
config=genai.types.GenerateContentConfig(
|
|
response_mime_type="application/json"
|
|
),
|
|
)
|
|
return self._clean_and_parse_json(response.text)
|
|
|
|
def _build_prompt(self, clean_stats: Dict[str, Any]) -> str:
|
|
return f"""Analyze this Spotify listening data and generate a personalized report.
|
|
|
|
**RULES:**
|
|
1. NO mental health diagnoses. Use behavioral descriptors only.
|
|
2. Be specific - reference actual metrics from the data.
|
|
3. Be playful but not cruel.
|
|
4. Return ONLY valid JSON.
|
|
|
|
**DATA:**
|
|
{json.dumps(clean_stats, indent=2)}
|
|
|
|
**REQUIRED JSON:**
|
|
{{
|
|
"vibe_check_short": "1-2 sentence hook for the hero banner.",
|
|
"vibe_check": "2-3 paragraphs describing their overall listening personality.",
|
|
"patterns": ["Observation 1", "Observation 2", "Observation 3"],
|
|
"persona": "A creative label (e.g., 'The Genre Chameleon').",
|
|
"era_insight": "Comment on Musical Age ({clean_stats.get("era", {}).get("musical_age", "N/A")}).",
|
|
"roast": "1-2 sentence playful roast.",
|
|
"comparison": "Compare to previous period if data exists."
|
|
}}"""
|
|
|
|
def _shape_payload(self, stats: Dict[str, Any]) -> Dict[str, Any]:
|
|
s = stats.copy()
|
|
|
|
if "volume" in s:
|
|
volume_copy = {
|
|
k: v
|
|
for k, v in s["volume"].items()
|
|
if k not in ["top_tracks", "top_artists", "top_albums", "top_genres"]
|
|
}
|
|
volume_copy["top_tracks"] = [
|
|
t["name"] for t in stats["volume"].get("top_tracks", [])[:5]
|
|
]
|
|
volume_copy["top_artists"] = [
|
|
a["name"] for a in stats["volume"].get("top_artists", [])[:5]
|
|
]
|
|
volume_copy["top_genres"] = [
|
|
g["name"] for g in stats["volume"].get("top_genres", [])[:5]
|
|
]
|
|
s["volume"] = volume_copy
|
|
|
|
if "time_habits" in s:
|
|
s["time_habits"] = {
|
|
k: v for k, v in s["time_habits"].items() if k != "heatmap"
|
|
}
|
|
|
|
if "sessions" in s:
|
|
s["sessions"] = {
|
|
k: v for k, v in s["sessions"].items() if k != "session_list"
|
|
}
|
|
|
|
return s
|
|
|
|
def _clean_and_parse_json(self, raw_text: str) -> Dict[str, Any]:
|
|
try:
|
|
return json.loads(raw_text)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
try:
|
|
match = re.search(r"\{.*\}", raw_text, re.DOTALL)
|
|
if match:
|
|
return json.loads(match.group(0))
|
|
except:
|
|
pass
|
|
|
|
return self._get_fallback_narrative()
|
|
|
|
def _get_fallback_narrative(self) -> Dict[str, Any]:
|
|
return {
|
|
"vibe_check_short": "Your taste is... interesting.",
|
|
"vibe_check": "Data processing error. You're too mysterious to analyze right now.",
|
|
"patterns": [],
|
|
"persona": "The Enigma",
|
|
"era_insight": "Time is a flat circle.",
|
|
"roast": "You broke the machine. Congratulations.",
|
|
"comparison": "N/A",
|
|
}
|