|
import numpy as np
|
|
import torchaudio
|
|
import torchaudio.transforms as transforms
|
|
|
|
def load_audio(file_path):
|
|
"""
|
|
Load an audio file and return the waveform and sample rate.
|
|
|
|
Args:
|
|
file_path (str): Path to the audio file.
|
|
|
|
Returns:
|
|
waveform (Tensor): Audio waveform.
|
|
sample_rate (int): Sample rate of the audio.
|
|
"""
|
|
waveform, sample_rate = torchaudio.load(file_path)
|
|
return waveform, sample_rate
|
|
|
|
def preprocess_audio(waveform, sample_rate, target_sample_rate=16000, target_length=16000):
|
|
"""
|
|
Preprocess audio by resampling and truncating or padding to a fixed length.
|
|
|
|
Args:
|
|
waveform (Tensor): Audio waveform.
|
|
sample_rate (int): Original sample rate.
|
|
target_sample_rate (int): Desired sample rate for the output waveform.
|
|
target_length (int): Desired length of the output waveform in samples.
|
|
|
|
Returns:
|
|
waveform (Tensor): Preprocessed audio waveform.
|
|
"""
|
|
|
|
if sample_rate != target_sample_rate:
|
|
resampler = transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
|
|
waveform = resampler(waveform)
|
|
|
|
|
|
if waveform.size(1) > target_length:
|
|
waveform = waveform[:, :target_length]
|
|
elif waveform.size(1) < target_length:
|
|
pad_length = target_length - waveform.size(1)
|
|
waveform = torch.nn.functional.pad(waveform, (0, pad_length))
|
|
|
|
return waveform
|
|
|
|
def extract_features(waveform, sample_rate, n_mfcc=13):
|
|
"""
|
|
Extract MFCC features from the audio waveform.
|
|
|
|
Args:
|
|
waveform (Tensor): Audio waveform.
|
|
sample_rate (int): Sample rate of the audio.
|
|
n_mfcc (int): Number of MFCC features to extract.
|
|
|
|
Returns:
|
|
mfcc_features (Tensor): Extracted MFCC features.
|
|
"""
|
|
mfcc = transforms.MFCC(
|
|
sample_rate=sample_rate,
|
|
n_mfcc=n_mfcc,
|
|
melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23, "center": False}
|
|
)
|
|
mfcc_features = mfcc(waveform)
|
|
return mfcc_features
|
|
|
|
def plot_waveform(waveform, sample_rate, title="Waveform"):
|
|
"""
|
|
Plot the waveform of an audio signal.
|
|
|
|
Args:
|
|
waveform (Tensor): Audio waveform.
|
|
sample_rate (int): Sample rate of the audio.
|
|
title (str): Title of the plot.
|
|
"""
|
|
import matplotlib.pyplot as plt
|
|
|
|
waveform_np = waveform.squeeze().cpu().numpy()
|
|
time = np.arange(waveform_np.shape[0]) / sample_rate
|
|
|
|
plt.figure(figsize=(10, 4))
|
|
plt.plot(time, waveform_np)
|
|
plt.title(title)
|
|
plt.xlabel("Time (s)")
|
|
plt.ylabel("Amplitude")
|
|
plt.show()
|
|
|