#!/usr/bin/python3 # -*- coding: utf-8 -*- import torch import torch.nn as nn import numpy as np import torch.nn.functional as F from pesq import pesq from joblib import Parallel, delayed from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d def cal_pesq(clean, noisy, sr=16000): try: pesq_score = pesq(sr, clean, noisy, 'wb') except: # error can happen due to silent period pesq_score = -1 return pesq_score def batch_pesq(clean, noisy): pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy)) pesq_score = np.array(pesq_score) if -1 in pesq_score: return None pesq_score = (pesq_score - 1) / 3.5 return torch.FloatTensor(pesq_score) def metric_loss(metric_ref, metrics_gen): loss = 0 for metric_gen in metrics_gen: metric_loss = F.mse_loss(metric_ref, metric_gen.flatten()) loss += metric_loss return loss class MetricDiscriminator(nn.Module): def __init__(self, dim=16, in_channel=2): super(MetricDiscriminator, self).__init__() self.layers = nn.Sequential( nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)), nn.InstanceNorm2d(dim, affine=True), nn.PReLU(dim), nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), nn.InstanceNorm2d(dim*2, affine=True), nn.PReLU(dim*2), nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), nn.InstanceNorm2d(dim*4, affine=True), nn.PReLU(dim*4), nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), nn.InstanceNorm2d(dim*8, affine=True), nn.PReLU(dim*8), nn.AdaptiveMaxPool2d(1), nn.Flatten(), nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), nn.Dropout(0.3), nn.PReLU(dim*4), nn.utils.spectral_norm(nn.Linear(dim*4, 1)), LearnableSigmoid1d(1) ) def forward(self, x, y): xy = torch.stack((x, y), dim=1) return self.layers(xy) if __name__ == '__main__': pass