Add automatic skip detection for playlist playback

Background poller monitors HA media_player state during playlist sessions.
When a track transition occurs and the previous track was played < 40% of
its duration, automatically records "skip" feedback. Also includes the
previously uncommitted delete_feedback endpoint.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-23 09:17:52 -06:00
parent 9f301497df
commit af6159a297
8 changed files with 400 additions and 8 deletions

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import joinedload
from haunt_fm.db import get_session
from haunt_fm.models.track import FeedbackEvent, Track
from haunt_fm.services.feedback import compute_contextual_score, record_feedback
from haunt_fm.services.feedback import compute_contextual_score, delete_feedback, record_feedback
router = APIRouter(prefix="/api/feedback")
@@ -41,6 +41,24 @@ async def submit_feedback(req: FeedbackRequest, session: AsyncSession = Depends(
}
@router.delete("/{event_id}")
async def retract_feedback(event_id: int, session: AsyncSession = Depends(get_session)):
"""Delete a feedback event entirely, removing its influence on scoring."""
event = await delete_feedback(session, event_id)
if event is None:
raise HTTPException(status_code=404, detail=f"Feedback event {event_id} not found")
return {
"id": event.id,
"playlist_id": event.playlist_id,
"track_id": event.track_id,
"signal": event.signal,
"signal_weight": event.signal_weight,
"vibe_text": event.vibe_text,
"created_at": event.created_at.isoformat(),
}
@router.post("/score")
async def get_score(req: ScoreRequest, session: AsyncSession = Depends(get_session)):
"""Compute the contextual feedback score for a track given a vibe description."""

View File

@@ -62,6 +62,7 @@ async def generate(req: GenerateRequest, session: AsyncSession = Depends(get_ses
"artist": t.artist,
"title": t.title,
"album": t.album,
"duration_ms": t.duration_ms,
"is_known": pt.is_known,
"similarity_score": pt.similarity_score,
}
@@ -72,6 +73,11 @@ async def generate(req: GenerateRequest, session: AsyncSession = Depends(get_ses
if req.auto_play and req.speaker_entity:
await play_playlist_on_speaker(track_list, req.speaker_entity)
# Register with skip detector for automatic skip feedback
from haunt_fm.services.skip_detector import register_session
register_session(req.speaker_entity, playlist.id, track_list)
return {
"playlist_id": playlist.id,
"name": playlist.name,

View File

@@ -16,6 +16,8 @@ from haunt_fm.config import settings
from haunt_fm.services.embedding import is_model_loaded
from haunt_fm.services.embedding_worker import is_running as is_worker_running
from haunt_fm.services.embedding_worker import last_processed as worker_last_processed
from haunt_fm.services.skip_detector import get_sessions as get_skip_sessions
from haunt_fm.services.skip_detector import is_running as is_skip_detector_running
router = APIRouter(prefix="/api")
@@ -103,6 +105,25 @@ async def status(session: AsyncSession = Depends(get_session)):
"total_generated": total_playlists,
"last_generated": last_playlist.isoformat() if last_playlist else None,
},
"skip_detector": {
"running": is_skip_detector_running(),
"active_sessions": len(get_skip_sessions()),
"sessions": [
{
"speaker_entity": entity,
"playlist_id": s.playlist_id,
"current_position": s.current_position,
"total_tracks": len(s.tracks),
"current_track": (
f"{s.tracks[s.current_position]['artist']} - {s.tracks[s.current_position]['title']}"
if s.current_position < len(s.tracks)
else None
),
"last_activity": s.last_activity_at.isoformat(),
}
for entity, s in get_skip_sessions().items()
],
},
},
"dependencies": {
"lastfm_api": "configured" if settings.lastfm_api_key else "not_configured",

View File

@@ -25,7 +25,14 @@ class Settings(BaseSettings):
# Feedback
feedback_overlap_threshold: float = 0.85
feedback_signal_weights: dict = {"up": 1.0, "down": -1.0, "skip": -0.3}
feedback_signal_weights: dict = {"up": 1.0, "down": -1.0, "skip": -0.3, "neutral": 0.0}
# Skip detection
skip_detector_enabled: bool = True
skip_detector_poll_interval_seconds: float = 3.0
skip_detector_skip_threshold: float = 0.4 # < 40% played = skip
skip_detector_session_timeout_minutes: int = 30
skip_detector_min_track_duration_ms: int = 30000 # ignore tracks < 30s
settings = Settings()

View File

@@ -27,15 +27,24 @@ async def lifespan(app: FastAPI):
worker_task = asyncio.create_task(run_worker())
logger.info("Embedding worker task created")
# Start skip detector in background
skip_detector_task = None
if settings.skip_detector_enabled:
from haunt_fm.services.skip_detector import run_skip_detector
skip_detector_task = asyncio.create_task(run_skip_detector())
logger.info("Skip detector task created")
yield
# Shutdown
if worker_task:
worker_task.cancel()
try:
await worker_task
except asyncio.CancelledError:
pass
for task in [worker_task, skip_detector_task]:
if task:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
logger.info("haunt-fm shut down")

View File

@@ -122,6 +122,19 @@ def compute_contextual_score(
return score
async def delete_feedback(session: AsyncSession, event_id: int) -> FeedbackEvent | None:
"""Delete a feedback event by ID. Returns the event if found, None otherwise."""
event = await session.get(FeedbackEvent, event_id)
if event is None:
return None
await session.delete(event)
await session.commit()
logger.info("Deleted feedback event %d (signal=%s, track=%d)", event.id, event.signal, event.track_id)
return event
async def apply_feedback_adjustments(
session: AsyncSession,
recommendations: list[dict],

View File

@@ -33,6 +33,27 @@ async def is_ha_reachable() -> bool:
return False
async def get_speaker_state(speaker_entity: str) -> dict | None:
"""Get current playback state from a HA media_player entity.
Returns dict with state, media_title, media_artist, media_duration, media_position
or None if unreachable.
"""
try:
data = await _ha_request("GET", f"/api/states/{speaker_entity}")
attrs = data.get("attributes", {})
return {
"state": data.get("state"),
"media_title": attrs.get("media_title"),
"media_artist": attrs.get("media_artist"),
"media_duration": attrs.get("media_duration"),
"media_position": attrs.get("media_position"),
}
except Exception:
logger.debug("Failed to get state for %s", speaker_entity)
return None
async def play_media_on_speaker(
media_content_id: str,
speaker_entity: str,

View File

@@ -0,0 +1,297 @@
"""Background skip detection for haunt-fm playlists.
Polls Home Assistant media_player state to detect when tracks are skipped
(played < threshold % of duration) and automatically records skip feedback.
"""
import asyncio
import difflib
import logging
import re
from dataclasses import dataclass, field
from datetime import datetime, timezone
from haunt_fm.config import settings
from haunt_fm.db import async_session
from haunt_fm.services.feedback import record_feedback
from haunt_fm.services.music_assistant import get_speaker_state
logger = logging.getLogger(__name__)
_running = False
_sessions: dict[str, "PlaylistSession"] = {}
@dataclass
class PlaylistSession:
playlist_id: int
speaker_entity: str
tracks: list[dict] # each: {track_id, artist, title, duration_ms, position}
current_position: int = 0
current_track_started_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
last_activity_at: datetime = field(default_factory=lambda: datetime.now(timezone.utc))
last_ha_media_title: str | None = None
last_ha_media_artist: str | None = None
def is_running() -> bool:
return _running
def get_sessions() -> dict[str, "PlaylistSession"]:
return dict(_sessions)
def register_session(speaker_entity: str, playlist_id: int, tracks: list[dict]) -> None:
"""Register a new playlist session for skip detection.
Called from playlists.py after auto_play succeeds.
"""
session = PlaylistSession(
playlist_id=playlist_id,
speaker_entity=speaker_entity,
tracks=tracks,
)
_sessions[speaker_entity] = session
logger.info(
"Skip detector: registered session for %s (playlist %d, %d tracks)",
speaker_entity,
playlist_id,
len(tracks),
)
# --- String normalization and fuzzy matching ---
_PAREN_SUFFIX = re.compile(r"\s*\(.*\)\s*$")
_EXTRA_WHITESPACE = re.compile(r"\s+")
def _normalize(s: str) -> str:
"""Normalize a string for fuzzy comparison."""
s = s.lower().strip()
s = _PAREN_SUFFIX.sub("", s) # strip "(Remastered)", "(Live)", etc.
s = _EXTRA_WHITESPACE.sub(" ", s)
return s
def _fuzzy_score(a: str, b: str) -> float:
"""Fuzzy similarity score between two strings (0-1)."""
return difflib.SequenceMatcher(None, _normalize(a), _normalize(b)).ratio()
def _match_track(
ha_title: str, ha_artist: str, tracks: list[dict], hint_position: int
) -> int | None:
"""Find the best matching track in the playlist for the current HA playback.
Returns the track's position index, or None if no match found.
Checks hint_position+1 first (sequential playback is the common case).
"""
best_idx = None
best_score = 0.0
threshold = 0.5
# Check the next sequential position first (most common case)
candidates = []
next_pos = hint_position + 1
if next_pos < len(tracks):
candidates.append(next_pos)
# Then check all other positions
candidates.extend(i for i in range(len(tracks)) if i != next_pos)
for idx in candidates:
t = tracks[idx]
title_score = _fuzzy_score(ha_title, t["title"])
artist_score = _fuzzy_score(ha_artist, t["artist"])
combined = title_score * 0.7 + artist_score * 0.3
if combined > best_score:
best_score = combined
best_idx = idx
# If sequential position matches well, use it immediately
if idx == next_pos and combined >= threshold:
return idx
if best_score >= threshold:
return best_idx
return None
# --- Core polling loop ---
async def _evaluate_skip(
session: PlaylistSession, track: dict, elapsed_ms: float
) -> None:
"""Check if a track was skipped and record feedback if so."""
duration_ms = track.get("duration_ms")
if duration_ms is None or duration_ms <= 0:
logger.debug("Skip eval: no duration for track %d, skipping evaluation", track["track_id"])
return
if duration_ms < settings.skip_detector_min_track_duration_ms:
logger.debug(
"Skip eval: track %d too short (%dms), ignoring",
track["track_id"],
duration_ms,
)
return
fraction_played = elapsed_ms / duration_ms
if fraction_played < settings.skip_detector_skip_threshold:
logger.info(
"Skip detected: '%s - %s' played %.0f%% (track_id=%d, playlist_id=%d)",
track["artist"],
track["title"],
fraction_played * 100,
track["track_id"],
session.playlist_id,
)
try:
async with async_session() as db:
await record_feedback(
db, session.playlist_id, track["track_id"], "skip"
)
logger.info(
"Skip feedback recorded for track %d in playlist %d",
track["track_id"],
session.playlist_id,
)
except ValueError as e:
# Playlist without vibe_embedding — can't record contextual feedback
logger.warning(
"Could not record skip feedback for track %d: %s",
track["track_id"],
e,
)
except Exception:
logger.exception(
"Error recording skip feedback for track %d", track["track_id"]
)
else:
logger.debug(
"Track '%s - %s' played %.0f%%, not a skip",
track["artist"],
track["title"],
fraction_played * 100,
)
async def _poll_all_sessions() -> None:
"""Single poll cycle across all active sessions."""
now = datetime.now(timezone.utc)
to_remove = []
for entity, session in list(_sessions.items()):
# Check session timeout
idle_minutes = (now - session.last_activity_at).total_seconds() / 60
if idle_minutes > settings.skip_detector_session_timeout_minutes:
logger.info(
"Skip detector: session for %s timed out (idle %.0f min)",
entity,
idle_minutes,
)
to_remove.append(entity)
continue
# Poll HA for current state
state = await get_speaker_state(entity)
if state is None:
continue # HA unreachable, try again next cycle
ha_title = state.get("media_title")
ha_artist = state.get("media_artist")
# No media info — speaker might be idle
if not ha_title or not ha_artist:
continue
session.last_activity_at = now
# Check if track changed
if (
ha_title == session.last_ha_media_title
and ha_artist == session.last_ha_media_artist
):
# Same track still playing — no transition
continue
# Track transition detected
logger.debug(
"Track transition on %s: '%s - %s''%s - %s'",
entity,
session.last_ha_media_artist,
session.last_ha_media_title,
ha_artist,
ha_title,
)
# Evaluate previous track for skip (if we had one)
if session.last_ha_media_title is not None:
prev_track = (
session.tracks[session.current_position]
if session.current_position < len(session.tracks)
else None
)
if prev_track:
elapsed_ms = (now - session.current_track_started_at).total_seconds() * 1000
await _evaluate_skip(session, prev_track, elapsed_ms)
# Match new track to playlist
matched_pos = _match_track(
ha_title, ha_artist, session.tracks, session.current_position
)
if matched_pos is None:
logger.info(
"Skip detector: '%s - %s' on %s doesn't match playlist — killing session",
ha_artist,
ha_title,
entity,
)
to_remove.append(entity)
continue
# Update session state
session.current_position = matched_pos
session.current_track_started_at = now
session.last_ha_media_title = ha_title
session.last_ha_media_artist = ha_artist
logger.debug(
"Skip detector: %s now at position %d (%s - %s)",
entity,
matched_pos,
session.tracks[matched_pos]["artist"],
session.tracks[matched_pos]["title"],
)
for entity in to_remove:
_sessions.pop(entity, None)
async def run_skip_detector() -> None:
"""Background loop that polls HA and detects skips."""
global _running
if not settings.skip_detector_enabled:
logger.info("Skip detector disabled")
return
_running = True
logger.info("Skip detector started (poll interval: %.1fs)", settings.skip_detector_poll_interval_seconds)
try:
while True:
try:
if _sessions:
await _poll_all_sessions()
except Exception:
logger.exception("Skip detector poll error")
await asyncio.sleep(settings.skip_detector_poll_interval_seconds)
finally:
_running = False
logger.info("Skip detector stopped")