Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import os | |
from typing import Optional, Union | |
import torch | |
import torch.nn as nn | |
import torchaudio | |
from toolbox.torchaudio.configuration_utils import CONFIG_FILE | |
from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig | |
from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid1d | |
class MetricDiscriminator(nn.Module): | |
def __init__(self, config: NXMPNetConfig): | |
super(MetricDiscriminator, self).__init__() | |
dim = config.discriminator_dim | |
self.in_channel = config.discriminator_in_channel | |
self.n_fft = config.n_fft | |
self.win_length = config.win_length | |
self.hop_length = config.hop_length | |
self.transform = torchaudio.transforms.Spectrogram( | |
n_fft=self.n_fft, | |
win_length=self.win_length, | |
hop_length=self.hop_length, | |
power=1.0, | |
window_fn=torch.hann_window, | |
# window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, | |
) | |
self.layers = nn.Sequential( | |
nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)), | |
nn.InstanceNorm2d(dim, affine=True), | |
nn.PReLU(dim), | |
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), | |
nn.InstanceNorm2d(dim*2, affine=True), | |
nn.PReLU(dim*2), | |
nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), | |
nn.InstanceNorm2d(dim*4, affine=True), | |
nn.PReLU(dim*4), | |
nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), | |
nn.InstanceNorm2d(dim*8, affine=True), | |
nn.PReLU(dim*8), | |
nn.AdaptiveMaxPool2d(1), | |
nn.Flatten(), | |
nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), | |
nn.Dropout(0.3), | |
nn.PReLU(dim*4), | |
nn.utils.spectral_norm(nn.Linear(dim*4, 1)), | |
LearnableSigmoid1d(1) | |
) | |
def forward(self, x, y): | |
x = self.transform.forward(x) | |
y = self.transform.forward(y) | |
xy = torch.stack((x, y), dim=1) | |
return self.layers(xy) | |
MODEL_FILE = "discriminator.pt" | |
class MetricDiscriminatorPretrainedModel(MetricDiscriminator): | |
def __init__(self, | |
config: NXMPNetConfig, | |
): | |
super(MetricDiscriminatorPretrainedModel, self).__init__( | |
config=config, | |
) | |
self.config = config | |
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): | |
config = NXMPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) | |
model = cls(config) | |
if os.path.isdir(pretrained_model_name_or_path): | |
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) | |
else: | |
ckpt_file = pretrained_model_name_or_path | |
with open(ckpt_file, "rb") as f: | |
state_dict = torch.load(f, map_location="cpu", weights_only=True) | |
model.load_state_dict(state_dict, strict=True) | |
return model | |
def save_pretrained(self, | |
save_directory: Union[str, os.PathLike], | |
state_dict: Optional[dict] = None, | |
): | |
model = self | |
if state_dict is None: | |
state_dict = model.state_dict() | |
os.makedirs(save_directory, exist_ok=True) | |
# save state dict | |
model_file = os.path.join(save_directory, MODEL_FILE) | |
torch.save(state_dict, model_file) | |
# save config | |
config_file = os.path.join(save_directory, CONFIG_FILE) | |
self.config.to_yaml_file(config_file) | |
return save_directory | |
def main(): | |
config = NXMPNetConfig() | |
discriminator = MetricDiscriminator(config=config) | |
# shape: [batch_size, num_samples] | |
# x = torch.ones([4, int(4.5 * 16000)]) | |
# y = torch.ones([4, int(4.5 * 16000)]) | |
x = torch.ones([4, 16000]) | |
y = torch.ones([4, 16000]) | |
output = discriminator.forward(x, y) | |
print(output.shape) | |
print(output) | |
return | |
if __name__ == "__main__": | |
main() | |