Spaces:
Building
Building
from gyraudio.default_locations import MIXED_AUDIO_ROOT | |
from gyraudio.audio_separation.properties import ( | |
TRAIN, TEST, VALID, NAME, EPOCHS, LEARNING_RATE, | |
OPTIMIZER, BATCH_SIZE, DATALOADER, AUGMENTATION, | |
SHORT_NAME, AUG_TRIM, TRIM_PROB, LENGTH_DIVIDER, LENGTHS, SNR_FILTER | |
) | |
from gyraudio.audio_separation.data.remixed_fixed import RemixedFixedAudioDataset | |
from gyraudio.audio_separation.data.remixed_rnd import RemixedRandomAudioDataset | |
from gyraudio.audio_separation.data import get_dataloader, get_config_dataloader | |
from gyraudio.audio_separation.experiment_tracking.experiments_definition import get_experiment_generator | |
import torch | |
from typing import Tuple | |
def get_experience(exp_major: int, exp_minor: int = 0, snr_filter_test=None, dry_run=False) -> Tuple[str, torch.nn.Module, dict, dict]: | |
"""Get all experience details | |
Args: | |
exp_major (int): Major experience number | |
exp_minor (int, optional): Used for HP search. Defaults to 0. | |
Returns: | |
Tuple[str, torch.nn.Module, dict, dict]: short_name, model, config, dataloaders | |
""" | |
model = None | |
config = {} | |
dataloader_name = "remix" | |
config = { | |
NAME: None, | |
OPTIMIZER: { | |
NAME: "adam", | |
LEARNING_RATE: 0.001 | |
}, | |
EPOCHS: 60, | |
DATALOADER: { | |
NAME: dataloader_name, | |
}, | |
BATCH_SIZE: [16, 16, 16], | |
SNR_FILTER : snr_filter_test | |
} | |
model, config = get_experiment_generator(exp_major=exp_major)(config, no_model=dry_run, minor=exp_minor) | |
# POST PROCESSING | |
if isinstance(config[BATCH_SIZE], list) or isinstance(config[BATCH_SIZE], tuple): | |
config[BATCH_SIZE] = { | |
TRAIN: config[BATCH_SIZE][0], | |
TEST: config[BATCH_SIZE][1], | |
VALID: config[BATCH_SIZE][2], | |
} | |
if config[DATALOADER][NAME] == "premix": | |
mixed_audio_root = MIXED_AUDIO_ROOT | |
dataloaders = get_dataloader({ | |
TRAIN: get_config_dataloader( | |
audio_root=mixed_audio_root, | |
mode=TRAIN, | |
shuffle=True, | |
batch_size=config[BATCH_SIZE][TRAIN], | |
augmentation=config[DATALOADER].get(AUGMENTATION, {}) | |
), | |
TEST: get_config_dataloader( | |
audio_root=mixed_audio_root, | |
mode=TEST, | |
shuffle=False, | |
batch_size=config[BATCH_SIZE][TEST], | |
snr_filter=config[SNR_FILTER] | |
) | |
}) | |
elif config[DATALOADER][NAME] == "remix": | |
mixed_audio_root = MIXED_AUDIO_ROOT | |
aug_test = {} | |
if AUG_TRIM in config[DATALOADER].get(AUGMENTATION, {}): | |
aug_test = { | |
AUG_TRIM: {LENGTHS: [None, None], LENGTH_DIVIDER: config[DATALOADER][AUGMENTATION] | |
[AUG_TRIM][LENGTH_DIVIDER], TRIM_PROB: -1.} | |
} | |
try: | |
dl_train = get_dataloader( | |
{ | |
TRAIN: get_config_dataloader( | |
audio_root=mixed_audio_root, | |
mode=TRAIN, | |
shuffle=True, | |
batch_size=config[BATCH_SIZE][TRAIN], | |
augmentation=config[DATALOADER].get(AUGMENTATION, {}) | |
) | |
}, | |
audio_dataset=RemixedRandomAudioDataset | |
)[TRAIN] | |
dl_test = get_dataloader( | |
{ | |
TEST: get_config_dataloader( | |
audio_root=mixed_audio_root, | |
mode=TEST, | |
shuffle=False, | |
batch_size=config[BATCH_SIZE][TEST] | |
) | |
}, | |
audio_dataset=RemixedFixedAudioDataset | |
)[TEST] | |
except Exception as e: | |
dl_train = None | |
dl_test = None | |
pass | |
dataloaders = { | |
TRAIN: dl_train, | |
TEST: dl_test | |
} | |
else: | |
raise NotImplementedError(f"Unknown dataloader {dataloader_name}") | |
assert config[NAME] is not None | |
short_name = f"{exp_major:04d}_{exp_minor:04d}" | |
config[SHORT_NAME] = short_name | |
return short_name, model, config, dataloaders | |
if __name__ == "__main__": | |
from gyraudio.audio_separation.parser import shared_parser | |
parser_def = shared_parser() | |
args = parser_def.parse_args() | |
for exp in args.experiments: | |
short_name, model, config, dl = get_experience(exp) | |
print(short_name) | |
print(config) | |