diff --git a/src/haunt_fm/services/embedding.py b/src/haunt_fm/services/embedding.py index 842e0ce..570b5a4 100644 --- a/src/haunt_fm/services/embedding.py +++ b/src/haunt_fm/services/embedding.py @@ -42,9 +42,14 @@ def embed_audio(audio: np.ndarray, sample_rate: int = 48000) -> np.ndarray: inputs = _processor(audio=audio, sampling_rate=sample_rate, return_tensors="pt") with torch.no_grad(): - embeddings = _model.get_audio_features(**inputs) + output = _model.get_audio_features(**inputs) - # Flatten to 1-D and normalize - emb = embeddings[0].numpy().flatten() + # transformers 5.x returns BaseModelOutputWithPooling; extract pooler_output + if hasattr(output, "pooler_output"): + emb = output.pooler_output[0].numpy() + else: + emb = output[0].numpy() + + # Normalize to unit vector emb = emb / np.linalg.norm(emb) return emb