Spaces:
Building
Building
from torch.utils.data import Dataset | |
from pathlib import Path | |
from typing import Optional | |
import torch | |
from torch.utils.data import default_collate | |
from typing import Tuple | |
from functools import partial | |
from gyraudio.audio_separation.properties import ( | |
AUG_AWGN, AUG_RESCALE, AUG_TRIM, LENGTHS, LENGTH_DIVIDER, TRIM_PROB | |
) | |
class AudioDataset(Dataset): | |
def __init__( | |
self, | |
data_path: Path, | |
augmentation_config: dict = {}, | |
snr_filter: Optional[float] = None, | |
debug: bool = False | |
): | |
self.debug = debug | |
self.data_path = data_path | |
self.augmentation_config = augmentation_config | |
self.snr_filter = snr_filter | |
self.load_data() | |
self.length = len(self.file_list) | |
self.collate_fn = None | |
if AUG_TRIM in self.augmentation_config: | |
self.collate_fn = partial(collate_fn_generic, | |
lengths_lim=self.augmentation_config[AUG_TRIM][LENGTHS], | |
length_divider=self.augmentation_config[AUG_TRIM][LENGTH_DIVIDER], | |
trim_prob=self.augmentation_config[AUG_TRIM][TRIM_PROB]) | |
def filter_data(self, snr): | |
if self.snr_filter is None: | |
return True | |
if snr in self.snr_filter: | |
return True | |
else: | |
return False | |
def load_data(self): | |
raise NotImplementedError("load_data method must be implemented") | |
def augment_data(self, mixed_audio_signal, clean_audio_signal, noise_audio_signal): | |
if AUG_RESCALE in self.augmentation_config: | |
current_amplitude = 0.5 + 1.5*torch.rand(1, device=mixed_audio_signal.device) | |
# logging.debug(current_amplitude) | |
mixed_audio_signal *= current_amplitude | |
noise_audio_signal *= current_amplitude | |
clean_audio_signal *= current_amplitude | |
if AUG_AWGN in self.augmentation_config: | |
# noise_std = self.augmentation_config[AUG_AWGN]["noise_std"] | |
noise_std = 0.01 | |
current_noise_std = torch.randn(1) * noise_std | |
# logging.debug(current_noise_std) | |
extra_awgn = torch.randn(mixed_audio_signal.shape, device=mixed_audio_signal.device) * current_noise_std | |
mixed_audio_signal = mixed_audio_signal+extra_awgn | |
# Open question: should we add noise to the noise signal aswell? | |
return mixed_audio_signal, clean_audio_signal, noise_audio_signal | |
def __len__(self): | |
return self.length | |
def __getitem__(self, idx: int) -> torch.Tensor: | |
raise NotImplementedError("__getitem__ method must be implemented") | |
def collate_fn_generic(batch, lengths_lim, length_divider=1024, trim_prob=0.5) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Collate function to allow trimming (=crop the time dimension) of the signals in a batch. | |
Args: | |
batch (list): A list of tuples (triplets), where each tuple contain: | |
- mixed_audio_signal | |
- clean_audio_signal | |
- noise_audio_signal | |
lengths_lim (list) : A list of containing a minimum length (0) and a maximum length (1) | |
length_divider (int) : has to be a trimmed length divider | |
trim_prob (float) : trimming probability | |
Returns: | |
- Tensor: A batch of mixed_audio_signal, trimmed to the same length. | |
- Tensor: A batch of clean_audio_signal | |
- Tensor: A batch of noise_audio_signal | |
""" | |
# Find the length of the shortest signal in the batch | |
mixed_audio_signal, clean_audio_signal, noise_audio_signal = default_collate(batch) | |
length = mixed_audio_signal[0].shape[-1] | |
min_length, max_length = lengths_lim | |
take_full_signal = torch.rand(1) > trim_prob | |
if not take_full_signal: | |
start = torch.randint(0, length-min_length, (1,)) | |
trim_length = torch.randint(min_length, min(max_length, length-start-1)+1, (1,)) | |
trim_length = trim_length-trim_length % length_divider | |
end = start + trim_length | |
else: | |
start = 0 | |
end = length - length % length_divider | |
mixed_audio_signal = mixed_audio_signal[..., start:end] | |
clean_audio_signal = clean_audio_signal[..., start:end] | |
noise_audio_signal = noise_audio_signal[..., start:end] | |
return mixed_audio_signal, clean_audio_signal, noise_audio_signal | |