Add profile-scoped feedback endpoint
New POST /api/profiles/{name}/feedback accepts explicit vibe text and
records feedback against a named profile. GET history endpoint added too.
Scoring now filters feedback by profile_name for profile-aware playlists.
Migration 005 adds profile_name column and makes playlist_id nullable.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
30
alembic/versions/005_add_profile_feedback.py
Normal file
30
alembic/versions/005_add_profile_feedback.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
"""Add profile_name to feedback_events, make playlist_id nullable
|
||||||
|
|
||||||
|
Revision ID: 005
|
||||||
|
Revises: 004
|
||||||
|
Create Date: 2026-02-23
|
||||||
|
|
||||||
|
"""
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "005"
|
||||||
|
down_revision: Union[str, None] = "004"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column("feedback_events", sa.Column("profile_name", sa.Text, nullable=True))
|
||||||
|
op.create_index("ix_feedback_events_profile_name", "feedback_events", ["profile_name"])
|
||||||
|
|
||||||
|
# Make playlist_id nullable (profile-scoped feedback doesn't require a playlist)
|
||||||
|
op.alter_column("feedback_events", "playlist_id", existing_type=sa.BigInteger, nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column("feedback_events", "playlist_id", existing_type=sa.BigInteger, nullable=False)
|
||||||
|
op.drop_index("ix_feedback_events_profile_name", table_name="feedback_events")
|
||||||
|
op.drop_column("feedback_events", "profile_name")
|
||||||
@@ -129,6 +129,7 @@ async def get_history(
|
|||||||
"events": [
|
"events": [
|
||||||
{
|
{
|
||||||
"id": e.id,
|
"id": e.id,
|
||||||
|
"profile_name": e.profile_name,
|
||||||
"playlist_id": e.playlist_id,
|
"playlist_id": e.playlist_id,
|
||||||
"track_id": e.track_id,
|
"track_id": e.track_id,
|
||||||
"artist": e.track.artist,
|
"artist": e.track.artist,
|
||||||
|
|||||||
@@ -1,15 +1,18 @@
|
|||||||
from fastapi import APIRouter, Depends, HTTPException
|
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import delete, func, select
|
from sqlalchemy import delete, func, or_, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import joinedload
|
||||||
|
|
||||||
from haunt_fm.db import get_session
|
from haunt_fm.db import get_session
|
||||||
from haunt_fm.models.track import (
|
from haunt_fm.models.track import (
|
||||||
|
FeedbackEvent,
|
||||||
ListenEvent,
|
ListenEvent,
|
||||||
Profile,
|
Profile,
|
||||||
SpeakerProfileMapping,
|
SpeakerProfileMapping,
|
||||||
TasteProfile,
|
TasteProfile,
|
||||||
)
|
)
|
||||||
|
from haunt_fm.services.feedback import record_profile_feedback
|
||||||
|
|
||||||
router = APIRouter(prefix="/api/profiles")
|
router = APIRouter(prefix="/api/profiles")
|
||||||
|
|
||||||
@@ -23,6 +26,13 @@ class SetSpeakersRequest(BaseModel):
|
|||||||
speakers: list[str]
|
speakers: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ProfileFeedbackRequest(BaseModel):
|
||||||
|
track_id: int
|
||||||
|
signal: str
|
||||||
|
vibe: str
|
||||||
|
playlist_id: int | None = None
|
||||||
|
|
||||||
|
|
||||||
async def _get_profile_or_404(session: AsyncSession, name: str) -> Profile:
|
async def _get_profile_or_404(session: AsyncSession, name: str) -> Profile:
|
||||||
result = await session.execute(select(Profile).where(Profile.name == name))
|
result = await session.execute(select(Profile).where(Profile.name == name))
|
||||||
profile = result.scalar_one_or_none()
|
profile = result.scalar_one_or_none()
|
||||||
@@ -204,3 +214,93 @@ async def get_speakers(name: str, session: AsyncSession = Depends(get_session)):
|
|||||||
.where(SpeakerProfileMapping.profile_id == profile.id)
|
.where(SpeakerProfileMapping.profile_id == profile.id)
|
||||||
)
|
)
|
||||||
return {"profile": name, "speakers": [r.speaker_name for r in result]}
|
return {"profile": name, "speakers": [r.speaker_name for r in result]}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{name}/feedback")
|
||||||
|
async def submit_profile_feedback(
|
||||||
|
name: str, req: ProfileFeedbackRequest, session: AsyncSession = Depends(get_session)
|
||||||
|
):
|
||||||
|
"""Submit feedback for a track scoped to a named profile with explicit vibe text."""
|
||||||
|
from haunt_fm.services.embedding import embed_text, is_model_loaded, load_model
|
||||||
|
|
||||||
|
if not is_model_loaded():
|
||||||
|
load_model()
|
||||||
|
vibe_embedding = embed_text(req.vibe)
|
||||||
|
|
||||||
|
try:
|
||||||
|
event = await record_profile_feedback(
|
||||||
|
session,
|
||||||
|
profile_name=name,
|
||||||
|
track_id=req.track_id,
|
||||||
|
signal=req.signal,
|
||||||
|
vibe_text=req.vibe,
|
||||||
|
vibe_embedding=vibe_embedding.tolist(),
|
||||||
|
playlist_id=req.playlist_id,
|
||||||
|
)
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": event.id,
|
||||||
|
"profile_name": event.profile_name,
|
||||||
|
"playlist_id": event.playlist_id,
|
||||||
|
"track_id": event.track_id,
|
||||||
|
"signal": event.signal,
|
||||||
|
"signal_weight": event.signal_weight,
|
||||||
|
"vibe_text": event.vibe_text,
|
||||||
|
"created_at": event.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{name}/feedback/history")
|
||||||
|
async def get_profile_feedback_history(
|
||||||
|
name: str,
|
||||||
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
|
track_id: int | None = Query(default=None),
|
||||||
|
session: AsyncSession = Depends(get_session),
|
||||||
|
):
|
||||||
|
"""Get feedback history scoped to a profile."""
|
||||||
|
# Verify profile exists
|
||||||
|
await _get_profile_or_404(session, name)
|
||||||
|
|
||||||
|
# For "default", include events where profile_name IS NULL or "default"
|
||||||
|
if name == "default":
|
||||||
|
profile_filter = or_(
|
||||||
|
FeedbackEvent.profile_name.is_(None),
|
||||||
|
FeedbackEvent.profile_name == "default",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
profile_filter = FeedbackEvent.profile_name == name
|
||||||
|
|
||||||
|
query = (
|
||||||
|
select(FeedbackEvent)
|
||||||
|
.options(joinedload(FeedbackEvent.track))
|
||||||
|
.where(profile_filter)
|
||||||
|
.order_by(FeedbackEvent.created_at.desc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
if track_id is not None:
|
||||||
|
query = query.where(FeedbackEvent.track_id == track_id)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
|
events = result.scalars().unique().all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"profile": name,
|
||||||
|
"events": [
|
||||||
|
{
|
||||||
|
"id": e.id,
|
||||||
|
"profile_name": e.profile_name,
|
||||||
|
"playlist_id": e.playlist_id,
|
||||||
|
"track_id": e.track_id,
|
||||||
|
"artist": e.track.artist,
|
||||||
|
"title": e.track.title,
|
||||||
|
"signal": e.signal,
|
||||||
|
"signal_weight": e.signal_weight,
|
||||||
|
"vibe_text": e.vibe_text,
|
||||||
|
"created_at": e.created_at.isoformat(),
|
||||||
|
}
|
||||||
|
for e in events
|
||||||
|
],
|
||||||
|
"count": len(events),
|
||||||
|
}
|
||||||
|
|||||||
@@ -34,6 +34,6 @@ async def recommendations(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Apply feedback adjustments (re-ranks based on contextual feedback)
|
# Apply feedback adjustments (re-ranks based on contextual feedback)
|
||||||
results = await apply_feedback_adjustments(session, results, vibe_embedding)
|
results = await apply_feedback_adjustments(session, results, vibe_embedding, profile_name=profile or "default")
|
||||||
|
|
||||||
return {"recommendations": results, "count": len(results), "vibe": vibe, "alpha": effective_alpha, "profile": profile or "default"}
|
return {"recommendations": results, "count": len(results), "vibe": vibe, "alpha": effective_alpha, "profile": profile or "default"}
|
||||||
|
|||||||
@@ -140,8 +140,9 @@ class FeedbackEvent(Base):
|
|||||||
__tablename__ = "feedback_events"
|
__tablename__ = "feedback_events"
|
||||||
|
|
||||||
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||||
playlist_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("playlists.id"), nullable=False)
|
playlist_id: Mapped[int | None] = mapped_column(BigInteger, ForeignKey("playlists.id"), nullable=True)
|
||||||
track_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("tracks.id"), nullable=False)
|
track_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("tracks.id"), nullable=False)
|
||||||
|
profile_name: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
vibe_embedding = mapped_column(Vector(512), nullable=False)
|
vibe_embedding = mapped_column(Vector(512), nullable=False)
|
||||||
vibe_text: Mapped[str | None] = mapped_column(Text)
|
vibe_text: Mapped[str | None] = mapped_column(Text)
|
||||||
signal: Mapped[str] = mapped_column(Text, nullable=False)
|
signal: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
@@ -150,7 +151,8 @@ class FeedbackEvent(Base):
|
|||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("ix_feedback_events_track_id", "track_id"),
|
Index("ix_feedback_events_track_id", "track_id"),
|
||||||
|
Index("ix_feedback_events_profile_name", "profile_name"),
|
||||||
)
|
)
|
||||||
|
|
||||||
playlist: Mapped[Playlist] = relationship(back_populates="feedback_events")
|
playlist: Mapped[Playlist | None] = relationship(back_populates="feedback_events")
|
||||||
track: Mapped[Track] = relationship()
|
track: Mapped[Track] = relationship()
|
||||||
|
|||||||
@@ -4,8 +4,10 @@ import numpy as np
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from sqlalchemy import or_
|
||||||
|
|
||||||
from haunt_fm.config import settings
|
from haunt_fm.config import settings
|
||||||
from haunt_fm.models.track import FeedbackEvent, Playlist, Track
|
from haunt_fm.models.track import FeedbackEvent, Playlist, Profile, Track
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -51,6 +53,55 @@ async def record_feedback(
|
|||||||
return event
|
return event
|
||||||
|
|
||||||
|
|
||||||
|
async def record_profile_feedback(
|
||||||
|
session: AsyncSession,
|
||||||
|
profile_name: str,
|
||||||
|
track_id: int,
|
||||||
|
signal: str,
|
||||||
|
vibe_text: str,
|
||||||
|
vibe_embedding: list[float],
|
||||||
|
playlist_id: int | None = None,
|
||||||
|
) -> FeedbackEvent:
|
||||||
|
"""Record a feedback event scoped to a named profile with explicit vibe."""
|
||||||
|
if signal not in VALID_SIGNALS:
|
||||||
|
raise ValueError(f"Invalid signal '{signal}'. Must be one of: {', '.join(sorted(VALID_SIGNALS))}")
|
||||||
|
|
||||||
|
# Verify profile exists
|
||||||
|
result = await session.execute(select(Profile).where(Profile.name == profile_name))
|
||||||
|
if result.scalar_one_or_none() is None:
|
||||||
|
raise ValueError(f"Profile '{profile_name}' not found")
|
||||||
|
|
||||||
|
# Verify track exists
|
||||||
|
track = await session.get(Track, track_id)
|
||||||
|
if track is None:
|
||||||
|
raise ValueError(f"Track {track_id} not found")
|
||||||
|
|
||||||
|
# Verify playlist if provided
|
||||||
|
if playlist_id is not None:
|
||||||
|
playlist = await session.get(Playlist, playlist_id)
|
||||||
|
if playlist is None:
|
||||||
|
raise ValueError(f"Playlist {playlist_id} not found")
|
||||||
|
|
||||||
|
weight = settings.feedback_signal_weights[signal]
|
||||||
|
|
||||||
|
event = FeedbackEvent(
|
||||||
|
playlist_id=playlist_id,
|
||||||
|
track_id=track_id,
|
||||||
|
profile_name=profile_name,
|
||||||
|
vibe_embedding=vibe_embedding,
|
||||||
|
vibe_text=vibe_text,
|
||||||
|
signal=signal,
|
||||||
|
signal_weight=weight,
|
||||||
|
)
|
||||||
|
session.add(event)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(event)
|
||||||
|
|
||||||
|
logger.info("Recorded %s feedback for track %d, profile '%s' (vibe: %s)",
|
||||||
|
signal, track_id, profile_name, vibe_text)
|
||||||
|
return event
|
||||||
|
|
||||||
|
|
||||||
def compute_contextual_score(
|
def compute_contextual_score(
|
||||||
events: list[FeedbackEvent],
|
events: list[FeedbackEvent],
|
||||||
current_vibe_embedding: np.ndarray,
|
current_vibe_embedding: np.ndarray,
|
||||||
@@ -139,20 +190,30 @@ async def apply_feedback_adjustments(
|
|||||||
session: AsyncSession,
|
session: AsyncSession,
|
||||||
recommendations: list[dict],
|
recommendations: list[dict],
|
||||||
current_vibe_embedding: np.ndarray | None,
|
current_vibe_embedding: np.ndarray | None,
|
||||||
|
profile_name: str = "default",
|
||||||
) -> list[dict]:
|
) -> list[dict]:
|
||||||
"""Adjust recommendation scores based on contextual feedback.
|
"""Adjust recommendation scores based on contextual feedback.
|
||||||
|
|
||||||
Fetches feedback events for the recommended tracks, computes contextual
|
Fetches feedback events for the recommended tracks, computes contextual
|
||||||
scores, adds them to similarity, and re-sorts.
|
scores, adds them to similarity, and re-sorts.
|
||||||
|
|
||||||
|
When profile_name is "default", includes events where profile_name IS NULL
|
||||||
|
or profile_name = "default" (backward compatible).
|
||||||
"""
|
"""
|
||||||
if current_vibe_embedding is None or not recommendations:
|
if current_vibe_embedding is None or not recommendations:
|
||||||
return recommendations
|
return recommendations
|
||||||
|
|
||||||
track_ids = [r["track_id"] for r in recommendations]
|
track_ids = [r["track_id"] for r in recommendations]
|
||||||
|
|
||||||
result = await session.execute(
|
query = select(FeedbackEvent).where(FeedbackEvent.track_id.in_(track_ids))
|
||||||
select(FeedbackEvent).where(FeedbackEvent.track_id.in_(track_ids))
|
if profile_name == "default":
|
||||||
)
|
query = query.where(
|
||||||
|
or_(FeedbackEvent.profile_name.is_(None), FeedbackEvent.profile_name == "default")
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
query = query.where(FeedbackEvent.profile_name == profile_name)
|
||||||
|
|
||||||
|
result = await session.execute(query)
|
||||||
events = list(result.scalars().all())
|
events = list(result.scalars().all())
|
||||||
|
|
||||||
if not events:
|
if not events:
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ async def generate_playlist(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Apply feedback adjustments (re-ranks based on contextual feedback)
|
# Apply feedback adjustments (re-ranks based on contextual feedback)
|
||||||
recs = await apply_feedback_adjustments(session, recs, vibe_embedding)
|
recs = await apply_feedback_adjustments(session, recs, vibe_embedding, profile_name=profile_name)
|
||||||
|
|
||||||
new_tracks = [(r["track_id"], r.get("adjusted_score", r["similarity"])) for r in recs[:new_count]]
|
new_tracks = [(r["track_id"], r.get("adjusted_score", r["similarity"])) for r in recs[:new_count]]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user