from typing import Any, Tuple import torch from torch import nn from torch.nn import Module import torch.nn.functional as F from torch.nn.utils.parametrizations import spectral_norm, weight_norm from models.config import VocoderModelConfig class DiscriminatorR(Module): r"""A class representing the Residual Discriminator network for a UnivNet vocoder. Args: resolution (Tuple): A tuple containing the number of FFT points, hop length, and window length. model_config (VocoderModelConfig): A configuration object for the UnivNet model. """ def __init__( self, resolution: Tuple[int, int, int], model_config: VocoderModelConfig, ): super().__init__() self.resolution = resolution self.LRELU_SLOPE = model_config.mrd.lReLU_slope # Use spectral normalization or weight normalization based on the configuration norm_f: Any = ( spectral_norm if model_config.mrd.use_spectral_norm else weight_norm ) # Define the convolutional layers self.convs = nn.ModuleList( [ norm_f( nn.Conv2d( 1, 32, (3, 9), padding=(1, 4), ), ), norm_f( nn.Conv2d( 32, 32, (3, 9), stride=(1, 2), padding=(1, 4), ), ), norm_f( nn.Conv2d( 32, 32, (3, 9), stride=(1, 2), padding=(1, 4), ), ), norm_f( nn.Conv2d( 32, 32, (3, 9), stride=(1, 2), padding=(1, 4), ), ), norm_f( nn.Conv2d( 32, 32, (3, 3), padding=(1, 1), ), ), ], ) self.conv_post = norm_f( nn.Conv2d( 32, 1, (3, 3), padding=(1, 1), ), ) def forward(self, x: torch.Tensor) -> tuple[list[torch.Tensor], torch.Tensor]: r"""Forward pass of the DiscriminatorR class. Args: x (torch.Tensor): The input tensor. Returns: tuple: A tuple containing the intermediate feature maps and the output tensor. """ fmap = [] # Compute the magnitude spectrogram of the input waveform x = self.spectrogram(x) # Add a channel dimension to the spectrogram tensor x = x.unsqueeze(1) # Apply the convolutional layers with leaky ReLU activation for layer in self.convs: x = layer(x.to(dtype=self.conv_post.weight.dtype)) x = F.leaky_relu(x, self.LRELU_SLOPE) fmap.append(x) # Apply the post-convolutional layer x = self.conv_post(x) fmap.append(x) # Flatten the output tensor x = torch.flatten(x, 1, -1) return fmap, x def spectrogram(self, x: torch.Tensor) -> torch.Tensor: r"""Computes the magnitude spectrogram of the input waveform. Args: x (torch.Tensor): Input waveform tensor of shape [B, C, T]. Returns: torch.Tensor: Magnitude spectrogram tensor of shape [B, F, TT], where F is the number of frequency bins and TT is the number of time frames. """ n_fft, hop_length, win_length = self.resolution # Apply reflection padding to the input waveform x = F.pad( x, (int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), mode="reflect", ) # Squeeze the input waveform to remove the channel dimension x = x.squeeze(1) # Compute the short-time Fourier transform of the input waveform x = torch.stft( x, n_fft=n_fft, hop_length=hop_length, win_length=win_length, center=False, return_complex=True, window=torch.ones(win_length, device=x.device), ) # [B, F, TT, 2] x = torch.view_as_real(x) # Compute the magnitude spectrogram from the complex spectrogram return torch.norm(x, p=2, dim=-1) # [B, F, TT]