Excitometer / model /utils.py
szili2011's picture
Upload 17 files
c248b9b verified
import numpy as np
import torchaudio
import torchaudio.transforms as transforms
def load_audio(file_path):
"""
Load an audio file and return the waveform and sample rate.
Args:
file_path (str): Path to the audio file.
Returns:
waveform (Tensor): Audio waveform.
sample_rate (int): Sample rate of the audio.
"""
waveform, sample_rate = torchaudio.load(file_path)
return waveform, sample_rate
def preprocess_audio(waveform, sample_rate, target_sample_rate=16000, target_length=16000):
"""
Preprocess audio by resampling and truncating or padding to a fixed length.
Args:
waveform (Tensor): Audio waveform.
sample_rate (int): Original sample rate.
target_sample_rate (int): Desired sample rate for the output waveform.
target_length (int): Desired length of the output waveform in samples.
Returns:
waveform (Tensor): Preprocessed audio waveform.
"""
# Resample audio if necessary
if sample_rate != target_sample_rate:
resampler = transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)
waveform = resampler(waveform)
# Truncate or pad waveform
if waveform.size(1) > target_length:
waveform = waveform[:, :target_length]
elif waveform.size(1) < target_length:
pad_length = target_length - waveform.size(1)
waveform = torch.nn.functional.pad(waveform, (0, pad_length))
return waveform
def extract_features(waveform, sample_rate, n_mfcc=13):
"""
Extract MFCC features from the audio waveform.
Args:
waveform (Tensor): Audio waveform.
sample_rate (int): Sample rate of the audio.
n_mfcc (int): Number of MFCC features to extract.
Returns:
mfcc_features (Tensor): Extracted MFCC features.
"""
mfcc = transforms.MFCC(
sample_rate=sample_rate,
n_mfcc=n_mfcc,
melkwargs={"n_fft": 400, "hop_length": 160, "n_mels": 23, "center": False}
)
mfcc_features = mfcc(waveform)
return mfcc_features
def plot_waveform(waveform, sample_rate, title="Waveform"):
"""
Plot the waveform of an audio signal.
Args:
waveform (Tensor): Audio waveform.
sample_rate (int): Sample rate of the audio.
title (str): Title of the plot.
"""
import matplotlib.pyplot as plt
waveform_np = waveform.squeeze().cpu().numpy()
time = np.arange(waveform_np.shape[0]) / sample_rate
plt.figure(figsize=(10, 4))
plt.plot(time, waveform_np)
plt.title(title)
plt.xlabel("Time (s)")
plt.ylabel("Amplitude")
plt.show()