Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import os | |
from typing import Optional, Union | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import torch.nn.functional as F | |
from pesq import pesq | |
from joblib import Parallel, delayed | |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE | |
from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig | |
from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid1d | |
class MetricDiscriminator(nn.Module): | |
def __init__(self, config: NXMPNetConfig): | |
super(MetricDiscriminator, self).__init__() | |
dim = config.discriminator_dim | |
in_channel = config.discriminator_in_channel | |
self.layers = nn.Sequential( | |
nn.utils.spectral_norm(nn.Conv2d(in_channel, dim, (4,4), (2,2), (1,1), bias=False)), | |
nn.InstanceNorm2d(dim, affine=True), | |
nn.PReLU(dim), | |
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), | |
nn.InstanceNorm2d(dim*2, affine=True), | |
nn.PReLU(dim*2), | |
nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), | |
nn.InstanceNorm2d(dim*4, affine=True), | |
nn.PReLU(dim*4), | |
nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), | |
nn.InstanceNorm2d(dim*8, affine=True), | |
nn.PReLU(dim*8), | |
nn.AdaptiveMaxPool2d(1), | |
nn.Flatten(), | |
nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), | |
nn.Dropout(0.3), | |
nn.PReLU(dim*4), | |
nn.utils.spectral_norm(nn.Linear(dim*4, 1)), | |
LearnableSigmoid1d(1) | |
) | |
def forward(self, x, y): | |
xy = torch.stack((x, y), dim=1) | |
return self.layers(xy) | |
MODEL_FILE = "discriminator.pt" | |
class MetricDiscriminatorPretrainedModel(MetricDiscriminator): | |
def __init__(self, | |
config: NXMPNetConfig, | |
): | |
super(MetricDiscriminatorPretrainedModel, self).__init__( | |
config=config, | |
) | |
self.config = config | |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | |
config = NXMPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) | |
model = cls(config) | |
if os.path.isdir(pretrained_model_name_or_path): | |
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) | |
else: | |
ckpt_file = pretrained_model_name_or_path | |
with open(ckpt_file, "rb") as f: | |
state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
model.load_state_dict(state_dict, strict=True) | |
return model | |
def save_pretrained(self, | |
save_directory: Union[str, os.PathLike], | |
state_dict: Optional[dict] = None, | |
): | |
model = self | |
if state_dict is None: | |
state_dict = model.state_dict() | |
os.makedirs(save_directory, exist_ok=True) | |
# save state dict | |
model_file = os.path.join(save_directory, MODEL_FILE) | |
torch.save(state_dict, model_file) | |
# save config | |
config_file = os.path.join(save_directory, CONFIG_FILE) | |
self.config.to_yaml_file(config_file) | |
return save_directory | |
if __name__ == '__main__': | |
pass | |