hieupt commited on
Commit
efcaf28
·
verified ·
1 Parent(s): acf2c04

Upload dataset.py

Browse files
Files changed (1) hide show
  1. data/dataset.py +152 -0
data/dataset.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import h5py
4
+ import numpy as np
5
+ from sortedcontainers import SortedList
6
+ from torch.utils.data import Dataset
7
+ from tqdm import tqdm
8
+
9
+ from data.utils import load
10
+
11
+
12
+ class SeparationDataset(Dataset):
13
+ def __init__(self, dataset, partition, instruments, sr, channels, shapes, random_hops, hdf_dir, audio_transform=None, in_memory=False):
14
+ '''
15
+ Initialises a source separation dataset
16
+ :param data: HDF audio data object
17
+ :param input_size: Number of input samples for each example
18
+ :param context_front: Number of extra context samples to prepend to input
19
+ :param context_back: NUmber of extra context samples to append to input
20
+ :param hop_size: Skip hop_size - 1 sample positions in the audio for each example (subsampling the audio)
21
+ :param random_hops: If False, sample examples evenly from whole audio signal according to hop_size parameter. If True, randomly sample a position from the audio
22
+ '''
23
+
24
+ super(SeparationDataset, self).__init__()
25
+
26
+ self.hdf_dataset = None
27
+ os.makedirs(hdf_dir, exist_ok=True)
28
+ self.hdf_dir = os.path.join(hdf_dir, partition + ".hdf5")
29
+
30
+ self.random_hops = random_hops
31
+ self.sr = sr
32
+ self.channels = channels
33
+ self.shapes = shapes
34
+ self.audio_transform = audio_transform
35
+ self.in_memory = in_memory
36
+ self.instruments = instruments
37
+
38
+ # PREPARE HDF FILE
39
+
40
+ # Check if HDF file exists already
41
+ if not os.path.exists(self.hdf_dir):
42
+ # Create folder if it did not exist before
43
+ if not os.path.exists(hdf_dir):
44
+ os.makedirs(hdf_dir)
45
+
46
+ # Create HDF file
47
+ with h5py.File(self.hdf_dir, "w") as f:
48
+ f.attrs["sr"] = sr
49
+ f.attrs["channels"] = channels
50
+ f.attrs["instruments"] = instruments
51
+
52
+ print("Adding audio files to dataset (preprocessing)...")
53
+ for idx, example in enumerate(tqdm(dataset[partition])):
54
+ # Load mix
55
+ mix_audio, _ = load(example["mix"], sr=self.sr, mono=(self.channels == 1))
56
+
57
+ source_audios = []
58
+ for source in instruments:
59
+ # In this case, read in audio and convert to target sampling rate
60
+ source_audio, _ = load(example[source], sr=self.sr, mono=(self.channels == 1))
61
+ source_audios.append(source_audio)
62
+ source_audios = np.concatenate(source_audios, axis=0)
63
+ assert(source_audios.shape[1] == mix_audio.shape[1])
64
+
65
+ # Add to HDF5 file
66
+ grp = f.create_group(str(idx))
67
+ grp.create_dataset("inputs", shape=mix_audio.shape, dtype=mix_audio.dtype, data=mix_audio)
68
+ grp.create_dataset("targets", shape=source_audios.shape, dtype=source_audios.dtype, data=source_audios)
69
+ grp.attrs["length"] = mix_audio.shape[1]
70
+ grp.attrs["target_length"] = source_audios.shape[1]
71
+
72
+ # In that case, check whether sr and channels are complying with the audio in the HDF file, otherwise raise error
73
+ with h5py.File(self.hdf_dir, "r") as f:
74
+ if f.attrs["sr"] != sr or \
75
+ f.attrs["channels"] != channels or \
76
+ list(f.attrs["instruments"]) != instruments:
77
+ raise ValueError(
78
+ "Tried to load existing HDF file, but sampling rate and channel or instruments are not as expected. Did you load an out-dated HDF file?")
79
+
80
+ # HDF FILE READY
81
+
82
+ # SET SAMPLING POSITIONS
83
+
84
+ # Go through HDF and collect lengths of all audio files
85
+ with h5py.File(self.hdf_dir, "r") as f:
86
+ lengths = [f[str(song_idx)].attrs["target_length"] for song_idx in range(len(f))]
87
+
88
+ # Subtract input_size from lengths and divide by hop size to determine number of starting positions
89
+ lengths = [(l // self.shapes["output_frames"]) + 1 for l in lengths]
90
+
91
+ self.start_pos = SortedList(np.cumsum(lengths))
92
+ self.length = self.start_pos[-1]
93
+
94
+ def __getitem__(self, index):
95
+ # Open HDF5
96
+ if self.hdf_dataset is None:
97
+ driver = "core" if self.in_memory else None # Load HDF5 fully into memory if desired
98
+ self.hdf_dataset = h5py.File(self.hdf_dir, 'r', driver=driver)
99
+
100
+ # Find out which slice of targets we want to read
101
+ audio_idx = self.start_pos.bisect_right(index)
102
+ if audio_idx > 0:
103
+ index = index - self.start_pos[audio_idx - 1]
104
+
105
+ # Check length of audio signal
106
+ audio_length = self.hdf_dataset[str(audio_idx)].attrs["length"]
107
+ target_length = self.hdf_dataset[str(audio_idx)].attrs["target_length"]
108
+
109
+ # Determine position where to start targets
110
+ if self.random_hops:
111
+ start_target_pos = np.random.randint(0, max(target_length - self.shapes["output_frames"] + 1, 1))
112
+ else:
113
+ # Map item index to sample position within song
114
+ start_target_pos = index * self.shapes["output_frames"]
115
+
116
+ # READ INPUTS
117
+ # Check front padding
118
+ start_pos = start_target_pos - self.shapes["output_start_frame"]
119
+ if start_pos < 0:
120
+ # Pad manually since audio signal was too short
121
+ pad_front = abs(start_pos)
122
+ start_pos = 0
123
+ else:
124
+ pad_front = 0
125
+
126
+ # Check back padding
127
+ end_pos = start_target_pos - self.shapes["output_start_frame"] + self.shapes["input_frames"]
128
+ if end_pos > audio_length:
129
+ # Pad manually since audio signal was too short
130
+ pad_back = end_pos - audio_length
131
+ end_pos = audio_length
132
+ else:
133
+ pad_back = 0
134
+
135
+ # Read and return
136
+ audio = self.hdf_dataset[str(audio_idx)]["inputs"][:, start_pos:end_pos].astype(np.float32)
137
+ if pad_front > 0 or pad_back > 0:
138
+ audio = np.pad(audio, [(0, 0), (pad_front, pad_back)], mode="constant", constant_values=0.0)
139
+
140
+ targets = self.hdf_dataset[str(audio_idx)]["targets"][:, start_pos:end_pos].astype(np.float32)
141
+ if pad_front > 0 or pad_back > 0:
142
+ targets = np.pad(targets, [(0, 0), (pad_front, pad_back)], mode="constant", constant_values=0.0)
143
+
144
+ targets = {inst : targets[idx*self.channels:(idx+1)*self.channels] for idx, inst in enumerate(self.instruments)}
145
+
146
+ if hasattr(self, "audio_transform") and self.audio_transform is not None:
147
+ audio, targets = self.audio_transform(audio, targets)
148
+
149
+ return audio, targets
150
+
151
+ def __len__(self):
152
+ return self.length