File size: 1,673 Bytes
f6b56a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from torch.utils.data import DataLoader
from gyraudio.audio_separation.data.mixed import MixedAudioDataset
from typing import Optional, List
from gyraudio.audio_separation.properties import (
    DATA_PATH, AUGMENTATION, SNR_FILTER, SHUFFLE, BATCH_SIZE, TRAIN, VALID, TEST, AUG_TRIM
)
from gyraudio import root_dir
RAW_AUDIO_ROOT = root_dir/"__data_source_separation"/"voice_origin"
MIXED_AUDIO_ROOT = root_dir/"__data_source_separation"/"source_separation"


def get_dataloader(configurations: dict, audio_dataset=MixedAudioDataset):
    dataloaders = {}
    for mode, configuration in configurations.items():
        dataset = audio_dataset(
            configuration[DATA_PATH],
            augmentation_config=configuration[AUGMENTATION],
            snr_filter=configuration[SNR_FILTER]
        )
        dl = DataLoader(
            dataset,
            shuffle=configuration[SHUFFLE],
            batch_size=configuration[BATCH_SIZE],
            collate_fn=dataset.collate_fn
        )
        dataloaders[mode] = dl
    return dataloaders


def get_config_dataloader(
        audio_root=MIXED_AUDIO_ROOT,
        mode: str = TRAIN,
        shuffle: Optional[bool] = None,
        batch_size: Optional[int] = 16,
        snr_filter: Optional[List[float]] = None,
        augmentation: dict = {}):
    audio_folder = audio_root/mode
    assert mode in [TRAIN, VALID, TEST]
    assert audio_folder.exists()
    config = {
        DATA_PATH: audio_folder,
        SHUFFLE: shuffle if shuffle is not None else (True if mode == TRAIN else False),
        AUGMENTATION: augmentation,
        SNR_FILTER: snr_filter,
        BATCH_SIZE: batch_size
    }
    return config