Spaces:
Building
Building
from torch.utils.data import DataLoader | |
from gyraudio.audio_separation.data.mixed import MixedAudioDataset | |
from typing import Optional, List | |
from gyraudio.audio_separation.properties import ( | |
DATA_PATH, AUGMENTATION, SNR_FILTER, SHUFFLE, BATCH_SIZE, TRAIN, VALID, TEST, AUG_TRIM | |
) | |
from gyraudio import root_dir | |
RAW_AUDIO_ROOT = root_dir/"__data_source_separation"/"voice_origin" | |
MIXED_AUDIO_ROOT = root_dir/"__data_source_separation"/"source_separation" | |
def get_dataloader(configurations: dict, audio_dataset=MixedAudioDataset): | |
dataloaders = {} | |
for mode, configuration in configurations.items(): | |
dataset = audio_dataset( | |
configuration[DATA_PATH], | |
augmentation_config=configuration[AUGMENTATION], | |
snr_filter=configuration[SNR_FILTER] | |
) | |
dl = DataLoader( | |
dataset, | |
shuffle=configuration[SHUFFLE], | |
batch_size=configuration[BATCH_SIZE], | |
collate_fn=dataset.collate_fn | |
) | |
dataloaders[mode] = dl | |
return dataloaders | |
def get_config_dataloader( | |
audio_root=MIXED_AUDIO_ROOT, | |
mode: str = TRAIN, | |
shuffle: Optional[bool] = None, | |
batch_size: Optional[int] = 16, | |
snr_filter: Optional[List[float]] = None, | |
augmentation: dict = {}): | |
audio_folder = audio_root/mode | |
assert mode in [TRAIN, VALID, TEST] | |
assert audio_folder.exists() | |
config = { | |
DATA_PATH: audio_folder, | |
SHUFFLE: shuffle if shuffle is not None else (True if mode == TRAIN else False), | |
AUGMENTATION: augmentation, | |
SNR_FILTER: snr_filter, | |
BATCH_SIZE: batch_size | |
} | |
return config | |