File size: 4,563 Bytes
f6b56a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fca4341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6b56a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
from gyraudio.default_locations import MIXED_AUDIO_ROOT
from gyraudio.audio_separation.properties import (
    TRAIN, TEST, VALID, NAME, EPOCHS, LEARNING_RATE,
    OPTIMIZER, BATCH_SIZE, DATALOADER, AUGMENTATION,
    SHORT_NAME, AUG_TRIM, TRIM_PROB, LENGTH_DIVIDER, LENGTHS, SNR_FILTER
)
from gyraudio.audio_separation.data.remixed_fixed import RemixedFixedAudioDataset
from gyraudio.audio_separation.data.remixed_rnd import RemixedRandomAudioDataset
from gyraudio.audio_separation.data import get_dataloader, get_config_dataloader
from gyraudio.audio_separation.experiment_tracking.experiments_definition import get_experiment_generator
import torch
from typing import Tuple


def get_experience(exp_major: int, exp_minor: int = 0, snr_filter_test=None, dry_run=False) -> Tuple[str, torch.nn.Module, dict, dict]:
    """Get all experience details

    Args:
        exp_major (int): Major experience number
        exp_minor (int, optional): Used for HP search. Defaults to 0.


    Returns:
        Tuple[str, torch.nn.Module, dict, dict]: short_name, model, config, dataloaders
    """
    model = None
    config = {}
    dataloader_name = "remix"
    config = {
        NAME: None,
        OPTIMIZER: {
            NAME: "adam",
            LEARNING_RATE: 0.001
        },
        EPOCHS: 60,
        DATALOADER: {
            NAME: dataloader_name,
        },
        BATCH_SIZE: [16, 16, 16],
        SNR_FILTER : snr_filter_test
    }

    model, config = get_experiment_generator(exp_major=exp_major)(config, no_model=dry_run, minor=exp_minor)
    # POST PROCESSING
    if isinstance(config[BATCH_SIZE], list) or isinstance(config[BATCH_SIZE], tuple):
        config[BATCH_SIZE] = {
            TRAIN: config[BATCH_SIZE][0],
            TEST:  config[BATCH_SIZE][1],
            VALID:  config[BATCH_SIZE][2],
        }

    if config[DATALOADER][NAME] == "premix":
        mixed_audio_root = MIXED_AUDIO_ROOT
        dataloaders = get_dataloader({
            TRAIN: get_config_dataloader(
                audio_root=mixed_audio_root,
                mode=TRAIN,
                shuffle=True,
                batch_size=config[BATCH_SIZE][TRAIN],
                augmentation=config[DATALOADER].get(AUGMENTATION, {})
            ),
            TEST: get_config_dataloader(
                audio_root=mixed_audio_root,
                mode=TEST,
                shuffle=False,
                batch_size=config[BATCH_SIZE][TEST],
                snr_filter=config[SNR_FILTER]
            )
        })
    elif config[DATALOADER][NAME] == "remix":
        mixed_audio_root = MIXED_AUDIO_ROOT
        aug_test = {}
        if AUG_TRIM in config[DATALOADER].get(AUGMENTATION, {}):
            aug_test = {
                AUG_TRIM: {LENGTHS: [None, None], LENGTH_DIVIDER: config[DATALOADER][AUGMENTATION]
                           [AUG_TRIM][LENGTH_DIVIDER], TRIM_PROB: -1.}
            }
        try:
            dl_train = get_dataloader(
                {
                    TRAIN: get_config_dataloader(
                        audio_root=mixed_audio_root,
                        mode=TRAIN,
                        shuffle=True,
                        batch_size=config[BATCH_SIZE][TRAIN],
                        augmentation=config[DATALOADER].get(AUGMENTATION, {})
                    )
                },
                audio_dataset=RemixedRandomAudioDataset
            )[TRAIN]
            dl_test = get_dataloader(
                {
                    TEST: get_config_dataloader(
                        audio_root=mixed_audio_root,
                        mode=TEST,
                        shuffle=False,
                        batch_size=config[BATCH_SIZE][TEST]
                    )
                },
                audio_dataset=RemixedFixedAudioDataset
            )[TEST]
        except Exception as e:
            dl_train = None
            dl_test = None
            pass
        dataloaders = {
            TRAIN: dl_train,
            TEST: dl_test
        }
    else:
        raise NotImplementedError(f"Unknown dataloader {dataloader_name}")
    assert config[NAME] is not None

    short_name = f"{exp_major:04d}_{exp_minor:04d}"
    config[SHORT_NAME] = short_name
    return short_name, model, config, dataloaders


if __name__ == "__main__":
    from gyraudio.audio_separation.parser import shared_parser
    parser_def = shared_parser()
    args = parser_def.parse_args()

    for exp in args.experiments:
        short_name, model, config, dl = get_experience(exp)
        print(short_name)
        print(config)