Spaces:
Sleeping
Sleeping
File size: 4,983 Bytes
9d61c9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
from typing import Optional, Tuple
import librosa
import torch
from torch.nn import Module
class TacotronSTFT(Module):
def __init__(
self,
filter_length: int,
hop_length: int,
win_length: int,
n_mel_channels: int,
sampling_rate: int,
center: bool,
mel_fmax: Optional[int],
mel_fmin: float = 0.0,
):
r"""TacotronSTFT module that computes mel-spectrograms from a batch of waves.
Args:
filter_length (int): Length of the filter window.
hop_length (int): Number of samples between successive frames.
win_length (int): Size of the STFT window.
n_mel_channels (int): Number of mel bins.
sampling_rate (int): Sampling rate of the input waveforms.
mel_fmin (int or None): Minimum frequency for the mel filter bank.
mel_fmax (int or None): Maximum frequency for the mel filter bank.
center (bool): Whether to pad the input signal on both sides.
"""
super().__init__()
self.n_mel_channels = n_mel_channels
self.sampling_rate = sampling_rate
self.n_fft = filter_length
self.hop_size = hop_length
self.win_size = win_length
self.fmin = mel_fmin
self.fmax = mel_fmax
self.center = center
# Define the mel filterbank
mel = librosa.filters.mel(
sr=sampling_rate,
n_fft=filter_length,
n_mels=n_mel_channels,
fmin=mel_fmin,
fmax=mel_fmax,
)
mel_basis = torch.from_numpy(mel).float()
# Define the Hann window
hann_window = torch.hann_window(win_length)
self.register_buffer("mel_basis", mel_basis)
self.register_buffer("hann_window", hann_window)
def _spectrogram(self, y: torch.Tensor) -> torch.Tensor:
r"""Computes the linear spectrogram of a batch of waves.
Args:
y (torch.Tensor): Input waveforms.
Returns:
torch.Tensor: Linear spectrogram.
"""
assert torch.min(y.data) >= -1
assert torch.max(y.data) <= 1
y = torch.nn.functional.pad(
y.unsqueeze(1),
(
int((self.n_fft - self.hop_size) / 2),
int((self.n_fft - self.hop_size) / 2),
),
mode="reflect",
)
y = y.squeeze(1)
spec = torch.stft(
y,
self.n_fft,
hop_length=self.hop_size,
win_length=self.win_size,
window=self.hann_window, # type: ignore
center=self.center,
pad_mode="reflect",
normalized=False,
onesided=True,
return_complex=True,
)
return torch.view_as_real(spec)
def linear_spectrogram(self, y: torch.Tensor) -> torch.Tensor:
r"""Computes the linear spectrogram of a batch of waves.
Args:
y (torch.Tensor): Input waveforms.
Returns:
torch.Tensor: Linear spectrogram.
"""
spec = self._spectrogram(y)
return torch.norm(spec, p=2, dim=-1)
def forward(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
r"""Computes mel-spectrograms from a batch of waves.
Args:
y (torch.FloatTensor): Input waveforms with shape (B, T) in range [-1, 1]
Returns:
torch.FloatTensor: Spectrogram of shape (B, n_spech_channels, T)
torch.FloatTensor: Mel-spectrogram of shape (B, n_mel_channels, T)
"""
spec = self._spectrogram(y)
spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))
mel = torch.matmul(self.mel_basis, spec) # type: ignore
mel = self.spectral_normalize_torch(mel)
return spec, mel
def spectral_normalize_torch(self, magnitudes: torch.Tensor) -> torch.Tensor:
r"""Applies dynamic range compression to magnitudes.
Args:
magnitudes (torch.Tensor): Input magnitudes.
Returns:
torch.Tensor: Output magnitudes.
"""
return self.dynamic_range_compression_torch(magnitudes)
def dynamic_range_compression_torch(
self,
x: torch.Tensor,
C: int = 1,
clip_val: float = 1e-5,
) -> torch.Tensor:
r"""Applies dynamic range compression to x.
Args:
x (torch.Tensor): Input tensor.
C (float): Compression factor.
clip_val (float): Clipping value.
Returns:
torch.Tensor: Output tensor.
"""
return torch.log(torch.clamp(x, min=clip_val) * C)
# NOTE: audio np.ndarray changed to torch.FloatTensor!
def get_mel_from_wav(self, audio: torch.Tensor) -> torch.Tensor:
audio_tensor = audio.unsqueeze(0)
with torch.no_grad():
_, melspec = self.forward(audio_tensor)
return melspec.squeeze(0)
|