File size: 13,470 Bytes
9d61c9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
from dataclasses import dataclass
from typing import Optional

import librosa
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn
import torchaudio.transforms as T
from torchmetrics.audio import (
    ComplexScaleInvariantSignalNoiseRatio,
    ScaleInvariantSignalDistortionRatio,
    ScaleInvariantSignalNoiseRatio,
    SpeechReverberationModulationEnergyRatio,
)

from models.config import PreprocessingConfig, PreprocessingConfigUnivNet, get_lang_map
from training.preprocess.audio_processor import AudioProcessor


@dataclass
class MetricsResult:
    r"""A data class that holds the results of the computed metrics.



    Attributes:

        energy (torch.Tensor): The energy loss ratio.

        si_sdr (torch.Tensor): The scale-invariant signal-to-distortion ratio.

        si_snr (torch.Tensor): The scale-invariant signal-to-noise ratio.

        c_si_snr (torch.Tensor): The complex scale-invariant signal-to-noise ratio.

        mcd (torch.Tensor): The Mel cepstral distortion.

        spec_dist (torch.Tensor): The spectrogram distance.

        f0_rmse (float): The F0 RMSE.

        jitter (float): The jitter.

        shimmer (float): The shimmer.

    """

    energy: torch.Tensor
    si_sdr: torch.Tensor
    si_snr: torch.Tensor
    c_si_snr: torch.Tensor
    mcd: torch.Tensor
    spec_dist: torch.Tensor
    f0_rmse: float
    jitter: float
    shimmer: float


class Metrics:
    r"""A class that computes various audio metrics.



    Args:

        lang (str): language parameter. Defaults to "en".

        preprocess_config (Optional[PreprocessingConfig]): The preprocessing configuration. Defaults to None.



    Attributes:

        hop_length (int): The hop length for the STFT.

        filter_length (int): The filter length for the STFT.

        mel_fmin (int): The minimum frequency for the Mel scale.

        win_length (int): The window length for the STFT.

        audio_processor (AudioProcessor): The audio processor.

        mse_loss (nn.MSELoss): The mean squared error loss.

        si_sdr (ScaleInvariantSignalDistortionRatio): The scale-invariant signal-to-distortion ratio.

        si_snr (ScaleInvariantSignalNoiseRatio): The scale-invariant signal-to-noise ratio.

        c_si_snr (ComplexScaleInvariantSignalNoiseRatio): The complex scale-invariant signal-to-noise ratio.

    """

    def __init__(

        self,

        lang: str = "en",

        preprocess_config: Optional[PreprocessingConfig] = None,

    ):
        lang_map = get_lang_map(lang)
        preprocess_config = preprocess_config or PreprocessingConfigUnivNet(
            lang_map.processing_lang_type,
        )

        self.hop_length = preprocess_config.stft.hop_length
        self.filter_length = preprocess_config.stft.filter_length
        self.mel_fmin = preprocess_config.stft.mel_fmin
        self.win_length = preprocess_config.stft.win_length
        self.sample_rate = preprocess_config.sampling_rate

        self.audio_processor = AudioProcessor()
        self.mse_loss = nn.MSELoss()
        self.si_sdr = ScaleInvariantSignalDistortionRatio()
        self.si_snr = ScaleInvariantSignalNoiseRatio()
        self.c_si_snr = ComplexScaleInvariantSignalNoiseRatio(zero_mean=False)
        self.reverb_modulation_energy_ratio = SpeechReverberationModulationEnergyRatio(
            self.sample_rate,
        )

    def calculate_mcd(

        self,

        wav_targets: torch.Tensor,

        wav_predictions: torch.Tensor,

        n_mfcc: int = 13,

    ) -> torch.Tensor:
        """Calculate Mel Cepstral Distortion."""
        mfcc_transform = T.MFCC(
            sample_rate=self.sample_rate,
            n_mfcc=n_mfcc,
            melkwargs={
                "n_fft": 400,
                "hop_length": 160,
                "n_mels": 23,
                "center": False,
            },
        ).to(wav_targets.device)
        wav_predictions = wav_predictions.to(wav_targets.device)

        ref_mfcc = mfcc_transform(wav_targets)
        synth_mfcc = mfcc_transform(wav_predictions)

        mcd = torch.mean(
            torch.sqrt(
                torch.sum((ref_mfcc - synth_mfcc) ** 2, dim=0),
            ),
        )

        return mcd

    def calculate_spectrogram_distance(

        self,

        wav_targets: torch.Tensor,

        wav_predictions: torch.Tensor,

        n_fft: int = 2048,

        hop_length: int = 512,

    ) -> torch.Tensor:
        """Calculate spectrogram distance."""
        spec_transform = T.Spectrogram(
            n_fft=n_fft,
            hop_length=hop_length,
            power=None,
        ).to(wav_targets.device)
        wav_predictions = wav_predictions.to(wav_targets.device)

        # Compute the spectrograms
        S1 = spec_transform(wav_targets)
        S2 = spec_transform(wav_predictions)

        # Compute the magnitude spectrograms
        S1_mag = torch.abs(S1)
        S2_mag = torch.abs(S2)

        # Compute the Euclidean distance
        dist = torch.dist(S1_mag.flatten(), S2_mag.flatten())

        return dist

    def calculate_f0_rmse(

        self,

        wav_targets: torch.Tensor,

        wav_predictions: torch.Tensor,

        frame_length: int = 2048,

        hop_length: int = 512,

    ) -> float:
        """Calculate F0 RMSE."""
        wav_targets_ = wav_targets.detach().cpu().numpy()
        wav_predictions_ = wav_predictions.detach().cpu().numpy()

        # Compute the F0 contour for each audio signal
        f0_audio1 = torch.from_numpy(
            librosa.yin(
                wav_targets_,
                fmin=float(librosa.note_to_hz("C2")),
                fmax=float(librosa.note_to_hz("C7")),
                sr=self.sample_rate,
                frame_length=frame_length,
                hop_length=hop_length,
            ),
        )
        f0_audio2 = torch.from_numpy(
            librosa.yin(
                wav_predictions_,
                fmin=float(librosa.note_to_hz("C2")),
                fmax=float(librosa.note_to_hz("C7")),
                sr=self.sample_rate,
                frame_length=frame_length,
                hop_length=hop_length,
            ),
        )

        # Assuming f0_audio1 and f0_audio2 are PyTorch tensors
        rmse = torch.sqrt(torch.mean((f0_audio1 - f0_audio2) ** 2)).item()

        return rmse

    def calculate_jitter_shimmer(

        self,

        audio: torch.Tensor,

    ) -> tuple[float, float]:
        r"""Calculate jitter and shimmer of an audio signal.



        Jitter and shimmer are two metrics used in speech signal processing to measure the quality of voice signals.



        Jitter refers to the short-term variability of a signal's fundamental frequency (F0). It is often used as an indicator of voice disorders, as high levels of jitter can indicate a lack of control over the vocal folds.



        Shimmer, on the other hand, refers to the short-term variability in amplitude of the voice signal. Like jitter, high levels of shimmer can be indicative of voice disorders, as they can suggest a lack of control over the vocal tract.



        Summary:

        Jitter is the short-term variability of a signal's fundamental frequency (F0).

        Shimmer is the short-term variability in amplitude of the voice signal.



        Args:

            audio (torch.Tensor): The audio signal to analyze.



        Returns:

            tuple[float, float]: The calculated jitter and shimmer values.

        """
        # Create a transformation to calculate the spectrogram
        spectrogram = T.Spectrogram(
            n_fft=self.filter_length * 2,
            hop_length=self.hop_length * 2,
            power=None,
        )

        spectrogram = spectrogram.to(audio.device)

        # Calculate the spectrogram of the audio signal
        amplitude = spectrogram(audio)

        # Calculate the F0 contour using the yin method
        f0 = T.Vad(sample_rate=self.sample_rate)(audio)

        # Episilon to avoid division by zero
        epsilon = 1e-10
        # Calculate the relative changes in the F0 and amplitude contours
        jitter = torch.mean(
            torch.abs(torch.diff(f0, dim=-1)) / (torch.diff(f0, dim=-1) + epsilon),
        ).item()
        shimmer = torch.mean(
            torch.abs(torch.diff(amplitude, dim=-1))
            / (torch.diff(amplitude, dim=-1) + epsilon),
        )

        shimmer = torch.abs(shimmer).item()

        return jitter, shimmer

    def wav_metrics(self, wav_predictions: torch.Tensor):
        r"""Compute the metrics for the waveforms.



        Args:

            wav_predictions (torch.Tensor): The predicted waveforms.



        Returns:

            tuple[float, float, float]: The computed metrics.

        """
        ermr = self.reverb_modulation_energy_ratio(wav_predictions).item()
        jitter, shimmer = self.calculate_jitter_shimmer(wav_predictions)

        return (
            ermr,
            jitter,
            shimmer,
        )

    def __call__(

        self,

        wav_predictions: torch.Tensor,

        wav_targets: torch.Tensor,

        mel_predictions: torch.Tensor,

        mel_targets: torch.Tensor,

    ) -> MetricsResult:
        r"""Compute the metrics.



        Args:

            wav_predictions (torch.Tensor): The predicted waveforms.

            wav_targets (torch.Tensor): The target waveforms.

            mel_predictions (torch.Tensor): The predicted Mel spectrograms.

            mel_targets (torch.Tensor): The target Mel spectrograms.



        Returns:

            MetricsResult: The computed metrics.

        """
        wav_predictions_energy = self.audio_processor.wav_to_energy(
            wav_predictions.unsqueeze(0),
            self.filter_length,
            self.hop_length,
            self.win_length,
        )

        wav_targets_energy = self.audio_processor.wav_to_energy(
            wav_targets.unsqueeze(0),
            self.filter_length,
            self.hop_length,
            self.win_length,
        )

        energy: torch.Tensor = self.mse_loss(wav_predictions_energy, wav_targets_energy)

        self.si_sdr.to(wav_predictions.device)
        self.si_snr.to(wav_predictions.device)
        self.c_si_snr.to(wav_predictions.device)

        # New Metrics
        si_sdr: torch.Tensor = self.si_sdr(mel_predictions, mel_targets)
        si_snr: torch.Tensor = self.si_snr(mel_predictions, mel_targets)

        # New shape: [1, F, T, 2]
        mel_predictions_complex = torch.stack(
            (mel_predictions, torch.zeros_like(mel_predictions)),
            dim=-1,
        )
        mel_targets_complex = torch.stack(
            (mel_targets, torch.zeros_like(mel_targets)),
            dim=-1,
        )
        c_si_snr: torch.Tensor = self.c_si_snr(
            mel_predictions_complex,
            mel_targets_complex,
        )

        mcd = self.calculate_mcd(wav_targets, wav_predictions)
        spec_dist = self.calculate_spectrogram_distance(wav_targets, wav_predictions)
        f0_rmse = self.calculate_f0_rmse(wav_targets, wav_predictions)
        jitter, shimmer = self.calculate_jitter_shimmer(wav_predictions)

        return MetricsResult(
            energy,
            si_sdr,
            si_snr,
            c_si_snr,
            mcd,
            spec_dist,
            f0_rmse,
            jitter,
            shimmer,
        )

    def plot_spectrograms(

        self,

        mel_target: np.ndarray,

        mel_prediction: np.ndarray,

        sr: int = 22050,

    ):
        r"""Plots the mel spectrograms for the target and the prediction."""
        fig, axs = plt.subplots(2, 1, sharex=True, sharey=True, dpi=80)

        img1 = librosa.display.specshow(
            mel_target,
            x_axis="time",
            y_axis="mel",
            sr=sr,
            ax=axs[0],
        )
        axs[0].set_title("Target spectrogram")
        fig.colorbar(img1, ax=axs[0], format="%+2.0f dB")

        img2 = librosa.display.specshow(
            mel_prediction,
            x_axis="time",
            y_axis="mel",
            sr=sr,
            ax=axs[1],
        )
        axs[1].set_title("Prediction spectrogram")
        fig.colorbar(img2, ax=axs[1], format="%+2.0f dB")

        # Adjust the spacing between subplots
        fig.subplots_adjust(hspace=0.5)

        return fig

    def plot_spectrograms_fast(

        self,

        mel_target: np.ndarray,

        mel_prediction: np.ndarray,

        sr: int = 22050,

    ):
        r"""Plots the mel spectrograms for the target and the prediction."""
        fig, axs = plt.subplots(2, 1, sharex=True, sharey=True)

        axs[0].specgram(
            mel_target,
            aspect="auto",
            Fs=sr,
            cmap=plt.get_cmap("magma"),  # type: ignore
        )
        axs[0].set_title("Target spectrogram")

        axs[1].specgram(
            mel_prediction,
            aspect="auto",
            Fs=sr,
            cmap=plt.get_cmap("magma"),  # type: ignore
        )
        axs[1].set_title("Prediction spectrogram")

        # Adjust the spacing between subplots
        fig.subplots_adjust(hspace=0.5)

        return fig