Spaces:
Running
Running
File size: 4,095 Bytes
f74ae8e f69c753 f74ae8e f69c753 04e3488 f74ae8e a88ebd1 f74ae8e f69c753 f74ae8e f69c753 f74ae8e f69c753 20d2f3e f69c753 f74ae8e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import os
from typing import Optional, Union
import torch
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from pesq import pesq
from joblib import Parallel, delayed
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.mpnet.configuration_mpnet import MPNetConfig
from toolbox.torchaudio.models.mpnet.utils import LearnableSigmoid1d
# def cal_pesq(clean, noisy, sr=16000):
# try:
# pesq_score = pesq(sr, clean, noisy, 'wb')
# except:
# # error can happen due to silent period
# pesq_score = -1
# return pesq_score
# def batch_pesq(clean, noisy):
# pesq_score = Parallel(n_jobs=15)(delayed(cal_pesq)(c, n) for c, n in zip(clean, noisy))
# pesq_score = np.array(pesq_score)
# if -1 in pesq_score:
# return None
# pesq_score = (pesq_score - 1) / 3.5
# return torch.FloatTensor(pesq_score)
def metric_loss(metric_ref, metrics_gen):
loss = 0
for metric_gen in metrics_gen:
metric_loss = F.mse_loss(metric_ref, metric_gen.flatten())
loss += metric_loss
return loss
class MetricDiscriminator(nn.Module):
def __init__(self, config: MPNetConfig):
super(MetricDiscriminator, self).__init__()
dim = config.discriminator_dim
in_channel = config.discriminator_in_channel
self.layers = nn.Sequential(
nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
nn.InstanceNorm2d(dim, affine=True),
nn.PReLU(dim),
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
nn.InstanceNorm2d(dim*2, affine=True),
nn.PReLU(dim*2),
nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
nn.InstanceNorm2d(dim*4, affine=True),
nn.PReLU(dim*4),
nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
nn.InstanceNorm2d(dim*8, affine=True),
nn.PReLU(dim*8),
nn.AdaptiveMaxPool2d(1),
nn.Flatten(),
nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
nn.Dropout(0.3),
nn.PReLU(dim*4),
nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
LearnableSigmoid1d(1)
)
def forward(self, x, y):
xy = torch.stack((x, y), dim=1)
return self.layers(xy)
MODEL_FILE = "discriminator.pt"
class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
def __init__(self,
config: MPNetConfig,
):
super(MetricDiscriminatorPretrainedModel, self).__init__(
config=config,
)
self.config = config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = MPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
if __name__ == '__main__':
pass
|