from dataclasses import dataclass from typing import Optional import librosa import matplotlib.pyplot as plt import numpy as np import torch from torch import nn import torchaudio.transforms as T from torchmetrics.audio import ( ComplexScaleInvariantSignalNoiseRatio, ScaleInvariantSignalDistortionRatio, ScaleInvariantSignalNoiseRatio, SpeechReverberationModulationEnergyRatio, ) from models.config import PreprocessingConfig, PreprocessingConfigUnivNet, get_lang_map from training.preprocess.audio_processor import AudioProcessor @dataclass class MetricsResult: r"""A data class that holds the results of the computed metrics. Attributes: energy (torch.Tensor): The energy loss ratio. si_sdr (torch.Tensor): The scale-invariant signal-to-distortion ratio. si_snr (torch.Tensor): The scale-invariant signal-to-noise ratio. c_si_snr (torch.Tensor): The complex scale-invariant signal-to-noise ratio. mcd (torch.Tensor): The Mel cepstral distortion. spec_dist (torch.Tensor): The spectrogram distance. f0_rmse (float): The F0 RMSE. jitter (float): The jitter. shimmer (float): The shimmer. """ energy: torch.Tensor si_sdr: torch.Tensor si_snr: torch.Tensor c_si_snr: torch.Tensor mcd: torch.Tensor spec_dist: torch.Tensor f0_rmse: float jitter: float shimmer: float class Metrics: r"""A class that computes various audio metrics. Args: lang (str): language parameter. Defaults to "en". preprocess_config (Optional[PreprocessingConfig]): The preprocessing configuration. Defaults to None. Attributes: hop_length (int): The hop length for the STFT. filter_length (int): The filter length for the STFT. mel_fmin (int): The minimum frequency for the Mel scale. win_length (int): The window length for the STFT. audio_processor (AudioProcessor): The audio processor. mse_loss (nn.MSELoss): The mean squared error loss. si_sdr (ScaleInvariantSignalDistortionRatio): The scale-invariant signal-to-distortion ratio. si_snr (ScaleInvariantSignalNoiseRatio): The scale-invariant signal-to-noise ratio. c_si_snr (ComplexScaleInvariantSignalNoiseRatio): The complex scale-invariant signal-to-noise ratio. """ def __init__( self, lang: str = "en", preprocess_config: Optional[PreprocessingConfig] = None, ): lang_map = get_lang_map(lang) preprocess_config = preprocess_config or PreprocessingConfigUnivNet( lang_map.processing_lang_type, ) self.hop_length = preprocess_config.stft.hop_length self.filter_length = preprocess_config.stft.filter_length self.mel_fmin = preprocess_config.stft.mel_fmin self.win_length = preprocess_config.stft.win_length self.sample_rate = preprocess_config.sampling_rate self.audio_processor = AudioProcessor() self.mse_loss = nn.MSELoss() self.si_sdr = ScaleInvariantSignalDistortionRatio() self.si_snr = ScaleInvariantSignalNoiseRatio() self.c_si_snr = ComplexScaleInvariantSignalNoiseRatio(zero_mean=False) self.reverb_modulation_energy_ratio = SpeechReverberationModulationEnergyRatio( self.sample_rate, ) def calculate_mcd( self, wav_targets: torch.Tensor, wav_predictions: torch.Tensor, n_mfcc: int = 13, ) -> torch.Tensor: """Calculate Mel Cepstral Distortion.""" mfcc_transform = T.MFCC( sample_rate=self.sample_rate, n_mfcc=n_mfcc, melkwargs={ "n_fft": 400, "hop_length": 160, "n_mels": 23, "center": False, }, ).to(wav_targets.device) wav_predictions = wav_predictions.to(wav_targets.device) ref_mfcc = mfcc_transform(wav_targets) synth_mfcc = mfcc_transform(wav_predictions) mcd = torch.mean( torch.sqrt( torch.sum((ref_mfcc - synth_mfcc) ** 2, dim=0), ), ) return mcd def calculate_spectrogram_distance( self, wav_targets: torch.Tensor, wav_predictions: torch.Tensor, n_fft: int = 2048, hop_length: int = 512, ) -> torch.Tensor: """Calculate spectrogram distance.""" spec_transform = T.Spectrogram( n_fft=n_fft, hop_length=hop_length, power=None, ).to(wav_targets.device) wav_predictions = wav_predictions.to(wav_targets.device) # Compute the spectrograms S1 = spec_transform(wav_targets) S2 = spec_transform(wav_predictions) # Compute the magnitude spectrograms S1_mag = torch.abs(S1) S2_mag = torch.abs(S2) # Compute the Euclidean distance dist = torch.dist(S1_mag.flatten(), S2_mag.flatten()) return dist def calculate_f0_rmse( self, wav_targets: torch.Tensor, wav_predictions: torch.Tensor, frame_length: int = 2048, hop_length: int = 512, ) -> float: """Calculate F0 RMSE.""" wav_targets_ = wav_targets.detach().cpu().numpy() wav_predictions_ = wav_predictions.detach().cpu().numpy() # Compute the F0 contour for each audio signal f0_audio1 = torch.from_numpy( librosa.yin( wav_targets_, fmin=float(librosa.note_to_hz("C2")), fmax=float(librosa.note_to_hz("C7")), sr=self.sample_rate, frame_length=frame_length, hop_length=hop_length, ), ) f0_audio2 = torch.from_numpy( librosa.yin( wav_predictions_, fmin=float(librosa.note_to_hz("C2")), fmax=float(librosa.note_to_hz("C7")), sr=self.sample_rate, frame_length=frame_length, hop_length=hop_length, ), ) # Assuming f0_audio1 and f0_audio2 are PyTorch tensors rmse = torch.sqrt(torch.mean((f0_audio1 - f0_audio2) ** 2)).item() return rmse def calculate_jitter_shimmer( self, audio: torch.Tensor, ) -> tuple[float, float]: r"""Calculate jitter and shimmer of an audio signal. Jitter and shimmer are two metrics used in speech signal processing to measure the quality of voice signals. Jitter refers to the short-term variability of a signal's fundamental frequency (F0). It is often used as an indicator of voice disorders, as high levels of jitter can indicate a lack of control over the vocal folds. Shimmer, on the other hand, refers to the short-term variability in amplitude of the voice signal. Like jitter, high levels of shimmer can be indicative of voice disorders, as they can suggest a lack of control over the vocal tract. Summary: Jitter is the short-term variability of a signal's fundamental frequency (F0). Shimmer is the short-term variability in amplitude of the voice signal. Args: audio (torch.Tensor): The audio signal to analyze. Returns: tuple[float, float]: The calculated jitter and shimmer values. """ # Create a transformation to calculate the spectrogram spectrogram = T.Spectrogram( n_fft=self.filter_length * 2, hop_length=self.hop_length * 2, power=None, ) spectrogram = spectrogram.to(audio.device) # Calculate the spectrogram of the audio signal amplitude = spectrogram(audio) # Calculate the F0 contour using the yin method f0 = T.Vad(sample_rate=self.sample_rate)(audio) # Episilon to avoid division by zero epsilon = 1e-10 # Calculate the relative changes in the F0 and amplitude contours jitter = torch.mean( torch.abs(torch.diff(f0, dim=-1)) / (torch.diff(f0, dim=-1) + epsilon), ).item() shimmer = torch.mean( torch.abs(torch.diff(amplitude, dim=-1)) / (torch.diff(amplitude, dim=-1) + epsilon), ) shimmer = torch.abs(shimmer).item() return jitter, shimmer def wav_metrics(self, wav_predictions: torch.Tensor): r"""Compute the metrics for the waveforms. Args: wav_predictions (torch.Tensor): The predicted waveforms. Returns: tuple[float, float, float]: The computed metrics. """ ermr = self.reverb_modulation_energy_ratio(wav_predictions).item() jitter, shimmer = self.calculate_jitter_shimmer(wav_predictions) return ( ermr, jitter, shimmer, ) def __call__( self, wav_predictions: torch.Tensor, wav_targets: torch.Tensor, mel_predictions: torch.Tensor, mel_targets: torch.Tensor, ) -> MetricsResult: r"""Compute the metrics. Args: wav_predictions (torch.Tensor): The predicted waveforms. wav_targets (torch.Tensor): The target waveforms. mel_predictions (torch.Tensor): The predicted Mel spectrograms. mel_targets (torch.Tensor): The target Mel spectrograms. Returns: MetricsResult: The computed metrics. """ wav_predictions_energy = self.audio_processor.wav_to_energy( wav_predictions.unsqueeze(0), self.filter_length, self.hop_length, self.win_length, ) wav_targets_energy = self.audio_processor.wav_to_energy( wav_targets.unsqueeze(0), self.filter_length, self.hop_length, self.win_length, ) energy: torch.Tensor = self.mse_loss(wav_predictions_energy, wav_targets_energy) self.si_sdr.to(wav_predictions.device) self.si_snr.to(wav_predictions.device) self.c_si_snr.to(wav_predictions.device) # New Metrics si_sdr: torch.Tensor = self.si_sdr(mel_predictions, mel_targets) si_snr: torch.Tensor = self.si_snr(mel_predictions, mel_targets) # New shape: [1, F, T, 2] mel_predictions_complex = torch.stack( (mel_predictions, torch.zeros_like(mel_predictions)), dim=-1, ) mel_targets_complex = torch.stack( (mel_targets, torch.zeros_like(mel_targets)), dim=-1, ) c_si_snr: torch.Tensor = self.c_si_snr( mel_predictions_complex, mel_targets_complex, ) mcd = self.calculate_mcd(wav_targets, wav_predictions) spec_dist = self.calculate_spectrogram_distance(wav_targets, wav_predictions) f0_rmse = self.calculate_f0_rmse(wav_targets, wav_predictions) jitter, shimmer = self.calculate_jitter_shimmer(wav_predictions) return MetricsResult( energy, si_sdr, si_snr, c_si_snr, mcd, spec_dist, f0_rmse, jitter, shimmer, ) def plot_spectrograms( self, mel_target: np.ndarray, mel_prediction: np.ndarray, sr: int = 22050, ): r"""Plots the mel spectrograms for the target and the prediction.""" fig, axs = plt.subplots(2, 1, sharex=True, sharey=True, dpi=80) img1 = librosa.display.specshow( mel_target, x_axis="time", y_axis="mel", sr=sr, ax=axs[0], ) axs[0].set_title("Target spectrogram") fig.colorbar(img1, ax=axs[0], format="%+2.0f dB") img2 = librosa.display.specshow( mel_prediction, x_axis="time", y_axis="mel", sr=sr, ax=axs[1], ) axs[1].set_title("Prediction spectrogram") fig.colorbar(img2, ax=axs[1], format="%+2.0f dB") # Adjust the spacing between subplots fig.subplots_adjust(hspace=0.5) return fig def plot_spectrograms_fast( self, mel_target: np.ndarray, mel_prediction: np.ndarray, sr: int = 22050, ): r"""Plots the mel spectrograms for the target and the prediction.""" fig, axs = plt.subplots(2, 1, sharex=True, sharey=True) axs[0].specgram( mel_target, aspect="auto", Fs=sr, cmap=plt.get_cmap("magma"), # type: ignore ) axs[0].set_title("Target spectrogram") axs[1].specgram( mel_prediction, aspect="auto", Fs=sr, cmap=plt.get_cmap("magma"), # type: ignore ) axs[1].set_title("Prediction spectrogram") # Adjust the spacing between subplots fig.subplots_adjust(hspace=0.5) return fig