Spaces:
Sleeping
Sleeping
from librosa.filters import mel as librosa_mel_fn | |
import torch | |
class AudioProcessor: | |
r"""A class used to process audio signals and convert them into different representations. | |
Attributes: | |
hann_window (dict): A dictionary to store the Hann window for different configurations. | |
mel_basis (dict): A dictionary to store the Mel basis for different configurations. | |
Methods: | |
name_mel_basis(spec, n_fft, fmax): Generate a name for the Mel basis based on the FFT size, maximum frequency, data type, and device. | |
amp_to_db(magnitudes, C=1, clip_val=1e-5): Convert amplitude to decibels (dB). | |
db_to_amp(magnitudes, C=1): Convert decibels (dB) to amplitude. | |
wav_to_spec(y, n_fft, hop_length, win_length, center=False): Convert a waveform to a spectrogram and compute the magnitude. | |
wav_to_energy(y, n_fft, hop_length, win_length, center=False): Convert a waveform to a spectrogram and compute the energy. | |
spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax): Convert a spectrogram to a Mel spectrogram. | |
wav_to_mel(y, n_fft, num_mels, sample_rate, hop_length, win_length, fmin, fmax, center=False): Convert a waveform to a Mel spectrogram. | |
""" | |
def __init__(self): | |
self.hann_window = {} | |
self.mel_basis = {} | |
def name_mel_basis(spec: torch.Tensor, n_fft: int, fmax: int) -> str: | |
"""Generate a name for the Mel basis based on the FFT size, maximum frequency, data type, and device. | |
Args: | |
spec (torch.Tensor): The spectrogram tensor. | |
n_fft (int): The FFT size. | |
fmax (int): The maximum frequency. | |
Returns: | |
str: The generated name for the Mel basis. | |
""" | |
n_fft_len = f"{n_fft}_{fmax}_{spec.dtype}_{spec.device}" | |
return n_fft_len | |
def amp_to_db(magnitudes: torch.Tensor, C: int = 1, clip_val: float = 1e-5) -> torch.Tensor: | |
r"""Convert amplitude to decibels (dB). | |
Args: | |
magnitudes (Tensor): The amplitude magnitudes to convert. | |
C (int, optional): A constant value used in the conversion. Defaults to 1. | |
clip_val (float, optional): A value to clamp the magnitudes to avoid taking the log of zero. Defaults to 1e-5. | |
Returns: | |
Tensor: The converted magnitudes in dB. | |
""" | |
return torch.log(torch.clamp(magnitudes, min=clip_val) * C) | |
def db_to_amp(magnitudes: torch.Tensor, C: int = 1) -> torch.Tensor: | |
r"""Convert decibels (dB) to amplitude. | |
Args: | |
magnitudes (Tensor): The dB magnitudes to convert. | |
C (int, optional): A constant value used in the conversion. Defaults to 1. | |
Returns: | |
Tensor: The converted magnitudes in amplitude. | |
""" | |
return torch.exp(magnitudes) / C | |
def wav_to_spec( | |
self, | |
y: torch.Tensor, | |
n_fft: int, | |
hop_length: int, | |
win_length: int, | |
center: bool = False, | |
) -> torch.Tensor: | |
r"""Convert a waveform to a spectrogram and compute the magnitude. | |
Args: | |
y (Tensor): The input waveform. | |
n_fft (int): The FFT size. | |
hop_length (int): The hop (stride) size. | |
win_length (int): The window size. | |
center (bool, optional): Whether to pad `y` such that frames are centered. Defaults to False. | |
Returns: | |
Tensor: The magnitude of the computed spectrogram. | |
""" | |
y = y.squeeze(1) | |
dtype_device = str(y.dtype) + "_" + str(y.device) | |
wnsize_dtype_device = str(win_length) + "_" + dtype_device | |
if wnsize_dtype_device not in self.hann_window: | |
self.hann_window[wnsize_dtype_device] = torch.hann_window(win_length).to(dtype=y.dtype, device=y.device) | |
y = torch.nn.functional.pad( | |
y.unsqueeze(1), | |
(int((n_fft - hop_length) / 2), int((n_fft - hop_length) / 2)), | |
mode="reflect", | |
) | |
y = y.squeeze(1) | |
spec = torch.stft( | |
y, | |
n_fft, | |
hop_length=hop_length, | |
win_length=win_length, | |
window=self.hann_window[wnsize_dtype_device], | |
center=center, | |
pad_mode="reflect", | |
normalized=False, | |
onesided=True, | |
return_complex=True, | |
) | |
spec = torch.view_as_real(spec) | |
# Compute the magnitude | |
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6) | |
return spec | |
def wav_to_energy( | |
self, | |
y: torch.Tensor, | |
n_fft: int, | |
hop_length: int, | |
win_length: int, | |
center: bool = False, | |
) -> torch.Tensor: | |
r"""Convert a waveform to a spectrogram and compute the energy. | |
Args: | |
y (Tensor): The input waveform. | |
n_fft (int): The FFT size. | |
hop_length (int): The hop (stride) size. | |
win_length (int): The window size. | |
center (bool, optional): Whether to pad `y` such that frames are centered. Defaults to False. | |
Returns: | |
Tensor: The energy of the computed spectrogram. | |
""" | |
spec = self.wav_to_spec(y, n_fft, hop_length, win_length, center=center) | |
spec = torch.norm(spec, dim=1, keepdim=True).squeeze(0) | |
# Normalize the energy | |
return (spec - spec.mean()) / spec.std() | |
def spec_to_mel( | |
self, | |
spec: torch.Tensor, | |
n_fft: int, | |
num_mels: int, | |
sample_rate: int, | |
fmin: int, | |
fmax: int, | |
) -> torch.Tensor: | |
r"""Convert a spectrogram to a Mel spectrogram. | |
Args: | |
spec (torch.Tensor): The input spectrogram of shape [B, C, T]. | |
n_fft (int): The FFT size. | |
num_mels (int): The number of Mel bands. | |
sample_rate (int): The sample rate of the audio. | |
fmin (int): The minimum frequency. | |
fmax (int): The maximum frequency. | |
Returns: | |
torch.Tensor: The computed Mel spectrogram of shape [B, C, T]. | |
""" | |
mel_basis_key = self.name_mel_basis(spec, n_fft, fmax) | |
if mel_basis_key not in self.mel_basis: | |
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax) | |
self.mel_basis[mel_basis_key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device) | |
mel = torch.matmul(self.mel_basis[mel_basis_key], spec) | |
mel = self.amp_to_db(mel) | |
return mel | |
def wav_to_mel( | |
self, | |
y: torch.Tensor, | |
n_fft: int, | |
num_mels: int, | |
sample_rate: int, | |
hop_length: int, | |
win_length: int, | |
fmin: int, | |
fmax: int, | |
center: bool = False, | |
) -> torch.Tensor: | |
r"""Convert a waveform to a Mel spectrogram. | |
Args: | |
y (torch.Tensor): The input waveform. | |
n_fft (int): The FFT size. | |
num_mels (int): The number of Mel bands. | |
sample_rate (int): The sample rate of the audio. | |
hop_length (int): The hop (stride) size. | |
win_length (int): The window size. | |
fmin (int): The minimum frequency. | |
fmax (int): The maximum frequency. | |
center (bool, optional): Whether to pad `y` such that frames are centered. Defaults to False. | |
Returns: | |
torch.Tensor: The computed Mel spectrogram. | |
""" | |
# Convert the waveform to a spectrogram | |
spec = self.wav_to_spec(y, n_fft, hop_length, win_length, center=center) | |
# Convert the spectrogram to a Mel spectrogram | |
mel = self.spec_to_mel(spec, n_fft, num_mels, sample_rate, fmin, fmax) | |
return mel | |