|
import os |
|
import re |
|
import argparse |
|
from pydub import AudioSegment |
|
|
|
class GTZAN: |
|
def __init__(self, root_dir, output_dir, labels): |
|
""" |
|
Args: |
|
root_dir (str): Root directory of the dataset. |
|
output_dir (str): Output directory to save converted MP3 files. |
|
labels (list): List of genres in the dataset. |
|
""" |
|
self.root_dir = root_dir |
|
self.output_dir = output_dir |
|
self.labels = labels |
|
|
|
|
|
self.create_output_dirs() |
|
|
|
def create_output_dirs(self): |
|
"""Create directories to store train and test audio files""" |
|
for split in ['train', 'test']: |
|
for genre in self.labels: |
|
genre_dir = os.path.join(self.output_dir, split, genre) |
|
os.makedirs(genre_dir, exist_ok=True) |
|
|
|
def split_train_test(self, audio_names, test_fold): |
|
""" |
|
Split the dataset into train and test sets based on test_fold. |
|
E.g., test_ids = [30, 31, 32, ..., 39]. |
|
""" |
|
test_audio_names = [] |
|
train_audio_names = [] |
|
|
|
test_ids = range(test_fold * 10, (test_fold + 1) * 10) |
|
|
|
for audio_name in audio_names: |
|
|
|
audio_id = int(re.search(r'\d+', audio_name).group()) |
|
|
|
if audio_id in test_ids: |
|
test_audio_names.append(audio_name) |
|
else: |
|
train_audio_names.append(audio_name) |
|
|
|
return train_audio_names, test_audio_names |
|
|
|
def convert_and_save(self, file_path, target_path): |
|
"""Convert AU format to MP3 and save to target path""" |
|
audio = AudioSegment.from_file(file_path, format="au") |
|
audio.export(target_path, format="mp3") |
|
print(f"Converted and saved {target_path}") |
|
|
|
def process_genre(self, genre, test_fold): |
|
"""Process a single genre, split the dataset, and convert formats""" |
|
genre_path = os.path.join(self.root_dir, genre) |
|
audio_files = os.listdir(genre_path) |
|
|
|
|
|
train_files, test_files = self.split_train_test(audio_files, test_fold) |
|
|
|
|
|
for audio_name in train_files: |
|
file_path = os.path.join(genre_path, audio_name) |
|
target_path = os.path.join(self.output_dir, 'train', genre, audio_name.replace('.au', '.mp3')) |
|
self.convert_and_save(file_path, target_path) |
|
|
|
|
|
for audio_name in test_files: |
|
file_path = os.path.join(genre_path, audio_name) |
|
target_path = os.path.join(self.output_dir, 'test', genre, audio_name.replace('.au', '.mp3')) |
|
self.convert_and_save(file_path, target_path) |
|
|
|
def process_dataset(self): |
|
"""Process the entire GTZAN dataset and split it into train and test sets""" |
|
for idx, genre in enumerate(self.labels): |
|
print(f"Processing genre: {genre}...") |
|
test_fold = idx % 10 |
|
self.process_genre(genre, test_fold) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
parser = argparse.ArgumentParser(description="GTZAN Dataset Converter") |
|
parser.add_argument('--root_dir', type=str, required=True, help='Root directory of the GTZAN dataset') |
|
parser.add_argument('--output_dir', type=str, required=True, help='Directory to save the converted MP3 files') |
|
args = parser.parse_args() |
|
|
|
|
|
labels = ["blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock"] |
|
|
|
|
|
gtzan = GTZAN(args.root_dir, args.output_dir, labels) |
|
gtzan.process_dataset() |
|
|
|
|
|
|
|
|