Apollo / look2hear /metrics /wrapper.py
Serhiy Stetskovych
Initial code
78e32cc
###
# Author: Kai Li
# Date: 2021-06-22 12:41:36
# LastEditors: Please set LastEditors
# LastEditTime: 2022-06-05 14:48:00
###
import csv
from sympy import im
import torch
import numpy as np
import logging
import os
import librosa
from torch_mir_eval.separation import bss_eval_sources
import fast_bss_eval
from visqol import visqol_lib_py
from visqol.pb2 import visqol_config_pb2
from visqol.pb2 import similarity_result_pb2
logger = logging.getLogger(__name__)
def is_silent(wav, threshold=1e-4):
return torch.sum(wav ** 2) / wav.numel() < threshold
class MetricsTracker:
def __init__(self, save_file: str = ""):
self.all_sdrs = []
self.all_sisnrs = []
self.all_visqols = []
csv_columns = ["snt_id", "sdr", "si-snr", "visqol"]
self.visqol_config = visqol_config_pb2.VisqolConfig()
self.visqol_config.audio.sample_rate = 48000
self.visqol_config.options.use_speech_scoring = False
svr_model_path = "libsvm_nu_svr_model.txt"
self.visqol_config.options.svr_model_path = os.path.join(os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path)
self.visqol_api = visqol_lib_py.VisqolApi()
self.visqol_api.Create(self.visqol_config)
self.results_csv = open(save_file, "w")
self.writer = csv.DictWriter(self.results_csv, fieldnames=csv_columns)
self.writer.writeheader()
def __call__(self, clean, estimate, key):
sisnr = fast_bss_eval.si_sdr(clean.unsqueeze(0), estimate.unsqueeze(0), zero_mean=True).mean()
sdr = fast_bss_eval.sdr(clean.unsqueeze(0), estimate.unsqueeze(0), zero_mean=True).mean()
clean = librosa.resample(clean.squeeze(0).mean(0).cpu().numpy(), orig_sr=44100, target_sr=48000).astype(np.float64)
estimate = librosa.resample(estimate.squeeze(0).mean(0).cpu().numpy(), orig_sr=44100, target_sr=48000).astype(np.float64)
visqol = self.visqol_api.Measure(clean, estimate).moslqo
# import pdb; pdb.set_trace()
row = {
"snt_id": key,
"sdr": sdr.item(),
"si-snr": sisnr.item(),
"visqol": visqol
}
self.writer.writerow(row)
# Metric Accumulation
self.all_sdrs.append(sdr.item())
self.all_sisnrs.append(sisnr.item())
self.all_visqols.append(visqol)
def update(self, ):
return {"sdr": np.array(self.all_sdrs).mean(),
"si-snr": np.array(self.all_sisnrs).mean(),
"visqol": np.array(self.all_visqols).mean()}
def final(self,):
row = {
"snt_id": "avg",
"sdr": np.array(self.all_sdrs).mean(),
"si-snr": np.array(self.all_sisnrs).mean(),
"visqol": np.array(self.all_visqols).mean()
}
self.writer.writerow(row)
row = {
"snt_id": "std",
"sdr": np.array(self.all_sdrs).std(),
"si-snr": np.array(self.all_sisnrs).std(),
"visqol": np.array(self.all_visqols).std()
}
self.writer.writerow(row)
self.results_csv.close()