Genre_Classifier / functions.py
DurreSudoku's picture
Upload 3 files
e2eef75 verified
raw
history blame
5.93 kB
import librosa
import numpy as np
import os
import gc
import matplotlib
import matplotlib.pyplot as plt
def load_audio_file(path, sample_rate=22050, resampling_type="kaiser_fast", duration=30):
"""Load an audio file as a numpy string using Librosa library.
Args:
path (str): Path to audio file.
sample_rate (int, optional): Sample rate to resample audio file to.
"None" uses the file's original sample rate. Defaults to 44100.
resampling_type (str, optional): Method to use for resampling. Defaults to "kaiser_fast".
duration (int, optional): Length to pad/shorten audio files to.
0 returns original audio length. Defaults to 30.
Returns:
numpy.array: Audio file as numpy array.
"""
# Load an audio file with librosa. Resamples the file to a specified sample rate.
audio_array, _ = librosa.load(path, sr=sample_rate, mono=True, res_type=resampling_type)
if duration > 0:
audio_array = pad_audio(audio_array, sample_rate, 30)
return audio_array
def add_noise(audio_array, std):
noise = np.random.normal(0, std, audio_array.shape)
return audio_array + noise
def pad_audio(audio_array, sample_rate=22050, duration=30):
# If audio array is shorter than 30s*sample rate -> pad
# If audio array is longer than 30s*sample rate -> shorten
duration_samples = duration * sample_rate
audio_len = audio_array.size
if audio_len < duration_samples:
audio_array = np.pad(audio_array, (duration_samples - audio_len)//2)
elif audio_len > duration_samples:
audio_array = audio_array[:duration_samples]
return audio_array
def log_mel_spectrogram(audio_array, sr=22050, nfft=2048, hop_length=512, window="hann"):
S = librosa.feature.melspectrogram(y=audio_array, sr=sr, n_fft=nfft,
hop_length=hop_length, win_length=nfft,
window=window)
S_db = librosa.power_to_db(S, ref=np.max)
return S_db
def split_spectrogram(spectrogram, output_shape=(128, 256)):
# Split spectrogram into equal chunks along the column axis.
splits = []
col_idx = 0
while col_idx + output_shape[1] <= spectrogram.shape[1]:
spec_split = spectrogram[:, col_idx:col_idx+output_shape[1]]
splits.append(spec_split)
col_idx += output_shape[1]
return splits
def save_spectrogram_as_png(spectrogram, save_path, sample_rate=22050, nfft=2048, hop_length=512):
shape = spectrogram.shape
fig, ax = plt.subplots(1, 1, figsize=(shape[1]/100, shape[0]/100))
fig.subplots_adjust(top=1.0, bottom=0, right=1.0, left=0, hspace=0, wspace=0)
ax.set_axis_off()
librosa.display.specshow(data=spectrogram, sr=sample_rate, n_fft=nfft, hop_length=hop_length, ax=ax)
plt.savefig(save_path, bbox_inches=None, pad_inches=0)
plt.close(fig)
return
def extract_features(df, audio_dir, save_path,
sr=22050, rs_type="kaiser_fast",
output_shape=(128,256), duration=30
, nfft=2048, hop_length=512, window="hann", checkpoint_id=0):
"""
Loads audio files, computes log-mel-spectrogram and saves it as png.
Args:
df (_type_): DataFrame containing ids and genres. Should only contain samples from specific data split (train/val/test).
audio_dir (_type_): Directory containing all audio files.
save_path (_type_): Path to where spectrograms will be saved.
sr (_type_): Sampling rate to set for all loaded audio files.
rs_type (_type_): Resampling method used when loading audio file to specific sampling rate.
output_shape (_type_): Shape of each spectrogram split.
duration (_type_): Set to standardize length of all audio files. Longer or shorter will be cut or padded respectively.
nfft (_type_): Number of samples for every fft window.
hop_length (_type_): Hop length to use for STFT.
window (_type_): Window function to use for STFT.
checkpoint_id (int, optional): Write the id of a track to start from there. Defaults to 0.
"""
matplotlib.use("Agg")
if int(checkpoint_id) > 0:
df = df.loc[checkpoint_id:]
id_list = df.index.ravel()
genre_list = df["genre_top"].ravel()
# Due to some weird memory leak, garbage collection is manually performed every 10% of progress.
gc_interval = int(len(id_list) * 0.1)
gc_checkpoints = id_list[::gc_interval]
for id, genre in zip(id_list, genre_list):
id_string = str(id).rjust(6, "0")
filename = id_string + ".mp3"
folder_name = filename[:3]
file_path = os.path.join(audio_dir, folder_name, filename)
print(id_string, end=" ")
audio = load_audio_file(file_path, sr, rs_type, duration)
spectrogram = log_mel_spectrogram(audio, sr=sr, nfft=nfft, hop_length=hop_length, window=window)
spec_splits = split_spectrogram(spectrogram, output_shape)
for idx, split in enumerate(spec_splits):
image_name = id_string + "_" + str(idx+1) +".png"
image_path = os.path.join(save_path, genre, image_name)
save_spectrogram_as_png(split, image_path, sr, nfft, hop_length)
if id in gc_checkpoints:
gc.collect()
return
def image_transformer(dataset, mode):
"""
Convert images from Huggingface Dataset object to different mode.
The generated PNGs are usually RGBA. This function can convert them to RGB, grayscale among others.
Args:
dataset (object): Huggingface Dataset object
mode (str): String specifying mode to convert images to. Ex: "RGB", "L" for grayscale.
Returns:
object: Huggingface Dataset
"""
dataset["image"] = [image.convert(mode) for image in dataset["image"]]
return dataset