#!/usr/bin/python3 # -*- coding: utf-8 -*- import os from typing import Optional, Union import torch import torch.nn as nn import torchaudio from toolbox.torchaudio.configuration_utils import CONFIG_FILE from toolbox.torchaudio.models.nx_mpnet.configuration_nx_mpnet import NXMPNetConfig from toolbox.torchaudio.models.nx_mpnet.utils import LearnableSigmoid1d class MetricDiscriminator(nn.Module): def __init__(self, config: NXMPNetConfig): super(MetricDiscriminator, self).__init__() dim = config.discriminator_dim self.in_channel = config.discriminator_in_channel self.n_fft = config.n_fft self.win_length = config.win_length self.hop_length = config.hop_length self.transform = torchaudio.transforms.Spectrogram( n_fft=self.n_fft, win_length=self.win_length, hop_length=self.hop_length, power=1.0, window_fn=torch.hann_window, # window_fn=torch.hamming_window if window_fn == "hamming" else torch.hann_window, ) self.layers = nn.Sequential( nn.utils.spectral_norm(nn.Conv2d(self.in_channel, dim, (4,4), (2,2), (1,1), bias=False)), nn.InstanceNorm2d(dim, affine=True), nn.PReLU(dim), nn.utils.spectral_norm(nn.Conv2d(dim, dim*2, (4,4), (2,2), (1,1), bias=False)), nn.InstanceNorm2d(dim*2, affine=True), nn.PReLU(dim*2), nn.utils.spectral_norm(nn.Conv2d(dim*2, dim*4, (4,4), (2,2), (1,1), bias=False)), nn.InstanceNorm2d(dim*4, affine=True), nn.PReLU(dim*4), nn.utils.spectral_norm(nn.Conv2d(dim*4, dim*8, (4,4), (2,2), (1,1), bias=False)), nn.InstanceNorm2d(dim*8, affine=True), nn.PReLU(dim*8), nn.AdaptiveMaxPool2d(1), nn.Flatten(), nn.utils.spectral_norm(nn.Linear(dim*8, dim*4)), nn.Dropout(0.3), nn.PReLU(dim*4), nn.utils.spectral_norm(nn.Linear(dim*4, 1)), LearnableSigmoid1d(1) ) def forward(self, x, y): x = self.transform.forward(x) y = self.transform.forward(y) xy = torch.stack((x, y), dim=1) return self.layers(xy) MODEL_FILE = "discriminator.pt" class MetricDiscriminatorPretrainedModel(MetricDiscriminator): def __init__(self, config: NXMPNetConfig, ): super(MetricDiscriminatorPretrainedModel, self).__init__( config=config, ) self.config = config @classmethod def from_pretrained(cls, pretrained_model_name_or_path, **kwargs): config = NXMPNetConfig.from_pretrained(pretrained_model_name_or_path, **kwargs) model = cls(config) if os.path.isdir(pretrained_model_name_or_path): ckpt_file = os.path.join(pretrained_model_name_or_path, MODEL_FILE) else: ckpt_file = pretrained_model_name_or_path with open(ckpt_file, "rb") as f: state_dict = torch.load(f, map_location="cpu", weights_only=True) model.load_state_dict(state_dict, strict=True) return model def save_pretrained(self, save_directory: Union[str, os.PathLike], state_dict: Optional[dict] = None, ): model = self if state_dict is None: state_dict = model.state_dict() os.makedirs(save_directory, exist_ok=True) # save state dict model_file = os.path.join(save_directory, MODEL_FILE) torch.save(state_dict, model_file) # save config config_file = os.path.join(save_directory, CONFIG_FILE) self.config.to_yaml_file(config_file) return save_directory def main(): config = NXMPNetConfig() discriminator = MetricDiscriminator(config=config) # shape: [batch_size, num_samples] # x = torch.ones([4, int(4.5 * 16000)]) # y = torch.ones([4, int(4.5 * 16000)]) x = torch.ones([4, 16000]) y = torch.ones([4, 16000]) output = discriminator.forward(x, y) print(output.shape) print(output) return if __name__ == "__main__": main()