Wave_U_Net_audio / data /dataset.py
hieupt's picture
Upload dataset.py
efcaf28 verified
import os
import h5py
import numpy as np
from sortedcontainers import SortedList
from torch.utils.data import Dataset
from tqdm import tqdm
from data.utils import load
class SeparationDataset(Dataset):
def __init__(self, dataset, partition, instruments, sr, channels, shapes, random_hops, hdf_dir, audio_transform=None, in_memory=False):
'''
Initialises a source separation dataset
:param data: HDF audio data object
:param input_size: Number of input samples for each example
:param context_front: Number of extra context samples to prepend to input
:param context_back: NUmber of extra context samples to append to input
:param hop_size: Skip hop_size - 1 sample positions in the audio for each example (subsampling the audio)
:param random_hops: If False, sample examples evenly from whole audio signal according to hop_size parameter. If True, randomly sample a position from the audio
'''
super(SeparationDataset, self).__init__()
self.hdf_dataset = None
os.makedirs(hdf_dir, exist_ok=True)
self.hdf_dir = os.path.join(hdf_dir, partition + ".hdf5")
self.random_hops = random_hops
self.sr = sr
self.channels = channels
self.shapes = shapes
self.audio_transform = audio_transform
self.in_memory = in_memory
self.instruments = instruments
# PREPARE HDF FILE
# Check if HDF file exists already
if not os.path.exists(self.hdf_dir):
# Create folder if it did not exist before
if not os.path.exists(hdf_dir):
os.makedirs(hdf_dir)
# Create HDF file
with h5py.File(self.hdf_dir, "w") as f:
f.attrs["sr"] = sr
f.attrs["channels"] = channels
f.attrs["instruments"] = instruments
print("Adding audio files to dataset (preprocessing)...")
for idx, example in enumerate(tqdm(dataset[partition])):
# Load mix
mix_audio, _ = load(example["mix"], sr=self.sr, mono=(self.channels == 1))
source_audios = []
for source in instruments:
# In this case, read in audio and convert to target sampling rate
source_audio, _ = load(example[source], sr=self.sr, mono=(self.channels == 1))
source_audios.append(source_audio)
source_audios = np.concatenate(source_audios, axis=0)
assert(source_audios.shape[1] == mix_audio.shape[1])
# Add to HDF5 file
grp = f.create_group(str(idx))
grp.create_dataset("inputs", shape=mix_audio.shape, dtype=mix_audio.dtype, data=mix_audio)
grp.create_dataset("targets", shape=source_audios.shape, dtype=source_audios.dtype, data=source_audios)
grp.attrs["length"] = mix_audio.shape[1]
grp.attrs["target_length"] = source_audios.shape[1]
# In that case, check whether sr and channels are complying with the audio in the HDF file, otherwise raise error
with h5py.File(self.hdf_dir, "r") as f:
if f.attrs["sr"] != sr or \
f.attrs["channels"] != channels or \
list(f.attrs["instruments"]) != instruments:
raise ValueError(
"Tried to load existing HDF file, but sampling rate and channel or instruments are not as expected. Did you load an out-dated HDF file?")
# HDF FILE READY
# SET SAMPLING POSITIONS
# Go through HDF and collect lengths of all audio files
with h5py.File(self.hdf_dir, "r") as f:
lengths = [f[str(song_idx)].attrs["target_length"] for song_idx in range(len(f))]
# Subtract input_size from lengths and divide by hop size to determine number of starting positions
lengths = [(l // self.shapes["output_frames"]) + 1 for l in lengths]
self.start_pos = SortedList(np.cumsum(lengths))
self.length = self.start_pos[-1]
def __getitem__(self, index):
# Open HDF5
if self.hdf_dataset is None:
driver = "core" if self.in_memory else None # Load HDF5 fully into memory if desired
self.hdf_dataset = h5py.File(self.hdf_dir, 'r', driver=driver)
# Find out which slice of targets we want to read
audio_idx = self.start_pos.bisect_right(index)
if audio_idx > 0:
index = index - self.start_pos[audio_idx - 1]
# Check length of audio signal
audio_length = self.hdf_dataset[str(audio_idx)].attrs["length"]
target_length = self.hdf_dataset[str(audio_idx)].attrs["target_length"]
# Determine position where to start targets
if self.random_hops:
start_target_pos = np.random.randint(0, max(target_length - self.shapes["output_frames"] + 1, 1))
else:
# Map item index to sample position within song
start_target_pos = index * self.shapes["output_frames"]
# READ INPUTS
# Check front padding
start_pos = start_target_pos - self.shapes["output_start_frame"]
if start_pos < 0:
# Pad manually since audio signal was too short
pad_front = abs(start_pos)
start_pos = 0
else:
pad_front = 0
# Check back padding
end_pos = start_target_pos - self.shapes["output_start_frame"] + self.shapes["input_frames"]
if end_pos > audio_length:
# Pad manually since audio signal was too short
pad_back = end_pos - audio_length
end_pos = audio_length
else:
pad_back = 0
# Read and return
audio = self.hdf_dataset[str(audio_idx)]["inputs"][:, start_pos:end_pos].astype(np.float32)
if pad_front > 0 or pad_back > 0:
audio = np.pad(audio, [(0, 0), (pad_front, pad_back)], mode="constant", constant_values=0.0)
targets = self.hdf_dataset[str(audio_idx)]["targets"][:, start_pos:end_pos].astype(np.float32)
if pad_front > 0 or pad_back > 0:
targets = np.pad(targets, [(0, 0), (pad_front, pad_back)], mode="constant", constant_values=0.0)
targets = {inst : targets[idx*self.channels:(idx+1)*self.channels] for idx, inst in enumerate(self.instruments)}
if hasattr(self, "audio_transform") and self.audio_transform is not None:
audio, targets = self.audio_transform(audio, targets)
return audio, targets
def __len__(self):
return self.length