Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import torch.nn.functional as F | |
from pesq import pesq | |
from joblib import Parallel, delayed | |
from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d | |
def cal_pesq(clean, noisy, sr=16000): | |
try: | |
pesq_score = pesq(sr, clean, noisy, 'wb') | |
except: | |
# error can happen due to silent period | |
pesq_score = -1 | |
return pesq_score | |
def batch_pesq(clean, noisy): | |
pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy)) | |
pesq_score = np.array(pesq_score) | |
if -1 in pesq_score: | |
return None | |
pesq_score = (pesq_score - 1) / 3.5 | |
return torch.FloatTensor(pesq_score) | |
def metric_loss(metric_ref, metrics_gen): | |
loss = 0 | |
for metric_gen in metrics_gen: | |
metric_loss = F.mse_loss(metric_ref, metric_gen.flatten()) | |
loss += metric_loss | |
return loss | |
class MetricDiscriminator(nn.Module): | |
def __init__(self, dim=16, in_channel=2): | |
super(MetricDiscriminator, self).__init__() | |
self.layers = nn.Sequential( | |
nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)), | |
nn.InstanceNorm2d(dim, affine=True), | |
nn.PReLU(dim), | |
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), | |
nn.InstanceNorm2d(dim*2, affine=True), | |
nn.PReLU(dim*2), | |
nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), | |
nn.InstanceNorm2d(dim*4, affine=True), | |
nn.PReLU(dim*4), | |
nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), | |
nn.InstanceNorm2d(dim*8, affine=True), | |
nn.PReLU(dim*8), | |
nn.AdaptiveMaxPool2d(1), | |
nn.Flatten(), | |
nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), | |
nn.Dropout(0.3), | |
nn.PReLU(dim*4), | |
nn.utils.spectral_norm(nn.Linear(dim*4, 1)), | |
LearnableSigmoid1d(1) | |
) | |
def forward(self, x, y): | |
xy = torch.stack((x, y), dim=1) | |
return self.layers(xy) | |
if __name__ == '__main__': | |
pass | |