File size: 3,902 Bytes
070e26e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import os
import re
import argparse
from pydub import AudioSegment

class GTZAN:
    def __init__(self, root_dir, output_dir, labels):
        """
        Args:
            root_dir (str): Root directory of the dataset.
            output_dir (str): Output directory to save converted MP3 files.
            labels (list): List of genres in the dataset.
        """
        self.root_dir = root_dir
        self.output_dir = output_dir
        self.labels = labels

        # Create output directory structure for MP3 files
        self.create_output_dirs()

    def create_output_dirs(self):
        """Create directories to store train and test audio files"""
        for split in ['train', 'test']:
            for genre in self.labels:
                genre_dir = os.path.join(self.output_dir, split, genre)
                os.makedirs(genre_dir, exist_ok=True)

    def split_train_test(self, audio_names, test_fold):
        """
        Split the dataset into train and test sets based on test_fold.
        E.g., test_ids = [30, 31, 32, ..., 39].
        """
        test_audio_names = []
        train_audio_names = []

        test_ids = range(test_fold * 10, (test_fold + 1) * 10)

        for audio_name in audio_names:
            # Extract the numeric ID from the audio file name
            audio_id = int(re.search(r'\d+', audio_name).group())

            if audio_id in test_ids:
                test_audio_names.append(audio_name)
            else:
                train_audio_names.append(audio_name)

        return train_audio_names, test_audio_names

    def convert_and_save(self, file_path, target_path):
        """Convert AU format to MP3 and save to target path"""
        audio = AudioSegment.from_file(file_path, format="au")
        audio.export(target_path, format="mp3")
        print(f"Converted and saved {target_path}")

    def process_genre(self, genre, test_fold):
        """Process a single genre, split the dataset, and convert formats"""
        genre_path = os.path.join(self.root_dir, genre)
        audio_files = os.listdir(genre_path)

        # Split the dataset
        train_files, test_files = self.split_train_test(audio_files, test_fold)

        # Process training set
        for audio_name in train_files:
            file_path = os.path.join(genre_path, audio_name)
            target_path = os.path.join(self.output_dir, 'train', genre, audio_name.replace('.au', '.mp3'))
            self.convert_and_save(file_path, target_path)

        # Process test set
        for audio_name in test_files:
            file_path = os.path.join(genre_path, audio_name)
            target_path = os.path.join(self.output_dir, 'test', genre, audio_name.replace('.au', '.mp3'))
            self.convert_and_save(file_path, target_path)

    def process_dataset(self):
        """Process the entire GTZAN dataset and split it into train and test sets"""
        for idx, genre in enumerate(self.labels):
            print(f"Processing genre: {genre}...")
            test_fold = idx % 10  # Each genre has a different test_fold
            self.process_genre(genre, test_fold)


if __name__ == "__main__":
    # Define argument parser
    parser = argparse.ArgumentParser(description="GTZAN Dataset Converter")
    parser.add_argument('--root_dir', type=str, required=True, help='Root directory of the GTZAN dataset')
    parser.add_argument('--output_dir', type=str, required=True, help='Directory to save the converted MP3 files')
    args = parser.parse_args()

    # Example genre labels in the GTZAN dataset
    labels = ["blues", "classical", "country", "disco", "hiphop", "jazz", "metal", "pop", "reggae", "rock"]

    # Initialize the GTZAN processor
    gtzan = GTZAN(args.root_dir, args.output_dir, labels)
    gtzan.process_dataset()

### how to use
# python gtzan_converter.py --root_dir /path/to/gtzan/genres --output_dir /path/to/output/directory