Spaces:
Running
Running
File size: 4,231 Bytes
33aff71 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import os
from typing import Optional, Union
import torch
import torch.nn as nn
import torchaudio
from toolbox.torchaudio.configuration_utils import CONFIG_FILE
from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig
from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid1d
class MetricDiscriminator(nn.Module):
def __init__(self, config: NXMPNetConfig):
super(MetricDiscriminator, self).__init__()
dim = config.discriminator_dim
self.in_channel = config.discriminator_in_channel
self.n_fft = config.n_fft
self.win_length = config.win_length
self.hop_length = config.hop_length
self.transform = torchaudio.transforms.Spectrogram(
n_fft=self.n_fft,
win_length=self.win_length,
hop_length=self.hop_length,
power=1.0,
window_fn=torch.hann_window,
# window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window,
)
self.layers = nn.Sequential(
nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)),
nn.InstanceNorm2d(dim, affine=True),
nn.PReLU(dim),
nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)),
nn.InstanceNorm2d(dim*2, affine=True),
nn.PReLU(dim*2),
nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)),
nn.InstanceNorm2d(dim*4, affine=True),
nn.PReLU(dim*4),
nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)),
nn.InstanceNorm2d(dim*8, affine=True),
nn.PReLU(dim*8),
nn.AdaptiveMaxPool2d(1),
nn.Flatten(),
nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)),
nn.Dropout(0.3),
nn.PReLU(dim*4),
nn.utils.spectral_norm(nn.Linear(dim*4, 1)),
LearnableSigmoid1d(1)
)
def forward(self, x, y):
x = self.transform.forward(x)
y = self.transform.forward(y)
xy = torch.stack((x, y), dim=1)
return self.layers(xy)
MODEL_FILE = "discriminator.pt"
class MetricDiscriminatorPretrainedModel(MetricDiscriminator):
def __init__(self,
config: NXMPNetConfig,
):
super(MetricDiscriminatorPretrainedModel, self).__init__(
config=config,
)
self.config = config
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, **kwargs):
config = NXMPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs)
model = cls(config)
if os.path.isdir(pretrained_model_name_or_path):
ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE)
else:
ckpt_file = pretrained_model_name_or_path
with open(ckpt_file, "rb") as f:
state_dict = torch.load(f, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict, strict=True)
return model
def save_pretrained(self,
save_directory: Union[str, os.PathLike],
state_dict: Optional[dict] = None,
):
model = self
if state_dict is None:
state_dict = model.state_dict()
os.makedirs(save_directory, exist_ok=True)
# save state dict
model_file = os.path.join(save_directory, MODEL_FILE)
torch.save(state_dict, model_file)
# save config
config_file = os.path.join(save_directory, CONFIG_FILE)
self.config.to_yaml_file(config_file)
return save_directory
def main():
config = NXMPNetConfig()
discriminator = MetricDiscriminator(config=config)
# shape: [batch_size, num_samples]
# x = torch.ones([4, int(4.5 * 16000)])
# y = torch.ones([4, int(4.5 * 16000)])
x = torch.ones([4, 16000])
y = torch.ones([4, 16000])
output = discriminator.forward(x, y)
print(output.shape)
print(output)
return
if __name__ == "__main__":
main()
|