Fix embedding dimensions and worker session management

Two issues:
1. CLAP model output needed .flatten() to produce a 1-D vector for
   pgvector. Without it, the nested array caused "expected ndim to be 1".
2. Worker now uses a fresh session per track instead of sharing one
   across a batch, preventing PendingRollbackError cascading from one
   failure to the next.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-02-22 10:43:44 -06:00
parent 771f714384
commit e9cf1e9b17
2 changed files with 65 additions and 62 deletions

View File

@@ -44,7 +44,7 @@ def embed_audio(audio: np.ndarray, sample_rate: int = 48000) -> np.ndarray:
with torch.no_grad(): with torch.no_grad():
embeddings = _model.get_audio_features(**inputs) embeddings = _model.get_audio_features(**inputs)
# Normalize # Flatten to 1-D and normalize
emb = embeddings[0].numpy() emb = embeddings[0].numpy().flatten()
emb = emb / np.linalg.norm(emb) emb = emb / np.linalg.norm(emb)
return emb return emb

View File

@@ -26,29 +26,30 @@ def last_processed() -> datetime | None:
return _last_processed return _last_processed
async def _process_track(session: AsyncSession, track: Track) -> bool: async def _process_track(track_id: int, artist: str, title: str, preview_url: str | None) -> bool:
"""Process a single track: find preview, download, embed, store. Returns True on success.""" """Process a single track: find preview, download, embed, store. Returns True on success."""
global _last_processed global _last_processed
async with async_session() as session:
# Mark as downloading # Mark as downloading
await session.execute( await session.execute(
update(Track).where(Track.id == track.id).values(embedding_status="downloading") update(Track).where(Track.id == track_id).values(embedding_status="downloading")
) )
await session.commit() await session.commit()
# Find iTunes preview # Find iTunes preview
if not track.itunes_preview_url: if not preview_url:
result = await search_track(track.artist, track.title) result = await search_track(artist, title)
if result is None: if result is None:
await session.execute( await session.execute(
update(Track).where(Track.id == track.id).values(embedding_status="no_preview") update(Track).where(Track.id == track_id).values(embedding_status="no_preview")
) )
await session.commit() await session.commit()
return False return False
await session.execute( await session.execute(
update(Track) update(Track)
.where(Track.id == track.id) .where(Track.id == track_id)
.values( .values(
itunes_track_id=result["track_id"], itunes_track_id=result["track_id"],
itunes_preview_url=result["preview_url"], itunes_preview_url=result["preview_url"],
@@ -59,11 +60,9 @@ async def _process_track(session: AsyncSession, track: Track) -> bool:
) )
await session.commit() await session.commit()
preview_url = result["preview_url"] preview_url = result["preview_url"]
else:
preview_url = track.itunes_preview_url
# Download and decode # Download and decode
filename = f"{track.id}.m4a" filename = f"{track_id}.m4a"
filepath = await download_preview(preview_url, settings.audio_cache_dir, filename) filepath = await download_preview(preview_url, settings.audio_cache_dir, filename)
audio = decode_audio(filepath) audio = decode_audio(filepath)
@@ -72,12 +71,12 @@ async def _process_track(session: AsyncSession, track: Track) -> bool:
# Store # Store
track_embedding = TrackEmbedding( track_embedding = TrackEmbedding(
track_id=track.id, track_id=track_id,
embedding=embedding.tolist(), embedding=embedding.tolist(),
) )
session.add(track_embedding) session.add(track_embedding)
await session.execute( await session.execute(
update(Track).where(Track.id == track.id).values(embedding_status="done") update(Track).where(Track.id == track_id).values(embedding_status="done")
) )
await session.commit() await session.commit()
@@ -109,27 +108,31 @@ async def run_worker():
.order_by(Track.created_at) .order_by(Track.created_at)
.limit(settings.embedding_batch_size) .limit(settings.embedding_batch_size)
) )
tracks = result.scalars().all() tracks = [
(t.id, t.artist, t.title, t.itunes_preview_url)
for t in result.scalars().all()
]
if not tracks: if not tracks:
await asyncio.sleep(settings.embedding_interval_seconds) await asyncio.sleep(settings.embedding_interval_seconds)
continue continue
for track in tracks: for track_id, artist, title, preview_url in tracks:
try: try:
await _process_track(session, track) await _process_track(track_id, artist, title, preview_url)
logger.info("Embedded: %s - %s", track.artist, track.title) logger.info("Embedded: %s - %s", artist, title)
except Exception as e: except Exception as e:
logger.exception("Failed to embed %s - %s", track.artist, track.title) logger.exception("Failed to embed %s - %s", artist, title)
await session.execute( async with async_session() as err_session:
await err_session.execute(
update(Track) update(Track)
.where(Track.id == track.id) .where(Track.id == track_id)
.values( .values(
embedding_status="failed", embedding_status="failed",
embedding_error=str(e), embedding_error=str(e),
) )
) )
await session.commit() await err_session.commit()
except Exception: except Exception:
logger.exception("Embedding worker error") logger.exception("Embedding worker error")