From 094621a9a85f715eac7bda259e7f2ea5537d6f1a Mon Sep 17 00:00:00 2001 From: Thomas Hallock Date: Sun, 22 Feb 2026 19:14:34 -0600 Subject: [PATCH] Add named taste profiles for per-person recommendations MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Named profiles allow each household member to get personalized recommendations without polluting each other's taste. Includes profile CRUD API, speaker→profile auto-attribution, recent listen history endpoint, and profile param on all existing endpoints. All endpoints backward compatible (no profile param = "default"). Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 33 +++- README.md | 42 ++++ alembic/versions/003_add_profiles.py | 56 ++++++ src/haunt_fm/api/admin.py | 19 +- src/haunt_fm/api/history.py | 72 ++++++- src/haunt_fm/api/playlists.py | 2 + src/haunt_fm/api/profiles.py | 206 ++++++++++++++++++++ src/haunt_fm/api/recommendations.py | 4 +- src/haunt_fm/main.py | 3 +- src/haunt_fm/models/track.py | 19 ++ src/haunt_fm/services/history_ingest.py | 58 +++++- src/haunt_fm/services/playlist_generator.py | 2 + src/haunt_fm/services/recommender.py | 24 ++- src/haunt_fm/services/taste_profile.py | 49 ++++- 14 files changed, 556 insertions(+), 33 deletions(-) create mode 100644 alembic/versions/003_add_profiles.py create mode 100644 src/haunt_fm/api/profiles.py diff --git a/CLAUDE.md b/CLAUDE.md index 78b8818..1c855f7 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -37,6 +37,9 @@ curl -X POST http://192.168.86.51:8321/api/admin/discover -H "Content-Type: appl # Get recommendations curl http://192.168.86.51:8321/api/recommendations?limit=20 +# Get recommendations for a specific profile +curl "http://192.168.86.51:8321/api/recommendations?limit=20&profile=antialias" + # Generate and play a playlist curl -X POST http://192.168.86.51:8321/api/playlists/generate \ -H "Content-Type: application/json" \ @@ -51,6 +54,24 @@ curl -X POST http://192.168.86.51:8321/api/playlists/generate \ curl -X POST http://192.168.86.51:8321/api/playlists/generate \ -H "Content-Type: application/json" \ -d '{"total_tracks":15,"vibe":"upbeat party music","alpha":0.3,"auto_play":true,"speaker_entity":"media_player.living_room_speaker_2"}' + +# Generate playlist for a specific profile +curl -X POST http://192.168.86.51:8321/api/playlists/generate \ + -H "Content-Type: application/json" \ + -d '{"total_tracks":20,"profile":"antialias","speaker_entity":"media_player.study_speaker_2","auto_play":true}' + +# Create a profile +curl -X POST http://192.168.86.51:8321/api/profiles \ + -H "Content-Type: application/json" \ + -d '{"name":"antialias","display_name":"Me"}' + +# Map speakers to a profile +curl -X PUT http://192.168.86.51:8321/api/profiles/antialias/speakers \ + -H "Content-Type: application/json" \ + -d '{"speakers":["Study speaker","Master bathroom speaker"]}' + +# Build taste profile for a specific profile +curl -X POST "http://192.168.86.51:8321/api/admin/build-taste-profile?profile=antialias" ``` ## Environment Variables @@ -61,4 +82,14 @@ All prefixed with `HAUNTFM_`. See `.env.example` for full list. - Alembic migrations in `alembic/versions/` - Run migrations: `alembic upgrade head` -- Schema: tracks, listen_events, track_embeddings, similarity_links, taste_profiles, playlists, playlist_tracks +- Schema: tracks, listen_events, track_embeddings, similarity_links, taste_profiles, playlists, playlist_tracks, profiles, speaker_profile_mappings + +## Named Profiles + +Named profiles allow per-person taste tracking. No auth — just named buckets. + +- **Default behavior**: All endpoints without `profile` param use the "default" profile (backward compatible) +- **Profile CRUD**: `GET/POST /api/profiles`, `GET/DELETE /api/profiles/{name}` +- **Speaker mappings**: `PUT/GET /api/profiles/{name}/speakers` — auto-attributes listen events from mapped speakers +- **Attribution**: Webhook accepts `"profile": "name"` or auto-resolves from speaker→profile mapping +- **Recommendations/playlists**: Pass `profile=name` to use that profile's taste diff --git a/README.md b/README.md index f8693bb..14b880d 100644 --- a/README.md +++ b/README.md @@ -49,10 +49,17 @@ docker exec haunt-fm alembic upgrade head | GET | `/api/status` | Full pipeline status JSON | | GET | `/` | HTML status dashboard | | POST | `/api/history/webhook` | Log a listen event (from HA automation) | +| GET | `/api/history/recent?limit=20&profile=name` | Recent listen events (optional profile filter) | | POST | `/api/admin/discover` | Expand listening history via Last.fm | | POST | `/api/admin/build-taste-profile` | Rebuild taste profile from embeddings | | GET | `/api/recommendations?limit=50&vibe=chill+ambient` | Get ranked recommendations (optional vibe) | | POST | `/api/playlists/generate` | Generate and optionally play a playlist | +| GET | `/api/profiles` | List all named profiles with stats | +| POST | `/api/profiles` | Create a named profile | +| GET | `/api/profiles/{name}` | Get profile details + stats | +| DELETE | `/api/profiles/{name}` | Delete profile (reassigns events to default) | +| PUT | `/api/profiles/{name}/speakers` | Set speaker→profile mappings | +| GET | `/api/profiles/{name}/speakers` | List speaker mappings | ## Usage @@ -89,6 +96,7 @@ curl -X POST http://192.168.86.51:8321/api/playlists/generate \ - `auto_play` — `true` to immediately play on the speaker - `vibe` — text description of the desired mood/vibe (e.g. "chill lo-fi beats", "upbeat party music"). Uses CLAP text embeddings to match tracks in the same vector space as audio. - `alpha` — blend factor between taste profile and vibe (default 0.5). `1.0` = pure taste profile, `0.0` = pure vibe match, `0.5` = equal blend. Ignored when no vibe is provided. +- `profile` — named taste profile to use (default: "default"). Each profile has its own listening history and taste embedding. ### Speaker entities @@ -110,6 +118,40 @@ The `speaker_entity` **must** be a Music Assistant entity (the `_2` suffix ones) | downstairs | `media_player.downstairs_2` | | upstairs | `media_player.upstairs_2` | +### Named profiles + +Named profiles let each household member get personalized recommendations without polluting each other's taste. + +```bash +# Create a profile +curl -X POST http://192.168.86.51:8321/api/profiles \ + -H "Content-Type: application/json" \ + -d '{"name":"antialias","display_name":"Me"}' + +# Map speakers to auto-attribute listens +curl -X PUT http://192.168.86.51:8321/api/profiles/antialias/speakers \ + -H "Content-Type: application/json" \ + -d '{"speakers":["Study speaker","Master bathroom speaker"]}' + +# Log a listen event with explicit profile +curl -X POST http://192.168.86.51:8321/api/history/webhook \ + -H "Content-Type: application/json" \ + -d '{"title":"Song","artist":"Artist","profile":"antialias"}' + +# Get recommendations for a profile +curl "http://192.168.86.51:8321/api/recommendations?limit=20&profile=antialias" + +# Generate playlist for a profile +curl -X POST http://192.168.86.51:8321/api/playlists/generate \ + -H "Content-Type: application/json" \ + -d '{"total_tracks":20,"profile":"antialias","speaker_entity":"media_player.study_speaker_2","auto_play":true}' + +# Build taste profile manually +curl -X POST "http://192.168.86.51:8321/api/admin/build-taste-profile?profile=antialias" +``` + +All endpoints are backward compatible — omitting `profile` uses the "default" profile. Events with no profile assignment (including all existing events) belong to "default". + ### Other operations ```bash diff --git a/alembic/versions/003_add_profiles.py b/alembic/versions/003_add_profiles.py new file mode 100644 index 0000000..aca8a9c --- /dev/null +++ b/alembic/versions/003_add_profiles.py @@ -0,0 +1,56 @@ +"""Add named taste profiles and speaker mappings + +Revision ID: 003 +Revises: 002 +Create Date: 2026-02-22 + +""" +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + +revision: str = "003" +down_revision: Union[str, None] = "002" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # 1. Create profiles table + op.create_table( + "profiles", + sa.Column("id", sa.BigInteger, primary_key=True), + sa.Column("name", sa.Text, unique=True, nullable=False), + sa.Column("display_name", sa.Text), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.func.now()), + ) + + # 2. Seed "default" profile + op.execute("INSERT INTO profiles (id, name, display_name) VALUES (1, 'default', 'Default')") + + # 3. Create speaker_profile_mappings table + op.create_table( + "speaker_profile_mappings", + sa.Column("id", sa.BigInteger, primary_key=True), + sa.Column("speaker_name", sa.Text, unique=True, nullable=False), + sa.Column("profile_id", sa.BigInteger, sa.ForeignKey("profiles.id"), nullable=False), + ) + + # 4. Add profile_id to listen_events (nullable — NULL means "default") + op.add_column("listen_events", sa.Column("profile_id", sa.BigInteger, sa.ForeignKey("profiles.id"))) + + # 5. Add profile_id to taste_profiles (nullable, unique) + op.add_column("taste_profiles", sa.Column("profile_id", sa.BigInteger, sa.ForeignKey("profiles.id"), unique=True)) + + # 6. Link existing "default" taste profile row to the default profile + op.execute( + "UPDATE taste_profiles SET profile_id = 1 WHERE name = 'default'" + ) + + +def downgrade() -> None: + op.drop_column("taste_profiles", "profile_id") + op.drop_column("listen_events", "profile_id") + op.drop_table("speaker_profile_mappings") + op.drop_table("profiles") diff --git a/src/haunt_fm/api/admin.py b/src/haunt_fm/api/admin.py index bee568a..3e2fae2 100644 --- a/src/haunt_fm/api/admin.py +++ b/src/haunt_fm/api/admin.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from pydantic import BaseModel from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession @@ -47,13 +47,18 @@ async def discover(req: DiscoverRequest, session: AsyncSession = Depends(get_ses @router.post("/build-taste-profile") -async def build_profile(session: AsyncSession = Depends(get_session)): +async def build_profile( + profile: str | None = Query(default=None), + session: AsyncSession = Depends(get_session), +): """Rebuild the taste profile from listened-track embeddings.""" - profile = await build_taste_profile(session) - if profile is None: - return {"ok": False, "error": "No listened tracks with embeddings found"} + profile_name = profile or "default" + taste = await build_taste_profile(session, profile_name=profile_name) + if taste is None: + return {"ok": False, "error": f"No listened tracks with embeddings found for profile '{profile_name}'"} return { "ok": True, - "track_count": profile.track_count, - "updated_at": profile.updated_at.isoformat(), + "profile": profile_name, + "track_count": taste.track_count, + "updated_at": taste.updated_at.isoformat(), } diff --git a/src/haunt_fm/api/history.py b/src/haunt_fm/api/history.py index 4ff2a9e..50ab882 100644 --- a/src/haunt_fm/api/history.py +++ b/src/haunt_fm/api/history.py @@ -1,10 +1,12 @@ from datetime import datetime, timezone -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, Query from pydantic import BaseModel +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from haunt_fm.db import get_session +from haunt_fm.models.track import ListenEvent, Profile, Track from haunt_fm.services.history_ingest import ingest_listen_event from haunt_fm.services.taste_profile import build_taste_profile @@ -19,6 +21,59 @@ class WebhookPayload(BaseModel): duration_played: int | None = None source: str = "music_assistant" listened_at: datetime | None = None + profile: str | None = None + + +@router.get("/recent") +async def recent_listens( + limit: int = Query(default=20, ge=1, le=100), + profile: str | None = Query(default=None), + session: AsyncSession = Depends(get_session), +): + """Get recent listen events, optionally filtered by profile.""" + query = ( + select( + ListenEvent.id, + ListenEvent.listened_at, + ListenEvent.speaker_name, + ListenEvent.profile_id, + Track.title, + Track.artist, + Track.album, + ) + .join(Track, ListenEvent.track_id == Track.id) + ) + + if profile: + # Look up profile_id + profile_row = ( + await session.execute(select(Profile).where(Profile.name == profile)) + ).scalar_one_or_none() + if profile_row: + if profile == "default": + query = query.where( + (ListenEvent.profile_id == profile_row.id) | (ListenEvent.profile_id.is_(None)) + ) + else: + query = query.where(ListenEvent.profile_id == profile_row.id) + else: + return {"events": [], "count": 0, "profile": profile} + + query = query.order_by(ListenEvent.listened_at.desc()).limit(limit) + result = await session.execute(query) + + events = [] + for row in result: + events.append({ + "event_id": row.id, + "title": row.title, + "artist": row.artist, + "album": row.album, + "listened_at": row.listened_at.isoformat() if row.listened_at else None, + "speaker_name": row.speaker_name, + }) + + return {"events": events, "count": len(events), "profile": profile or "all"} @router.post("/webhook") @@ -26,7 +81,7 @@ async def receive_webhook(payload: WebhookPayload, session: AsyncSession = Depen if payload.listened_at is None: payload.listened_at = datetime.now(timezone.utc) - event = await ingest_listen_event( + event, resolved_profile = await ingest_listen_event( session=session, title=payload.title, artist=payload.artist, @@ -36,11 +91,18 @@ async def receive_webhook(payload: WebhookPayload, session: AsyncSession = Depen source=payload.source, listened_at=payload.listened_at, raw_payload=payload.model_dump(mode="json"), + profile_name=payload.profile, ) if event is None: return {"ok": True, "duplicate": True} - # Rebuild taste profile on every new listen event (cheap: just a weighted average) - await build_taste_profile(session) + # Rebuild the resolved profile's taste (or "default" if unassigned) + rebuild_profile = resolved_profile or "default" + await build_taste_profile(session, profile_name=rebuild_profile) - return {"ok": True, "track_id": event.track_id, "event_id": event.id} + return { + "ok": True, + "track_id": event.track_id, + "event_id": event.id, + "profile": rebuild_profile, + } diff --git a/src/haunt_fm/api/playlists.py b/src/haunt_fm/api/playlists.py index 70011ea..a7baf84 100644 --- a/src/haunt_fm/api/playlists.py +++ b/src/haunt_fm/api/playlists.py @@ -19,6 +19,7 @@ class GenerateRequest(BaseModel): auto_play: bool = False vibe: str | None = None alpha: float = Field(default=0.5, ge=0.0, le=1.0) + profile: str | None = None @router.post("/generate") @@ -42,6 +43,7 @@ async def generate(req: GenerateRequest, session: AsyncSession = Depends(get_ses vibe_embedding=vibe_embedding, alpha=alpha, vibe_text=req.vibe, + profile_name=req.profile or "default", ) # Load playlist tracks with track info diff --git a/src/haunt_fm/api/profiles.py b/src/haunt_fm/api/profiles.py new file mode 100644 index 0000000..90a6656 --- /dev/null +++ b/src/haunt_fm/api/profiles.py @@ -0,0 +1,206 @@ +from fastapi import APIRouter, Depends, HTTPException +from pydantic import BaseModel +from sqlalchemy import delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from haunt_fm.db import get_session +from haunt_fm.models.track import ( + ListenEvent, + Profile, + SpeakerProfileMapping, + TasteProfile, +) + +router = APIRouter(prefix="/api/profiles") + + +class CreateProfileRequest(BaseModel): + name: str + display_name: str | None = None + + +class SetSpeakersRequest(BaseModel): + speakers: list[str] + + +async def _get_profile_or_404(session: AsyncSession, name: str) -> Profile: + result = await session.execute(select(Profile).where(Profile.name == name)) + profile = result.scalar_one_or_none() + if profile is None: + raise HTTPException(status_code=404, detail=f"Profile '{name}' not found") + return profile + + +@router.get("/") +async def list_profiles(session: AsyncSession = Depends(get_session)): + """List all profiles with stats.""" + result = await session.execute( + select( + Profile.name, + Profile.display_name, + Profile.created_at, + func.count(ListenEvent.id).label("event_count"), + func.count(func.distinct(ListenEvent.track_id)).label("track_count"), + func.max(ListenEvent.listened_at).label("last_listen"), + ) + .outerjoin(ListenEvent, ListenEvent.profile_id == Profile.id) + .group_by(Profile.id) + .order_by(Profile.created_at) + ) + profiles = [] + for row in result: + profiles.append({ + "name": row.name, + "display_name": row.display_name, + "created_at": row.created_at.isoformat() if row.created_at else None, + "event_count": row.event_count, + "track_count": row.track_count, + "last_listen": row.last_listen.isoformat() if row.last_listen else None, + }) + + # Also count events with no profile_id (belong to "default") + unassigned = await session.execute( + select(func.count(ListenEvent.id)).where(ListenEvent.profile_id.is_(None)) + ) + unassigned_count = unassigned.scalar() or 0 + + # Add unassigned counts to default profile + for p in profiles: + if p["name"] == "default": + p["event_count"] += unassigned_count + break + + return {"profiles": profiles} + + +@router.post("/", status_code=201) +async def create_profile(req: CreateProfileRequest, session: AsyncSession = Depends(get_session)): + """Create a new profile.""" + existing = await session.execute(select(Profile).where(Profile.name == req.name)) + if existing.scalar_one_or_none() is not None: + raise HTTPException(status_code=409, detail=f"Profile '{req.name}' already exists") + + profile = Profile(name=req.name, display_name=req.display_name) + session.add(profile) + await session.commit() + await session.refresh(profile) + return { + "name": profile.name, + "display_name": profile.display_name, + "created_at": profile.created_at.isoformat(), + } + + +@router.get("/{name}") +async def get_profile(name: str, session: AsyncSession = Depends(get_session)): + """Get a single profile with stats.""" + profile = await _get_profile_or_404(session, name) + + # Event stats — include NULL profile_id events for "default" + if name == "default": + event_filter = (ListenEvent.profile_id == profile.id) | (ListenEvent.profile_id.is_(None)) + else: + event_filter = ListenEvent.profile_id == profile.id + + stats = await session.execute( + select( + func.count(ListenEvent.id).label("event_count"), + func.count(func.distinct(ListenEvent.track_id)).label("track_count"), + func.max(ListenEvent.listened_at).label("last_listen"), + ).where(event_filter) + ) + row = stats.one() + + # Speaker mappings + speakers = await session.execute( + select(SpeakerProfileMapping.speaker_name) + .where(SpeakerProfileMapping.profile_id == profile.id) + ) + + # Taste profile status + taste = await session.execute( + select(TasteProfile).where(TasteProfile.profile_id == profile.id) + ) + taste_profile = taste.scalar_one_or_none() + + return { + "name": profile.name, + "display_name": profile.display_name, + "created_at": profile.created_at.isoformat(), + "event_count": row.event_count, + "track_count": row.track_count, + "last_listen": row.last_listen.isoformat() if row.last_listen else None, + "speakers": [r.speaker_name for r in speakers], + "taste_profile": { + "track_count": taste_profile.track_count, + "updated_at": taste_profile.updated_at.isoformat(), + } if taste_profile else None, + } + + +@router.delete("/{name}") +async def delete_profile(name: str, session: AsyncSession = Depends(get_session)): + """Delete a profile, reassigning its events to default.""" + if name == "default": + raise HTTPException(status_code=400, detail="Cannot delete the default profile") + + profile = await _get_profile_or_404(session, name) + + # Reassign listen events to NULL (i.e. default) + await session.execute( + ListenEvent.__table__.update() + .where(ListenEvent.profile_id == profile.id) + .values(profile_id=None) + ) + + # Delete speaker mappings + await session.execute( + delete(SpeakerProfileMapping).where(SpeakerProfileMapping.profile_id == profile.id) + ) + + # Delete taste profile for this profile + await session.execute( + delete(TasteProfile).where(TasteProfile.profile_id == profile.id) + ) + + await session.delete(profile) + await session.commit() + return {"ok": True, "deleted": name} + + +@router.put("/{name}/speakers") +async def set_speakers(name: str, req: SetSpeakersRequest, session: AsyncSession = Depends(get_session)): + """Set speaker→profile mappings (replaces existing).""" + profile = await _get_profile_or_404(session, name) + + # Remove existing mappings for this profile + await session.execute( + delete(SpeakerProfileMapping).where(SpeakerProfileMapping.profile_id == profile.id) + ) + + # Create new mappings + for speaker in req.speakers: + # Check if this speaker is already mapped to another profile + existing = await session.execute( + select(SpeakerProfileMapping).where(SpeakerProfileMapping.speaker_name == speaker) + ) + if existing.scalar_one_or_none() is not None: + raise HTTPException( + status_code=409, + detail=f"Speaker '{speaker}' is already mapped to another profile", + ) + session.add(SpeakerProfileMapping(speaker_name=speaker, profile_id=profile.id)) + + await session.commit() + return {"ok": True, "profile": name, "speakers": req.speakers} + + +@router.get("/{name}/speakers") +async def get_speakers(name: str, session: AsyncSession = Depends(get_session)): + """List speaker mappings for a profile.""" + profile = await _get_profile_or_404(session, name) + result = await session.execute( + select(SpeakerProfileMapping.speaker_name) + .where(SpeakerProfileMapping.profile_id == profile.id) + ) + return {"profile": name, "speakers": [r.speaker_name for r in result]} diff --git a/src/haunt_fm/api/recommendations.py b/src/haunt_fm/api/recommendations.py index 457a2e3..7946836 100644 --- a/src/haunt_fm/api/recommendations.py +++ b/src/haunt_fm/api/recommendations.py @@ -13,6 +13,7 @@ async def recommendations( include_known: bool = Query(default=False), vibe: str | None = Query(default=None), alpha: float = Query(default=0.5, ge=0.0, le=1.0), + profile: str | None = Query(default=None), session: AsyncSession = Depends(get_session), ): vibe_embedding = None @@ -27,6 +28,7 @@ async def recommendations( results = await get_recommendations( session, limit=limit, exclude_known=not include_known, + profile_name=profile or "default", vibe_embedding=vibe_embedding, alpha=effective_alpha, ) - return {"recommendations": results, "count": len(results), "vibe": vibe, "alpha": effective_alpha} + return {"recommendations": results, "count": len(results), "vibe": vibe, "alpha": effective_alpha, "profile": profile or "default"} diff --git a/src/haunt_fm/main.py b/src/haunt_fm/main.py index 8b230a4..b71bd98 100644 --- a/src/haunt_fm/main.py +++ b/src/haunt_fm/main.py @@ -4,7 +4,7 @@ from contextlib import asynccontextmanager from fastapi import FastAPI -from haunt_fm.api import admin, health, history, playlists, recommendations, status, status_page +from haunt_fm.api import admin, health, history, playlists, profiles, recommendations, status, status_page from haunt_fm.config import settings logging.basicConfig( @@ -45,6 +45,7 @@ app.include_router(health.router) app.include_router(status.router) app.include_router(status_page.router) app.include_router(history.router) +app.include_router(profiles.router) app.include_router(admin.router) app.include_router(recommendations.router) app.include_router(playlists.router) diff --git a/src/haunt_fm/models/track.py b/src/haunt_fm/models/track.py index 5648542..4b6c68a 100644 --- a/src/haunt_fm/models/track.py +++ b/src/haunt_fm/models/track.py @@ -33,11 +33,29 @@ class Track(Base): embedding: Mapped["TrackEmbedding | None"] = relationship(back_populates="track") +class Profile(Base): + __tablename__ = "profiles" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + name: Mapped[str] = mapped_column(Text, unique=True, nullable=False) + display_name: Mapped[str | None] = mapped_column(Text) + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) + + +class SpeakerProfileMapping(Base): + __tablename__ = "speaker_profile_mappings" + + id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + speaker_name: Mapped[str] = mapped_column(Text, unique=True, nullable=False) + profile_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("profiles.id"), nullable=False) + + class ListenEvent(Base): __tablename__ = "listen_events" id: Mapped[int] = mapped_column(BigInteger, primary_key=True) track_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("tracks.id"), nullable=False) + profile_id: Mapped[int | None] = mapped_column(BigInteger, ForeignKey("profiles.id")) source: Mapped[str] = mapped_column(Text, nullable=False, default="music_assistant") speaker_name: Mapped[str | None] = mapped_column(Text) listened_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) @@ -82,6 +100,7 @@ class TasteProfile(Base): id: Mapped[int] = mapped_column(BigInteger, primary_key=True) name: Mapped[str] = mapped_column(Text, unique=True, nullable=False, default="default") + profile_id: Mapped[int | None] = mapped_column(BigInteger, ForeignKey("profiles.id"), unique=True) embedding = mapped_column(Vector(512), nullable=False) track_count: Mapped[int] = mapped_column(Integer, nullable=False) updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), server_default=func.now()) diff --git a/src/haunt_fm/services/history_ingest.py b/src/haunt_fm/services/history_ingest.py index 3f66d89..35ef103 100644 --- a/src/haunt_fm/services/history_ingest.py +++ b/src/haunt_fm/services/history_ingest.py @@ -4,7 +4,7 @@ from datetime import datetime, timedelta, timezone from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from haunt_fm.models.track import ListenEvent, Track +from haunt_fm.models.track import ListenEvent, Profile, SpeakerProfileMapping, Track logger = logging.getLogger(__name__) @@ -40,6 +40,45 @@ async def upsert_track( return track +async def _resolve_profile_id( + session: AsyncSession, + profile_name: str | None, + speaker_name: str | None, +) -> tuple[int | None, str | None]: + """Resolve a profile_id from explicit name or speaker mapping. + + Returns (profile_id, resolved_profile_name). + """ + # 1. Explicit profile name + if profile_name: + result = await session.execute( + select(Profile).where(Profile.name == profile_name) + ) + profile = result.scalar_one_or_none() + if profile: + return profile.id, profile.name + logger.warning("Profile '%s' not found, event will be unassigned", profile_name) + return None, None + + # 2. Speaker→profile mapping + if speaker_name: + result = await session.execute( + select(SpeakerProfileMapping) + .where(SpeakerProfileMapping.speaker_name == speaker_name) + ) + mapping = result.scalar_one_or_none() + if mapping: + # Look up the profile name for logging + profile_result = await session.execute( + select(Profile).where(Profile.id == mapping.profile_id) + ) + profile = profile_result.scalar_one_or_none() + return mapping.profile_id, profile.name if profile else None + + # 3. Neither — unassigned (belongs to default) + return None, None + + async def ingest_listen_event( session: AsyncSession, title: str, @@ -50,11 +89,16 @@ async def ingest_listen_event( source: str, listened_at: datetime, raw_payload: dict | None = None, -) -> ListenEvent | None: + profile_name: str | None = None, +) -> tuple[ListenEvent | None, str | None]: + """Ingest a listen event, resolving profile from name or speaker. + + Returns (event, resolved_profile_name). resolved_profile_name is None + when the event belongs to the default profile (via NULL profile_id). + """ track = await upsert_track(session, title, artist, album) # Deduplicate: skip if this track was logged within the last 60 seconds. - # Multiple HA entities (Cast, WiFi, MA) fire simultaneously for the same play event. cutoff = datetime.now(timezone.utc) - timedelta(seconds=60) recent = await session.execute( select(ListenEvent) @@ -64,10 +108,14 @@ async def ingest_listen_event( ) if recent.scalar_one_or_none() is not None: logger.debug("Skipping duplicate listen event for %s - %s", artist, title) - return None + return None, None + + # Resolve profile + profile_id, resolved_name = await _resolve_profile_id(session, profile_name, speaker_name) event = ListenEvent( track_id=track.id, + profile_id=profile_id, source=source, speaker_name=speaker_name, listened_at=listened_at, @@ -77,4 +125,4 @@ async def ingest_listen_event( session.add(event) await session.commit() await session.refresh(event) - return event + return event, resolved_name diff --git a/src/haunt_fm/services/playlist_generator.py b/src/haunt_fm/services/playlist_generator.py index 8a85391..9f24f9e 100644 --- a/src/haunt_fm/services/playlist_generator.py +++ b/src/haunt_fm/services/playlist_generator.py @@ -24,6 +24,7 @@ async def generate_playlist( vibe_embedding: np.ndarray | None = None, alpha: float = 0.5, vibe_text: str | None = None, + profile_name: str = "default", ) -> Playlist: """Generate a playlist mixing known-liked tracks with new recommendations. @@ -54,6 +55,7 @@ async def generate_playlist( # Get new recommendations recs = await get_recommendations( session, limit=new_count * 2, exclude_known=True, + profile_name=profile_name, vibe_embedding=vibe_embedding, alpha=alpha, ) new_tracks = [(r["track_id"], r["similarity"]) for r in recs[:new_count]] diff --git a/src/haunt_fm/services/recommender.py b/src/haunt_fm/services/recommender.py index 6e05b97..712d94e 100644 --- a/src/haunt_fm/services/recommender.py +++ b/src/haunt_fm/services/recommender.py @@ -5,7 +5,7 @@ from sqlalchemy import select, text from sqlalchemy.ext.asyncio import AsyncSession from haunt_fm.models.track import ( - ListenEvent, + Profile, TasteProfile, Track, TrackEmbedding, @@ -28,11 +28,27 @@ async def get_recommendations( vibe_embedding: Optional 512-dim text embedding for vibe/mood matching. alpha: Blend factor. 1.0 = pure taste, 0.0 = pure vibe, 0.5 = equal blend. """ - # Load taste profile - profile = ( - await session.execute(select(TasteProfile).where(TasteProfile.name == profile_name)) + # Load taste profile via Profile → TasteProfile join + profile_row = ( + await session.execute(select(Profile).where(Profile.name == profile_name)) ).scalar_one_or_none() + profile = None + if profile_row is not None: + profile = ( + await session.execute( + select(TasteProfile).where(TasteProfile.profile_id == profile_row.id) + ) + ).scalar_one_or_none() + + # Fallback: look up by name (for legacy rows without profile_id) + if profile is None: + profile = ( + await session.execute( + select(TasteProfile).where(TasteProfile.name == profile_name) + ) + ).scalar_one_or_none() + if profile is None and vibe_embedding is None: return [] diff --git a/src/haunt_fm/services/taste_profile.py b/src/haunt_fm/services/taste_profile.py index bc88659..d001b03 100644 --- a/src/haunt_fm/services/taste_profile.py +++ b/src/haunt_fm/services/taste_profile.py @@ -7,6 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from haunt_fm.models.track import ( ListenEvent, + Profile, TasteProfile, Track, TrackEmbedding, @@ -21,12 +22,31 @@ def _recency_weight(listened_at: datetime, now: datetime, half_life_days: float return 2 ** (-age_days / half_life_days) -async def build_taste_profile(session: AsyncSession, name: str = "default") -> TasteProfile | None: +async def _resolve_profile(session: AsyncSession, profile_name: str) -> Profile | None: + """Look up a Profile by name.""" + result = await session.execute(select(Profile).where(Profile.name == profile_name)) + return result.scalar_one_or_none() + + +async def build_taste_profile(session: AsyncSession, profile_name: str = "default") -> TasteProfile | None: """Build a taste profile as the weighted average of listened-track embeddings. Weights: play_count * recency_decay for each track. + Filters events by profile. For "default", includes events with NULL profile_id. """ - # Get all listened tracks with embeddings + profile = await _resolve_profile(session, profile_name) + + # Build the event filter based on profile + if profile is not None and profile_name == "default": + # Default profile: include events explicitly assigned + unassigned (NULL) + event_filter = (ListenEvent.profile_id == profile.id) | (ListenEvent.profile_id.is_(None)) + elif profile is not None: + event_filter = ListenEvent.profile_id == profile.id + else: + # Profile doesn't exist yet — fall back to all unassigned events + event_filter = ListenEvent.profile_id.is_(None) + + # Get all listened tracks with embeddings for this profile result = await session.execute( select( Track.id, @@ -36,12 +56,13 @@ async def build_taste_profile(session: AsyncSession, name: str = "default") -> T ) .join(TrackEmbedding, TrackEmbedding.track_id == Track.id) .join(ListenEvent, ListenEvent.track_id == Track.id) + .where(event_filter) .group_by(Track.id, TrackEmbedding.embedding) ) rows = result.all() if not rows: - logger.warning("No listened tracks with embeddings found") + logger.warning("No listened tracks with embeddings found for profile '%s'", profile_name) return None now = datetime.now(timezone.utc) @@ -64,10 +85,19 @@ async def build_taste_profile(session: AsyncSession, name: str = "default") -> T profile_emb = (embeddings_arr * weights_arr[:, np.newaxis]).sum(axis=0) profile_emb = profile_emb / np.linalg.norm(profile_emb) - # Upsert - existing = ( - await session.execute(select(TasteProfile).where(TasteProfile.name == name)) - ).scalar_one_or_none() + # Upsert keyed by profile_id + if profile is not None: + existing = ( + await session.execute( + select(TasteProfile).where(TasteProfile.profile_id == profile.id) + ) + ).scalar_one_or_none() + else: + existing = ( + await session.execute( + select(TasteProfile).where(TasteProfile.name == profile_name) + ) + ).scalar_one_or_none() if existing: existing.embedding = profile_emb.tolist() @@ -75,7 +105,8 @@ async def build_taste_profile(session: AsyncSession, name: str = "default") -> T existing.updated_at = now else: existing = TasteProfile( - name=name, + name=profile_name, + profile_id=profile.id if profile else None, embedding=profile_emb.tolist(), track_count=len(rows), updated_at=now, @@ -84,5 +115,5 @@ async def build_taste_profile(session: AsyncSession, name: str = "default") -> T await session.commit() await session.refresh(existing) - logger.info("Built taste profile '%s' from %d tracks", name, len(rows)) + logger.info("Built taste profile '%s' from %d tracks", profile_name, len(rows)) return existing