File size: 5,926 Bytes
e2eef75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import librosa
import numpy as np
import os
import gc
import matplotlib
import matplotlib.pyplot as plt

def load_audio_file(path, sample_rate=22050, resampling_type="kaiser_fast", duration=30):
    """Load an audio file as a numpy string using Librosa library.

    Args:
        path (str): Path to audio file.
        sample_rate (int, optional): Sample rate to resample audio file to. 
            "None" uses the file's original sample rate. Defaults to 44100.
        resampling_type (str, optional): Method to use for resampling. Defaults to "kaiser_fast".
        duration (int, optional): Length to pad/shorten audio files to. 
            0 returns original audio length. Defaults to 30.

    Returns:
        numpy.array: Audio file as numpy array.
    """
    # Load an audio file with librosa. Resamples the file to a specified sample rate.
    audio_array, _ = librosa.load(path, sr=sample_rate, mono=True, res_type=resampling_type)
    if duration > 0:
        audio_array = pad_audio(audio_array, sample_rate, 30)
    return audio_array

def add_noise(audio_array, std):
    noise = np.random.normal(0, std, audio_array.shape)
    return audio_array + noise

def pad_audio(audio_array, sample_rate=22050, duration=30):
    # If audio array is shorter than 30s*sample rate -> pad
    # If audio array is longer than 30s*sample rate -> shorten
    duration_samples = duration * sample_rate
    audio_len = audio_array.size
    
    if audio_len < duration_samples:
        audio_array = np.pad(audio_array, (duration_samples - audio_len)//2)
    elif audio_len > duration_samples:
        audio_array = audio_array[:duration_samples]
    return audio_array

def log_mel_spectrogram(audio_array, sr=22050, nfft=2048, hop_length=512, window="hann"):
    S = librosa.feature.melspectrogram(y=audio_array, sr=sr, n_fft=nfft, 
                                       hop_length=hop_length, win_length=nfft, 
                                       window=window)
    S_db = librosa.power_to_db(S, ref=np.max)
    return S_db

def split_spectrogram(spectrogram, output_shape=(128, 256)):
    # Split spectrogram into equal chunks along the column axis.
    splits = []
    col_idx = 0
    while col_idx + output_shape[1] <= spectrogram.shape[1]:
        spec_split = spectrogram[:, col_idx:col_idx+output_shape[1]]
        splits.append(spec_split)
        col_idx += output_shape[1]
    return splits

def save_spectrogram_as_png(spectrogram, save_path, sample_rate=22050, nfft=2048, hop_length=512):
    shape = spectrogram.shape
    fig, ax = plt.subplots(1, 1, figsize=(shape[1]/100, shape[0]/100))
    fig.subplots_adjust(top=1.0, bottom=0, right=1.0, left=0, hspace=0, wspace=0)
    ax.set_axis_off()
    librosa.display.specshow(data=spectrogram, sr=sample_rate, n_fft=nfft, hop_length=hop_length, ax=ax)
    plt.savefig(save_path, bbox_inches=None, pad_inches=0)
    plt.close(fig)
    return

def extract_features(df, audio_dir, save_path, 
                     sr=22050, rs_type="kaiser_fast", 
                     output_shape=(128,256), duration=30
                     , nfft=2048, hop_length=512, window="hann", checkpoint_id=0):
    """
    Loads audio files, computes log-mel-spectrogram and saves it as png.
    Args:
        df (_type_): DataFrame containing ids and genres. Should only contain samples from specific data split (train/val/test).
        audio_dir (_type_): Directory containing all audio files.
        save_path (_type_): Path to where spectrograms will be saved.
        sr (_type_): Sampling rate to set for all loaded audio files.
        rs_type (_type_): Resampling method used when loading audio file to specific sampling rate.
        output_shape (_type_): Shape of each spectrogram split.
        duration (_type_): Set to standardize length of all audio files. Longer or shorter will be cut or padded respectively.
        nfft (_type_): Number of samples for every fft window.
        hop_length (_type_): Hop length to use for STFT.
        window (_type_): Window function to use for STFT.
        checkpoint_id (int, optional): Write the id of a track to start from there. Defaults to 0.
    """
    matplotlib.use("Agg")
    
    if int(checkpoint_id) > 0:
        df = df.loc[checkpoint_id:]
    
    id_list = df.index.ravel()
    genre_list = df["genre_top"].ravel()
    
    # Due to some weird memory leak, garbage collection is manually performed every 10% of progress.
    gc_interval = int(len(id_list) * 0.1)
    gc_checkpoints = id_list[::gc_interval]
    
    for id, genre in zip(id_list, genre_list):
        id_string = str(id).rjust(6, "0")
        filename = id_string + ".mp3"
        folder_name = filename[:3]
        file_path = os.path.join(audio_dir, folder_name, filename)
        
        print(id_string, end=" ")
        audio = load_audio_file(file_path, sr, rs_type, duration)
        
        spectrogram = log_mel_spectrogram(audio, sr=sr, nfft=nfft, hop_length=hop_length, window=window)
        spec_splits = split_spectrogram(spectrogram, output_shape)
        
        for idx, split in enumerate(spec_splits):
            image_name = id_string + "_" + str(idx+1) +".png"
            
            image_path = os.path.join(save_path, genre, image_name)
            save_spectrogram_as_png(split, image_path, sr, nfft, hop_length)

        if id in gc_checkpoints:
            gc.collect() 
    return


def image_transformer(dataset, mode):
    """
    Convert images from Huggingface Dataset object to different mode.
    The generated PNGs are usually RGBA. This function can convert them to RGB, grayscale among others.

    Args:
        dataset (object): Huggingface Dataset object
        mode (str): String specifying mode to convert images to. Ex: "RGB", "L" for grayscale.

    Returns:
        object: Huggingface Dataset
    """
    dataset["image"] = [image.convert(mode) for image in dataset["image"]]
    return dataset