diff --git a/scripts/evaluate.py b/scripts/evaluate.py index 07cdc3c..e0af2ae 100644 --- a/scripts/evaluate.py +++ b/scripts/evaluate.py @@ -6,8 +6,7 @@ import argbind import torch -from audiotools import AudioSignal -from audiotools import metrics +from audiotools import AudioSignal, metrics from audiotools.core import util from audiotools.ml.decorators import Tracker from train import losses @@ -34,7 +33,7 @@ def get_metrics(signal_path, recons_path, state): f"mel-{k}": state.mel_loss(x, y), f"stft-{k}": state.stft_loss(x, y), f"waveform-{k}": state.waveform_loss(x, y), - f"sisdr-{k}": state.sisdr_loss(x, y), + f"sisdr-{k}": -state.sisdr_loss(x, y), f"visqol-audio-{k}": metrics.quality.visqol(x, y), f"visqol-speech-{k}": metrics.quality.visqol(x, y, "speech"), }