Spaces:
Building
Building
File size: 1,673 Bytes
f6b56a2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from torch.utils.data import DataLoader
from gyraudio.audio_separation.data.mixed import MixedAudioDataset
from typing import Optional, List
from gyraudio.audio_separation.properties import (
DATA_PATH, AUGMENTATION, SNR_FILTER, SHUFFLE, BATCH_SIZE, TRAIN, VALID, TEST, AUG_TRIM
)
from gyraudio import root_dir
RAW_AUDIO_ROOT = root_dir/"__data_source_separation"/"voice_origin"
MIXED_AUDIO_ROOT = root_dir/"__data_source_separation"/"source_separation"
def get_dataloader(configurations: dict, audio_dataset=MixedAudioDataset):
dataloaders = {}
for mode, configuration in configurations.items():
dataset = audio_dataset(
configuration[DATA_PATH],
augmentation_config=configuration[AUGMENTATION],
snr_filter=configuration[SNR_FILTER]
)
dl = DataLoader(
dataset,
shuffle=configuration[SHUFFLE],
batch_size=configuration[BATCH_SIZE],
collate_fn=dataset.collate_fn
)
dataloaders[mode] = dl
return dataloaders
def get_config_dataloader(
audio_root=MIXED_AUDIO_ROOT,
mode: str = TRAIN,
shuffle: Optional[bool] = None,
batch_size: Optional[int] = 16,
snr_filter: Optional[List[float]] = None,
augmentation: dict = {}):
audio_folder = audio_root/mode
assert mode in [TRAIN, VALID, TEST]
assert audio_folder.exists()
config = {
DATA_PATH: audio_folder,
SHUFFLE: shuffle if shuffle is not None else (True if mode == TRAIN else False),
AUGMENTATION: augmentation,
SNR_FILTER: snr_filter,
BATCH_SIZE: batch_size
}
return config
|