Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 44 additions & 27 deletions mteb/models/model_implementations/msclap_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import logging
import tempfile
import warnings
from pathlib import Path
from typing import Any

import numpy as np
Expand Down Expand Up @@ -58,6 +60,7 @@ def get_audio_embeddings(
show_progress_bar: bool = True,
**kwargs: Any,
) -> np.ndarray:
import soundfile as sf
import torchaudio

all_embeddings = []
Expand All @@ -66,34 +69,48 @@ def get_audio_embeddings(
inputs,
disable=not show_progress_bar,
):
audio_arrays = []
for a in batch["audio"]:
array = torch.tensor(a["array"], dtype=torch.float32)
sr = a.get("sampling_rate", None)
if sr is None:
warnings.warn(
f"No sampling_rate provided for an audio sample. "
f"Assuming {self.sampling_rate} Hz (model default)."
temp_files = []
try:
for a in batch["audio"]:
array = torch.tensor(a["array"], dtype=torch.float32)
sr = a.get("sampling_rate", None)
if sr is None:
warnings.warn(
f"No sampling_rate provided for an audio sample. "
f"Assuming {self.sampling_rate} Hz (model default)."
)
sr = self.sampling_rate

if sr != self.sampling_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=sr, new_freq=self.sampling_rate
)
array = resampler(array)

# Write to temp file - msclap expects file paths
temp_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
temp_files.append(temp_file.name)
sf.write(temp_file.name, array.numpy(), self.sampling_rate)

with torch.no_grad():
# Use the official msclap API that expects file paths
# https://github.com/microsoft/CLAP#api
audio_features = self.model.get_audio_embeddings(
temp_files, resample=False
)
sr = self.sampling_rate

if sr != self.sampling_rate:
resampler = torchaudio.transforms.Resample(
orig_freq=sr, new_freq=self.sampling_rate
# Normalize embeddings
audio_features = audio_features / audio_features.norm(
dim=-1, keepdim=True
)
array = resampler(array)
audio_arrays.append(array.numpy())

with torch.no_grad():
# Use the internal audio encoder directly
# [0] gives audio embeddings, [1] gives class probabilities
audio_features = self.model.clap.audio_encoder(audio_arrays)[0]
all_embeddings.append(audio_features.cpu().detach().numpy())
finally:
# Clean up temp files

# Normalize embeddings
audio_features = audio_features / audio_features.norm(
dim=-1, keepdim=True
)
all_embeddings.append(audio_features.cpu().detach().numpy())
for f in temp_files:
try:
Path(f).unlink()
except OSError:
pass

return np.vstack(all_embeddings)

Expand Down Expand Up @@ -162,7 +179,7 @@ def encode(
loader=MSClapWrapper,
name="microsoft/msclap-2022",
languages=["eng-Latn"],
revision="N/A",
revision="no_revision",
release_date="2022-12-01",
modalities=["audio", "text"],
n_parameters=196_000_000,
Expand All @@ -184,7 +201,7 @@ def encode(
loader=MSClapWrapper,
name="microsoft/msclap-2023",
languages=["eng-Latn"],
revision="N/A",
revision="no_revision",
release_date="2023-09-01",
modalities=["audio", "text"],
n_parameters=160_000_000,
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ youtu = ["tencentcloud-sdk-python-common>=3.0.1454", "tencentcloud-sdk-python-lk
llama-embed-nemotron = ["transformers==4.51.0"]
faiss-cpu = ["faiss-cpu>=1.12.0"]
eager_embed = ["qwen_vl_utils>=0.0.14"]
speechbrain = ["speechbrain>=0.5.12", "torchaudio>=2.6.0,<2.8"]
speechbrain = ["speechbrain>=0.5.12; python_full_version < '3.14'", "torchaudio>=2.6.0,<2.8; python_full_version < '3.14'"]
muq = ["muq==0.1.0"]
wav2clip = ["wav2clip==0.1.0"]
torch-vggish-yamnet = ["torch-vggish-yamnet==0.2.1"]
Expand Down
Loading
Loading