hieupt commited on
Commit
52a0f2f
·
verified ·
1 Parent(s): 9b0805b

Upload musdb.py

Browse files
Files changed (1) hide show
  1. data/musdb.py +127 -0
data/musdb.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import musdb
2
+ import os
3
+ import numpy as np
4
+ import glob
5
+
6
+ from data.utils import load, write_wav
7
+
8
+
9
+ def get_musdbhq(database_path):
10
+ '''
11
+ Retrieve audio file paths for MUSDB HQ dataset
12
+ :param database_path: MUSDB HQ root directory
13
+ :return: dictionary with train and test keys, each containing list of samples, each sample containing all audio paths
14
+ '''
15
+ subsets = list()
16
+
17
+ for subset in ["train", "test"]:
18
+ print("Loading " + subset + " set...")
19
+ tracks = glob.glob(os.path.join(database_path, subset, "*"))
20
+ samples = list()
21
+
22
+ # Go through tracks
23
+ for track_folder in sorted(tracks):
24
+ # Skip track if mixture is already written, assuming this track is done already
25
+ example = dict()
26
+ for stem in ["mix", "bass", "drums", "other", "vocals"]:
27
+ filename = stem if stem != "mix" else "mixture"
28
+ audio_path = os.path.join(track_folder, filename + ".wav")
29
+ example[stem] = audio_path
30
+
31
+ # Add other instruments to form accompaniment
32
+ acc_path = os.path.join(track_folder, "accompaniment.wav")
33
+
34
+ if not os.path.exists(acc_path):
35
+ print("Writing accompaniment to " + track_folder)
36
+ stem_audio = []
37
+ for stem in ["bass", "drums", "other"]:
38
+ audio, sr = load(example[stem], sr=None, mono=False)
39
+ stem_audio.append(audio)
40
+ acc_audio = np.clip(sum(stem_audio), -1.0, 1.0)
41
+ write_wav(acc_path, acc_audio, sr)
42
+
43
+ example["accompaniment"] = acc_path
44
+
45
+ samples.append(example)
46
+
47
+ subsets.append(samples)
48
+
49
+ return subsets
50
+
51
+ def get_musdb(database_path):
52
+ '''
53
+ Retrieve audio file paths for MUSDB dataset
54
+ :param database_path: MUSDB root directory
55
+ :return: dictionary with train and test keys, each containing list of samples, each sample containing all audio paths
56
+ '''
57
+ mus = musdb.DB(root=database_path, is_wav=False)
58
+
59
+ subsets = list()
60
+
61
+ for subset in ["train", "test"]:
62
+ tracks = mus.load_mus_tracks(subset)
63
+ samples = list()
64
+
65
+ # Go through tracks
66
+ for track in sorted(tracks):
67
+ # Skip track if mixture is already written, assuming this track is done already
68
+ track_path = track.path[:-4]
69
+ mix_path = track_path + "_mix.wav"
70
+ acc_path = track_path + "_accompaniment.wav"
71
+ if os.path.exists(mix_path):
72
+ print("WARNING: Skipping track " + mix_path + " since it exists already")
73
+
74
+ # Add paths and then skip
75
+ paths = {"mix" : mix_path, "accompaniment" : acc_path}
76
+ paths.update({key : track_path + "_" + key + ".wav" for key in ["bass", "drums", "other", "vocals"]})
77
+
78
+ samples.append(paths)
79
+
80
+ continue
81
+
82
+ rate = track.rate
83
+
84
+ # Go through each instrument
85
+ paths = dict()
86
+ stem_audio = dict()
87
+ for stem in ["bass", "drums", "other", "vocals"]:
88
+ path = track_path + "_" + stem + ".wav"
89
+ audio = track.targets[stem].audio
90
+ write_wav(path, audio, rate)
91
+ stem_audio[stem] = audio
92
+ paths[stem] = path
93
+
94
+ # Add other instruments to form accompaniment
95
+ acc_audio = np.clip(sum([stem_audio[key] for key in list(stem_audio.keys()) if key != "vocals"]), -1.0, 1.0)
96
+ write_wav(acc_path, acc_audio, rate)
97
+ paths["accompaniment"] = acc_path
98
+
99
+ # Create mixture
100
+ mix_audio = track.audio
101
+ write_wav(mix_path, mix_audio, rate)
102
+ paths["mix"] = mix_path
103
+
104
+ diff_signal = np.abs(mix_audio - acc_audio - stem_audio["vocals"])
105
+ print("Maximum absolute deviation from source additivity constraint: " + str(np.max(diff_signal)))# Check if acc+vocals=mix
106
+ print("Mean absolute deviation from source additivity constraint: " + str(np.mean(diff_signal)))
107
+
108
+ samples.append(paths)
109
+
110
+ subsets.append(samples)
111
+
112
+ print("DONE preparing dataset!")
113
+ return subsets
114
+
115
+ def get_musdb_folds(root_path, version="HQ"):
116
+ if version == "HQ":
117
+ dataset = get_musdbhq(root_path)
118
+ else:
119
+ dataset = get_musdb(root_path)
120
+ train_val_list = dataset[0]
121
+ test_list = dataset[1]
122
+
123
+ np.random.seed(1337) # Ensure that partitioning is always the same on each run
124
+ train_list = np.random.choice(train_val_list, 75, replace=False)
125
+ val_list = [elem for elem in train_val_list if elem not in train_list]
126
+ # print("First training song: " + str(train_list[0])) # To debug whether partitioning is deterministic
127
+ return {"train" : train_list, "val" : val_list, "test" : test_list}