HoneyTian's picture
update
a88ebd1
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import os
from typing import Optional, Union
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from pesq import pesq
from joblib import Parallel, delayed
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d
# def cal_pesq(clean, noisy, sr=16000):
# try:
# pesq_score = pesq(sr, clean, noisy, 'wb')
# except:
# # error can happen due to silent period
# pesq_score = -1
# return pesq_score
# def batch_pesq(clean, noisy):
# pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy))
# pesq_score = np.array(pesq_score)
# if -1 in pesq_score:
# return None
# pesq_score = (pesq_score - 1) / 3.5
# return torch.FloatTensor(pesq_score)
def metric_loss(metric_ref, metrics_gen):
loss = 0
for metric_gen in metrics_gen:
metric_loss = F.mse_loss(metric_ref, metric_gen.flatten())
loss += metric_loss
return loss
class MetricDiscriminator(nn.Module):
def __init__(self, config: MPNetConfig):
super(MetricDiscriminator, self).__init__()
dim = config.discriminator_dim
in_channel = config.discriminator_in_channel
self.layers = nn.Sequential(
nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
nn.InstanceNorm2d(dim, affine=True),
nn.PReLU(dim),
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
nn.InstanceNorm2d(dim*2, affine=True),
nn.PReLU(dim*2),
nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
nn.InstanceNorm2d(dim*4, affine=True),
nn.PReLU(dim*4),
nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
nn.InstanceNorm2d(dim*8, affine=True),
nn.PReLU(dim*8),
nn.AdaptiveMaxPool2d(1),
nn.Flatten(),
nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
nn.Dropout(0.3),
nn.PReLU(dim*4),
nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
LearnableSigmoid1d(1)
)
def forward(self, x, y):
xy = torch.stack((x, y), dim=1)
return self.layers(xy)
MODEL_FILE = "discriminator.pt"
class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
def __init__(self,
config: MPNetConfig,
):
super(MetricDiscriminatorPretrainedModel, self).__init__(
config=config,
)
self.config = config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = MPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
if __name__ == '__main__':
pass