from torch.utils.data import Dataset from pathlib import Path from typing import Optional import torch from torch.utils.data import default_collate from typing import Tuple from functools import partial from gyraudio.audio_separation.properties import ( AUG_AWGN, AUG_RESCALE, AUG_TRIM, LENGTHS, LENGTH_DIVIDER, TRIM_PROB ) class AudioDataset(Dataset): def __init__( self, data_path: Path, augmentation_config: dict = {}, snr_filter: Optional[float] = None, debug: bool = False ): self.debug = debug self.data_path = data_path self.augmentation_config = augmentation_config self.snr_filter = snr_filter self.load_data() self.length = len(self.file_list) self.collate_fn = None if AUG_TRIM in self.augmentation_config: self.collate_fn = partial(collate_fn_generic, lengths_lim=self.augmentation_config[AUG_TRIM][LENGTHS], length_divider=self.augmentation_config[AUG_TRIM][LENGTH_DIVIDER], trim_prob=self.augmentation_config[AUG_TRIM][TRIM_PROB]) def filter_data(self, snr): if self.snr_filter is None: return True if snr in self.snr_filter: return True else: return False def load_data(self): raise NotImplementedError("load_data method must be implemented") def augment_data(self, mixed_audio_signal, clean_audio_signal, noise_audio_signal): if AUG_RESCALE in self.augmentation_config: current_amplitude = 0.5 + 1.5*torch.rand(1, device=mixed_audio_signal.device) # logging.debug(current_amplitude) mixed_audio_signal *= current_amplitude noise_audio_signal *= current_amplitude clean_audio_signal *= current_amplitude if AUG_AWGN in self.augmentation_config: # noise_std = self.augmentation_config[AUG_AWGN]["noise_std"] noise_std = 0.01 current_noise_std = torch.randn(1) * noise_std # logging.debug(current_noise_std) extra_awgn = torch.randn(mixed_audio_signal.shape, device=mixed_audio_signal.device) * current_noise_std mixed_audio_signal = mixed_audio_signal+extra_awgn # Open question: should we add noise to the noise signal aswell? return mixed_audio_signal, clean_audio_signal, noise_audio_signal def __len__(self): return self.length def __getitem__(self, idx: int) -> torch.Tensor: raise NotImplementedError("__getitem__ method must be implemented") def collate_fn_generic(batch, lengths_lim, length_divider=1024, trim_prob=0.5) -> Tuple[torch.Tensor, torch.Tensor]: """Collate function to allow trimming (=crop the time dimension) of the signals in a batch. Args: batch (list): A list of tuples (triplets), where each tuple contain: - mixed_audio_signal - clean_audio_signal - noise_audio_signal lengths_lim (list) : A list of containing a minimum length (0) and a maximum length (1) length_divider (int) : has to be a trimmed length divider trim_prob (float) : trimming probability Returns: - Tensor: A batch of mixed_audio_signal, trimmed to the same length. - Tensor: A batch of clean_audio_signal - Tensor: A batch of noise_audio_signal """ # Find the length of the shortest signal in the batch mixed_audio_signal, clean_audio_signal, noise_audio_signal = default_collate(batch) length = mixed_audio_signal[0].shape[-1] min_length, max_length = lengths_lim take_full_signal = torch.rand(1) > trim_prob if not take_full_signal: start = torch.randint(0, length-min_length, (1,)) trim_length = torch.randint(min_length, min(max_length, length-start-1)+1, (1,)) trim_length = trim_length-trim_length % length_divider end = start + trim_length else: start = 0 end = length - length % length_divider mixed_audio_signal = mixed_audio_signal[..., start:end] clean_audio_signal = clean_audio_signal[..., start:end] noise_audio_signal = noise_audio_signal[..., start:end] return mixed_audio_signal, clean_audio_signal, noise_audio_signal