Spaces:
Runtime error
Runtime error
File size: 2,186 Bytes
75c6e9a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
from typing import List
import pytorch_lightning as pl
import torch.nn as nn
def get_callbacks(
task_name: str,
config_yaml: str,
workspace: str,
checkpoints_dir: str,
statistics_path: str,
logger: pl.loggers.TensorBoardLogger,
model: nn.Module,
evaluate_device: str,
) -> List[pl.Callback]:
r"""Get callbacks of a task and config yaml file.
Args:
task_name: str
config_yaml: str
dataset_dir: str
workspace: str, containing useful files such as audios for evaluation
checkpoints_dir: str, directory to save checkpoints
statistics_dir: str, directory to save statistics
logger: pl.loggers.TensorBoardLogger
model: nn.Module
evaluate_device: str
Return:
callbacks: List[pl.Callback]
"""
if task_name == 'musdb18':
from bytesep.callbacks.musdb18 import get_musdb18_callbacks
return get_musdb18_callbacks(
config_yaml=config_yaml,
workspace=workspace,
checkpoints_dir=checkpoints_dir,
statistics_path=statistics_path,
logger=logger,
model=model,
evaluate_device=evaluate_device,
)
elif task_name == 'voicebank-demand':
from bytesep.callbacks.voicebank_demand import get_voicebank_demand_callbacks
return get_voicebank_demand_callbacks(
config_yaml=config_yaml,
workspace=workspace,
checkpoints_dir=checkpoints_dir,
statistics_path=statistics_path,
logger=logger,
model=model,
evaluate_device=evaluate_device,
)
elif task_name in ['vctk-musdb18', 'violin-piano', 'piano-symphony']:
from bytesep.callbacks.instruments_callbacks import get_instruments_callbacks
return get_instruments_callbacks(
config_yaml=config_yaml,
workspace=workspace,
checkpoints_dir=checkpoints_dir,
statistics_path=statistics_path,
logger=logger,
model=model,
evaluate_device=evaluate_device,
)
else:
raise NotImplementedError
|