Last commit not found
import torch | |
import torchaudio | |
import torch.nn.functional as F | |
from pesq import pesq | |
from pystoi import stoi | |
from .other import si_sdr, pad_spec | |
# Settings | |
sr = 8000 | |
snr = 0.5 | |
N = 30 | |
corrector_steps = 1 | |
def evaluate_model(model, num_eval_files): | |
clean_files = model.data_module.valid_set.clean_files | |
noisy_files = model.data_module.valid_set.noisy_files | |
mixture_files = model.data_module.valid_set.mixture_files | |
# Select test files uniformly accros validation files | |
total_num_files = len(clean_files) | |
indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int) | |
clean_files = list(clean_files[i] for i in indices) | |
noisy_files = list(noisy_files[i] for i in indices) | |
mixture_files = list(mixture_files[i] for i in indices) | |
_pesq = 0 | |
_si_sdr = 0 | |
_estoi = 0 | |
# iterate over files | |
for (clean_file, noisy_file, mixture_file) in zip(clean_files, noisy_files, mixture_files): | |
# Load wavs | |
x, sr_ = torchaudio.load(clean_file) | |
if sr_ != sr: | |
x = torchaudio.transforms.Resample(sr_, sr)(x) | |
y, sr_ = torchaudio.load(noisy_file) | |
if sr_ != sr: | |
y = torchaudio.transforms.Resample(sr_, sr)(y) | |
m, sr_ = torchaudio.load(mixture_file) | |
if sr_ != sr: | |
m = torchaudio.transforms.Resample(sr_, sr)(m) | |
min_leng = min(x.shape[-1],y.shape[-1],m.shape[-1]) | |
x = x[...,:min_leng] | |
y = y[...,:min_leng] | |
m = m[...,:min_leng] | |
T_orig = x.size(1) | |
# Normalize per utterance | |
norm_factor = y.abs().max() | |
y = y / norm_factor | |
m = m / norm_factor | |
# Prepare DNN input | |
Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0) | |
Y = pad_spec(Y) | |
M = torch.unsqueeze(model._forward_transform(model._stft(m.cuda())), 0) | |
M = pad_spec(M) | |
y = y * norm_factor | |
# print(x.shape,y.shape,m.shape,Y.shape,M.shape) | |
# Reverse sampling | |
sampler = model.get_pc_sampler( | |
'reverse_diffusion', 'ald', Y.cuda(), M.cuda(), N=N, | |
corrector_steps=corrector_steps, snr=snr) | |
sample, _ = sampler() | |
sample = sample.squeeze() | |
x_hat = model.to_audio(sample.squeeze(), T_orig) | |
x_hat = x_hat * norm_factor | |
x_hat = x_hat.squeeze().cpu().numpy() | |
x = x.squeeze().cpu().numpy() | |
y = y.squeeze().cpu().numpy() | |
_si_sdr += si_sdr(x, x_hat) | |
_pesq += pesq(sr, x, x_hat, 'nb') | |
_estoi += stoi(x, x_hat, sr, extended=True) | |
return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files | |
def evaluate_model2(model, num_eval_files, inference_N, inference_start=0.5): | |
N = inference_N | |
reverse_start_time = inference_start | |
clean_files = model.data_module.valid_set.clean_files | |
noisy_files = model.data_module.valid_set.noisy_files | |
mixture_files = model.data_module.valid_set.mixture_files | |
# Select test files uniformly accros validation files | |
total_num_files = len(clean_files) | |
indices = torch.linspace(0, total_num_files-1, num_eval_files, dtype=torch.int) | |
clean_files = list(clean_files[i] for i in indices) | |
noisy_files = list(noisy_files[i] for i in indices) | |
mixture_files = list(mixture_files[i] for i in indices) | |
_pesq = 0 | |
_si_sdr = 0 | |
_estoi = 0 | |
# iterate over files | |
for (clean_file, noisy_file, mixture_file) in zip(clean_files, noisy_files, mixture_files): | |
# Load wavs | |
x, sr_ = torchaudio.load(clean_file) | |
if sr_ != sr: | |
x = torchaudio.transforms.Resample(sr_, sr)(x) | |
y, sr_ = torchaudio.load(noisy_file) | |
if sr_ != sr: | |
y = torchaudio.transforms.Resample(sr_, sr)(y) | |
m, sr_ = torchaudio.load(mixture_file) | |
if sr_ != sr: | |
m = torchaudio.transforms.Resample(sr_, sr)(m) | |
#requires only for BWE as the dataset has different length of clean and noisy files | |
min_leng = min(x.shape[-1],y.shape[-1],m.shape[-1]) | |
x = x[...,:min_leng] | |
y = y[...,:min_leng] | |
m = m[...,:min_leng] | |
T_orig = x.size(1) | |
# Normalize per utterance | |
norm_factor = y.abs().max() | |
y = y / norm_factor | |
x = x / norm_factor | |
m = m / norm_factor | |
# Prepare DNN input | |
Y = torch.unsqueeze(model._forward_transform(model._stft(y.cuda())), 0) | |
Y = pad_spec(Y) | |
X = torch.unsqueeze(model._forward_transform(model._stft(x.cuda())), 0) | |
X = pad_spec(X) | |
M = torch.unsqueeze(model._forward_transform(model._stft(m.cuda())), 0) | |
M = pad_spec(M) | |
y = y * norm_factor | |
x = x * norm_factor | |
x = x.squeeze().cpu().numpy() | |
y = y.squeeze().cpu().numpy() | |
total_loss = 0 | |
timesteps = torch.linspace(reverse_start_time, 0.03, N, device=Y.device) | |
#prior sampling starting from reverse_start_time | |
std = model.sde._std(reverse_start_time*torch.ones((Y.shape[0],), device=Y.device)) | |
z = torch.randn_like(Y) | |
X_t = Y + z * std[:, None, None, None] | |
#reverse steps by Euler Maruyama | |
for i in range(len(timesteps)): | |
t = timesteps[i] | |
if i != len(timesteps) - 1: | |
dt = t - timesteps[i+1] | |
else: | |
dt = timesteps[-1] | |
with torch.no_grad(): | |
#take Euler step here | |
f, g = model.sde.sde(X_t, t, Y) | |
vec_t = torch.ones(Y.shape[0], device=Y.device) * t | |
score = model.forward(X_t, vec_t, Y, M, vec_t[:,None,None,None]) | |
mean_x_tm1 = X_t - (f - g**2*score)*dt #mean of x t minus 1 = mu(x_{t-1}) | |
if i == len(timesteps) - 1: #output | |
X_t = mean_x_tm1 | |
break | |
z = torch.randn_like(X) | |
X_t = mean_x_tm1 + z*g*torch.sqrt(dt) | |
sample = X_t | |
sample = sample.squeeze() | |
x_hat = model.to_audio(sample.squeeze(), T_orig) | |
x_hat = x_hat * norm_factor | |
x_hat = x_hat.squeeze().cpu().numpy() | |
_si_sdr += si_sdr(x, x_hat) | |
_pesq += pesq(sr, x, x_hat, 'nb') | |
_estoi += stoi(x, x_hat, sr, extended=True) | |
return _pesq/num_eval_files, _si_sdr/num_eval_files, _estoi/num_eval_files, total_loss/num_eval_files | |
def convert_to_audio(X, deemp, T_orig, model, norm_factor): | |
sample = X | |
sample = sample.squeeze() | |
if len(sample.shape)==4: | |
sample = sample*deemp[None, None, :, None].to(device=sample.device) | |
elif len(sample.shape)==3: | |
sample = sample*deemp[None, :, None].to(device=sample.device) | |
else: | |
sample = sample*deemp[:, None].to(device=sample.device) | |
x_hat = model.to_audio(sample.squeeze(), T_orig) | |
x_hat = x_hat * norm_factor | |
x_hat = x_hat.squeeze().cpu().numpy() | |
return x_hat |