from typing import Optional, Tuple import librosa import torch from torch.nn import Module class TacotronSTFT(Module): def __init__( self, filter_length: int, hop_length: int, win_length: int, n_mel_channels: int, sampling_rate: int, center: bool, mel_fmax: Optional[int], mel_fmin: float = 0.0, ): r"""TacotronSTFT module that computes mel-spectrograms from a batch of waves. Args: filter_length (int): Length of the filter window. hop_length (int): Number of samples between successive frames. win_length (int): Size of the STFT window. n_mel_channels (int): Number of mel bins. sampling_rate (int): Sampling rate of the input waveforms. mel_fmin (int or None): Minimum frequency for the mel filter bank. mel_fmax (int or None): Maximum frequency for the mel filter bank. center (bool): Whether to pad the input signal on both sides. """ super().__init__() self.n_mel_channels = n_mel_channels self.sampling_rate = sampling_rate self.n_fft = filter_length self.hop_size = hop_length self.win_size = win_length self.fmin = mel_fmin self.fmax = mel_fmax self.center = center # Define the mel filterbank mel = librosa.filters.mel( sr=sampling_rate, n_fft=filter_length, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, ) mel_basis = torch.from_numpy(mel).float() # Define the Hann window hann_window = torch.hann_window(win_length) self.register_buffer("mel_basis", mel_basis) self.register_buffer("hann_window", hann_window) def _spectrogram(self, y: torch.Tensor) -> torch.Tensor: r"""Computes the linear spectrogram of a batch of waves. Args: y (torch.Tensor): Input waveforms. Returns: torch.Tensor: Linear spectrogram. """ assert torch.min(y.data) >= -1 assert torch.max(y.data) <= 1 y = torch.nn.functional.pad( y.unsqueeze(1), ( int((self.n_fft - self.hop_size) / 2), int((self.n_fft - self.hop_size) / 2), ), mode="reflect", ) y = y.squeeze(1) spec = torch.stft( y, self.n_fft, hop_length=self.hop_size, win_length=self.win_size, window=self.hann_window, # type: ignore center=self.center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True, ) return torch.view_as_real(spec) def linear_spectrogram(self, y: torch.Tensor) -> torch.Tensor: r"""Computes the linear spectrogram of a batch of waves. Args: y (torch.Tensor): Input waveforms. Returns: torch.Tensor: Linear spectrogram. """ spec = self._spectrogram(y) return torch.norm(spec, p=2, dim=-1) def forward(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: r"""Computes mel-spectrograms from a batch of waves. Args: y (torch.FloatTensor): Input waveforms with shape (B, T) in range [-1, 1] Returns: torch.FloatTensor: Spectrogram of shape (B, n_spech_channels, T) torch.FloatTensor: Mel-spectrogram of shape (B, n_mel_channels, T) """ spec = self._spectrogram(y) spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9)) mel = torch.matmul(self.mel_basis, spec) # type: ignore mel = self.spectral_normalize_torch(mel) return spec, mel def spectral_normalize_torch(self, magnitudes: torch.Tensor) -> torch.Tensor: r"""Applies dynamic range compression to magnitudes. Args: magnitudes (torch.Tensor): Input magnitudes. Returns: torch.Tensor: Output magnitudes. """ return self.dynamic_range_compression_torch(magnitudes) def dynamic_range_compression_torch( self, x: torch.Tensor, C: int = 1, clip_val: float = 1e-5, ) -> torch.Tensor: r"""Applies dynamic range compression to x. Args: x (torch.Tensor): Input tensor. C (float): Compression factor. clip_val (float): Clipping value. Returns: torch.Tensor: Output tensor. """ return torch.log(torch.clamp(x, min=clip_val) * C) # NOTE: audio np.ndarray changed to torch.FloatTensor! def get_mel_from_wav(self, audio: torch.Tensor) -> torch.Tensor: audio_tensor = audio.unsqueeze(0) with torch.no_grad(): _, melspec = self.forward(audio_tensor) return melspec.squeeze(0)