#!/usr/bin/python3 # -*- coding: utf-8 -*- import os from typing import Optional, Union import torch import torch.nn as nn import numpy as np import torch.nn.functional as F from pesq import pesq from joblib import Parallel, delayed from toolbox.torchaudio.configuration_utils import CONFIG_FILE from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d # def cal_pesq(clean, noisy, sr=16000): # try: # pesq_score = pesq(sr, clean, noisy, 'wb') # except: # # error can happen due to silent period # pesq_score = -1 # return pesq_score # def batch_pesq(clean, noisy): # pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy)) # pesq_score = np.array(pesq_score) # if -1 in pesq_score: # return None # pesq_score = (pesq_score - 1) / 3.5 # return torch.FloatTensor(pesq_score) def metric_loss(metric_ref, metrics_gen): loss = 0 for metric_gen in metrics_gen: metric_loss = F.mse_loss(metric_ref, metric_gen.flatten()) loss += metric_loss return loss class MetricDiscriminator(nn.Module): def __init__(self, config: MPNetConfig): super(MetricDiscriminator, self).__init__() dim = config.discriminator_dim in_channel = config.discriminator_in_channel self.layers = nn.Sequential( nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)), nn.InstanceNorm2d(dim, affine=True), nn.PReLU(dim), nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), nn.InstanceNorm2d(dim*2, affine=True), nn.PReLU(dim*2), nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), nn.InstanceNorm2d(dim*4, affine=True), nn.PReLU(dim*4), nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), nn.InstanceNorm2d(dim*8, affine=True), nn.PReLU(dim*8), nn.AdaptiveMaxPool2d(1), nn.Flatten(), nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), nn.Dropout(0.3), nn.PReLU(dim*4), nn.utils.spectral_norm(nn.Linear(dim*4, 1)), LearnableSigmoid1d(1) ) def forward(self, x, y): xy = torch.stack((x, y), dim=1) return self.layers(xy) MODEL_FILE = "discriminator.pt" class MetricDiscriminatorPretrainedModel(MetricDiscriminator): def __init__(self, config: MPNetConfig, ): super(MetricDiscriminatorPretrainedModel, self).__init__( config=config, ) self.config = config @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = MPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) if os.path.isdir(pretrained_model_name_or_path): ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) else: ckpt_file = pretrained_model_name_or_path with open(ckpt_file, "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) return model def save_pretrained(self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, ): model = self if state_dict is None: state_dict = model.state_dict() os.makedirs(save_directory, exist_ok=True) # save state dict model_file = os.path.join(save_directory, MODEL_FILE) torch.save(state_dict, model_file) # save config config_file = os.path.join(save_directory, CONFIG_FILE) self.config.to_yaml_file(config_file) return save_directory if __name__ == '__main__': pass