Spaces:
Running
on
Zero
Running
on
Zero
### | |
# Author: Kai Li | |
# Date: 2021-06-22 12:41:36 | |
# LastEditors: Please set LastEditors | |
# LastEditTime: 2022-06-05 14:48:00 | |
### | |
import csv | |
from sympy import im | |
import torch | |
import numpy as np | |
import logging | |
import os | |
import librosa | |
from torch_mir_eval.separation import bss_eval_sources | |
import fast_bss_eval | |
from visqol import visqol_lib_py | |
from visqol.pb2 import visqol_config_pb2 | |
from visqol.pb2 import similarity_result_pb2 | |
logger = logging.getLogger(__name__) | |
def is_silent(wav, threshold=1e-4): | |
return torch.sum(wav ** 2) / wav.numel() < threshold | |
class MetricsTracker: | |
def __init__(self, save_file: str = ""): | |
self.all_sdrs = [] | |
self.all_sisnrs = [] | |
self.all_visqols = [] | |
csv_columns = ["snt_id", "sdr", "si-snr", "visqol"] | |
self.visqol_config = visqol_config_pb2.VisqolConfig() | |
self.visqol_config.audio.sample_rate = 48000 | |
self.visqol_config.options.use_speech_scoring = False | |
svr_model_path = "libsvm_nu_svr_model.txt" | |
self.visqol_config.options.svr_model_path = os.path.join(os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path) | |
self.visqol_api = visqol_lib_py.VisqolApi() | |
self.visqol_api.Create(self.visqol_config) | |
self.results_csv = open(save_file, "w") | |
self.writer = csv.DictWriter(self.results_csv, fieldnames=csv_columns) | |
self.writer.writeheader() | |
def __call__(self, clean, estimate, key): | |
sisnr = fast_bss_eval.si_sdr(clean.unsqueeze(0), estimate.unsqueeze(0), zero_mean=True).mean() | |
sdr = fast_bss_eval.sdr(clean.unsqueeze(0), estimate.unsqueeze(0), zero_mean=True).mean() | |
clean = librosa.resample(clean.squeeze(0).mean(0).cpu().numpy(), orig_sr=44100, target_sr=48000).astype(np.float64) | |
estimate = librosa.resample(estimate.squeeze(0).mean(0).cpu().numpy(), orig_sr=44100, target_sr=48000).astype(np.float64) | |
visqol = self.visqol_api.Measure(clean, estimate).moslqo | |
# import pdb; pdb.set_trace() | |
row = { | |
"snt_id": key, | |
"sdr": sdr.item(), | |
"si-snr": sisnr.item(), | |
"visqol": visqol | |
} | |
self.writer.writerow(row) | |
# Metric Accumulation | |
self.all_sdrs.append(sdr.item()) | |
self.all_sisnrs.append(sisnr.item()) | |
self.all_visqols.append(visqol) | |
def update(self, ): | |
return {"sdr": np.array(self.all_sdrs).mean(), | |
"si-snr": np.array(self.all_sisnrs).mean(), | |
"visqol": np.array(self.all_visqols).mean()} | |
def final(self,): | |
row = { | |
"snt_id": "avg", | |
"sdr": np.array(self.all_sdrs).mean(), | |
"si-snr": np.array(self.all_sisnrs).mean(), | |
"visqol": np.array(self.all_visqols).mean() | |
} | |
self.writer.writerow(row) | |
row = { | |
"snt_id": "std", | |
"sdr": np.array(self.all_sdrs).std(), | |
"si-snr": np.array(self.all_sisnrs).std(), | |
"visqol": np.array(self.all_visqols).std() | |
} | |
self.writer.writerow(row) | |
self.results_csv.close() | |