File size: 4,983 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from typing import Optional, Tuple

import librosa
import torch
from torch.nn import Module


class TacotronSTFT(Module):
    def __init__(
        self,
        filter_length: int,
        hop_length: int,
        win_length: int,
        n_mel_channels: int,
        sampling_rate: int,
        center: bool,
        mel_fmax: Optional[int],
        mel_fmin: float = 0.0,
    ):
        r"""TacotronSTFT module that computes mel-spectrograms from a batch of waves.

        Args:
            filter_length (int): Length of the filter window.
            hop_length (int): Number of samples between successive frames.
            win_length (int): Size of the STFT window.
            n_mel_channels (int): Number of mel bins.
            sampling_rate (int): Sampling rate of the input waveforms.
            mel_fmin (int or None): Minimum frequency for the mel filter bank.
            mel_fmax (int or None): Maximum frequency for the mel filter bank.
            center (bool): Whether to pad the input signal on both sides.
        """
        super().__init__()

        self.n_mel_channels = n_mel_channels
        self.sampling_rate = sampling_rate
        self.n_fft = filter_length
        self.hop_size = hop_length
        self.win_size = win_length
        self.fmin = mel_fmin
        self.fmax = mel_fmax
        self.center = center

        # Define the mel filterbank
        mel = librosa.filters.mel(
            sr=sampling_rate,
            n_fft=filter_length,
            n_mels=n_mel_channels,
            fmin=mel_fmin,
            fmax=mel_fmax,
        )

        mel_basis = torch.from_numpy(mel).float()

        # Define the Hann window
        hann_window = torch.hann_window(win_length)

        self.register_buffer("mel_basis", mel_basis)
        self.register_buffer("hann_window", hann_window)

    def _spectrogram(self, y: torch.Tensor) -> torch.Tensor:
        r"""Computes the linear spectrogram of a batch of waves.

        Args:
            y (torch.Tensor): Input waveforms.

        Returns:
            torch.Tensor: Linear spectrogram.
        """
        assert torch.min(y.data) >= -1
        assert torch.max(y.data) <= 1

        y = torch.nn.functional.pad(
            y.unsqueeze(1),
            (
                int((self.n_fft - self.hop_size) / 2),
                int((self.n_fft - self.hop_size) / 2),
            ),
            mode="reflect",
        )
        y = y.squeeze(1)
        spec = torch.stft(
            y,
            self.n_fft,
            hop_length=self.hop_size,
            win_length=self.win_size,
            window=self.hann_window,  # type: ignore
            center=self.center,
            pad_mode="reflect",
            normalized=False,
            onesided=True,
            return_complex=True,
        )
        return torch.view_as_real(spec)

    def linear_spectrogram(self, y: torch.Tensor) -> torch.Tensor:
        r"""Computes the linear spectrogram of a batch of waves.

        Args:
            y (torch.Tensor): Input waveforms.

        Returns:
            torch.Tensor: Linear spectrogram.
        """
        spec = self._spectrogram(y)
        return torch.norm(spec, p=2, dim=-1)

    def forward(self, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Computes mel-spectrograms from a batch of waves.

        Args:
            y (torch.FloatTensor): Input waveforms with shape (B, T) in range [-1, 1]

        Returns:
            torch.FloatTensor: Spectrogram of shape (B, n_spech_channels, T)
            torch.FloatTensor: Mel-spectrogram of shape (B, n_mel_channels, T)
        """
        spec = self._spectrogram(y)

        spec = torch.sqrt(spec.pow(2).sum(-1) + (1e-9))

        mel = torch.matmul(self.mel_basis, spec)  # type: ignore
        mel = self.spectral_normalize_torch(mel)

        return spec, mel

    def spectral_normalize_torch(self, magnitudes: torch.Tensor) -> torch.Tensor:
        r"""Applies dynamic range compression to magnitudes.

        Args:
            magnitudes (torch.Tensor): Input magnitudes.

        Returns:
            torch.Tensor: Output magnitudes.
        """
        return self.dynamic_range_compression_torch(magnitudes)

    def dynamic_range_compression_torch(
        self,
        x: torch.Tensor,
        C: int = 1,
        clip_val: float = 1e-5,
    ) -> torch.Tensor:
        r"""Applies dynamic range compression to x.

        Args:
            x (torch.Tensor): Input tensor.
            C (float): Compression factor.
            clip_val (float): Clipping value.

        Returns:
            torch.Tensor: Output tensor.
        """
        return torch.log(torch.clamp(x, min=clip_val) * C)

    # NOTE: audio np.ndarray changed to torch.FloatTensor!
    def get_mel_from_wav(self, audio: torch.Tensor) -> torch.Tensor:
        audio_tensor = audio.unsqueeze(0)
        with torch.no_grad():
            _, melspec = self.forward(audio_tensor)
        return melspec.squeeze(0)