import os import h5py import numpy as np from sortedcontainers import SortedList from torch.utils.data import Dataset from tqdm import tqdm from data.utils import load class SeparationDataset(Dataset): def __init__(self, dataset, partition, instruments, sr, channels, shapes, random_hops, hdf_dir, audio_transform=None, in_memory=False): ''' Initialises a source separation dataset :param data: HDF audio data object :param input_size: Number of input samples for each example :param context_front: Number of extra context samples to prepend to input :param context_back: NUmber of extra context samples to append to input :param hop_size: Skip hop_size - 1 sample positions in the audio for each example (subsampling the audio) :param random_hops: If False, sample examples evenly from whole audio signal according to hop_size parameter. If True, randomly sample a position from the audio ''' super(SeparationDataset, self).__init__() self.hdf_dataset = None os.makedirs(hdf_dir, exist_ok=True) self.hdf_dir = os.path.join(hdf_dir, partition + ".hdf5") self.random_hops = random_hops self.sr = sr self.channels = channels self.shapes = shapes self.audio_transform = audio_transform self.in_memory = in_memory self.instruments = instruments # PREPARE HDF FILE # Check if HDF file exists already if not os.path.exists(self.hdf_dir): # Create folder if it did not exist before if not os.path.exists(hdf_dir): os.makedirs(hdf_dir) # Create HDF file with h5py.File(self.hdf_dir, "w") as f: f.attrs["sr"] = sr f.attrs["channels"] = channels f.attrs["instruments"] = instruments print("Adding audio files to dataset (preprocessing)...") for idx, example in enumerate(tqdm(dataset[partition])): # Load mix mix_audio, _ = load(example["mix"], sr=self.sr, mono=(self.channels == 1)) source_audios = [] for source in instruments: # In this case, read in audio and convert to target sampling rate source_audio, _ = load(example[source], sr=self.sr, mono=(self.channels == 1)) source_audios.append(source_audio) source_audios = np.concatenate(source_audios, axis=0) assert(source_audios.shape[1] == mix_audio.shape[1]) # Add to HDF5 file grp = f.create_group(str(idx)) grp.create_dataset("inputs", shape=mix_audio.shape, dtype=mix_audio.dtype, data=mix_audio) grp.create_dataset("targets", shape=source_audios.shape, dtype=source_audios.dtype, data=source_audios) grp.attrs["length"] = mix_audio.shape[1] grp.attrs["target_length"] = source_audios.shape[1] # In that case, check whether sr and channels are complying with the audio in the HDF file, otherwise raise error with h5py.File(self.hdf_dir, "r") as f: if f.attrs["sr"] != sr or \ f.attrs["channels"] != channels or \ list(f.attrs["instruments"]) != instruments: raise ValueError( "Tried to load existing HDF file, but sampling rate and channel or instruments are not as expected. Did you load an out-dated HDF file?") # HDF FILE READY # SET SAMPLING POSITIONS # Go through HDF and collect lengths of all audio files with h5py.File(self.hdf_dir, "r") as f: lengths = [f[str(song_idx)].attrs["target_length"] for song_idx in range(len(f))] # Subtract input_size from lengths and divide by hop size to determine number of starting positions lengths = [(l // self.shapes["output_frames"]) + 1 for l in lengths] self.start_pos = SortedList(np.cumsum(lengths)) self.length = self.start_pos[-1] def __getitem__(self, index): # Open HDF5 if self.hdf_dataset is None: driver = "core" if self.in_memory else None # Load HDF5 fully into memory if desired self.hdf_dataset = h5py.File(self.hdf_dir, 'r', driver=driver) # Find out which slice of targets we want to read audio_idx = self.start_pos.bisect_right(index) if audio_idx > 0: index = index - self.start_pos[audio_idx - 1] # Check length of audio signal audio_length = self.hdf_dataset[str(audio_idx)].attrs["length"] target_length = self.hdf_dataset[str(audio_idx)].attrs["target_length"] # Determine position where to start targets if self.random_hops: start_target_pos = np.random.randint(0, max(target_length - self.shapes["output_frames"] + 1, 1)) else: # Map item index to sample position within song start_target_pos = index * self.shapes["output_frames"] # READ INPUTS # Check front padding start_pos = start_target_pos - self.shapes["output_start_frame"] if start_pos < 0: # Pad manually since audio signal was too short pad_front = abs(start_pos) start_pos = 0 else: pad_front = 0 # Check back padding end_pos = start_target_pos - self.shapes["output_start_frame"] + self.shapes["input_frames"] if end_pos > audio_length: # Pad manually since audio signal was too short pad_back = end_pos - audio_length end_pos = audio_length else: pad_back = 0 # Read and return audio = self.hdf_dataset[str(audio_idx)]["inputs"][:, start_pos:end_pos].astype(np.float32) if pad_front > 0 or pad_back > 0: audio = np.pad(audio, [(0, 0), (pad_front, pad_back)], mode="constant", constant_values=0.0) targets = self.hdf_dataset[str(audio_idx)]["targets"][:, start_pos:end_pos].astype(np.float32) if pad_front > 0 or pad_back > 0: targets = np.pad(targets, [(0, 0), (pad_front, pad_back)], mode="constant", constant_values=0.0) targets = {inst : targets[idx*self.channels:(idx+1)*self.channels] for idx, inst in enumerate(self.instruments)} if hasattr(self, "audio_transform") and self.audio_transform is not None: audio, targets = self.audio_transform(audio, targets) return audio, targets def __len__(self): return self.length