Spaces:
Building
Building
from gyraudio.audio_separation.experiment_tracking.experiments import get_experience | |
from gyraudio.audio_separation.parser import shared_parser | |
from gyraudio.audio_separation.properties import ( | |
TRAIN, TEST, EPOCHS, OPTIMIZER, NAME, MAX_STEPS_PER_EPOCH, | |
SIGNAL, NOISE, TOTAL, SNR, SCHEDULER, SCHEDULER_CONFIGURATION, | |
LOSS, LOSS_L2, LOSS_L1, LOSS_TYPE, COEFFICIENT | |
) | |
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT | |
from torch.optim.lr_scheduler import ReduceLROnPlateau | |
from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder, save_checkpoint | |
from gyraudio.audio_separation.metrics import Costs, snr | |
# from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint | |
from pathlib import Path | |
from gyraudio.io.dump import Dump | |
import sys | |
import torch | |
from tqdm import tqdm | |
from copy import deepcopy | |
import wandb | |
import logging | |
def launch_training(exp: int, wandb_flag: bool = True, device: str = "cuda", save_dir: Path = None, override=False): | |
short_name, model, config, dl = get_experience(exp) | |
exists, output_folder = get_output_folder(config, root_dir=save_dir, override=override) | |
if not exists: | |
logging.warning(f"Skipping experiment {short_name}") | |
return False | |
else: | |
logging.info(f"Experiment {short_name} saved in {output_folder}") | |
print(short_name) | |
print(config) | |
logging.info(f"Starting training for {short_name}") | |
logging.info(f"Config: {config}") | |
if wandb_flag: | |
wandb.init( | |
project="audio-separation", | |
entity="teammd", | |
name=short_name, | |
tags=["debug"], | |
config=config | |
) | |
training_loop(model, config, dl, wandb_flag=wandb_flag, device=device, exp_dir=output_folder) | |
if wandb_flag: | |
wandb.finish() | |
return True | |
def update_metrics(metrics, phase, pred, gt, pred_noise, gt_noise): | |
metrics[phase].update(pred, gt, SIGNAL) | |
metrics[phase].update(pred_noise, gt_noise, NOISE) | |
metrics[phase].update(pred, gt, SNR) | |
loss = metrics[phase].finish_step() | |
return loss | |
def training_loop(model: torch.nn.Module, config: dict, dl, device: str = "cuda", wandb_flag: bool = False, | |
exp_dir: Path = None): | |
optim_params = deepcopy(config[OPTIMIZER]) | |
optim_name = optim_params[NAME] | |
optim_params.pop(NAME) | |
if optim_name == "adam": | |
optimizer = torch.optim.Adam(model.parameters(), **optim_params) | |
scheduler = None | |
scheduler_config = config.get(SCHEDULER_CONFIGURATION, {}) | |
scheduler_name = config.get(SCHEDULER, False) | |
if scheduler_name: | |
if scheduler_name == "ReduceLROnPlateau": | |
scheduler = ReduceLROnPlateau(optimizer, mode='max', verbose=True, **scheduler_config) | |
logging.info(f"Using scheduler {scheduler_name} with config {scheduler_config}") | |
else: | |
raise NotImplementedError(f"Scheduler {scheduler_name} not implemented") | |
max_steps = config.get(MAX_STEPS_PER_EPOCH, None) | |
chosen_loss = config.get(LOSS, LOSS_L2) | |
if chosen_loss == LOSS_L2: | |
costs = {TRAIN: Costs(TRAIN), TEST: Costs(TEST)} | |
elif chosen_loss == LOSS_L1: | |
cost_init = { | |
SIGNAL: { | |
COEFFICIENT: 0.5, | |
LOSS_TYPE: torch.nn.functional.l1_loss | |
}, | |
NOISE: { | |
COEFFICIENT: 0.5, | |
LOSS_TYPE: torch.nn.functional.l1_loss | |
}, | |
SNR: { | |
LOSS_TYPE: snr | |
} | |
} | |
costs = { | |
TRAIN: Costs(TRAIN, costs=cost_init), | |
TEST: Costs(TEST) | |
} | |
for epoch in range(config[EPOCHS]): | |
costs[TRAIN].reset_epoch() | |
costs[TEST].reset_epoch() | |
model.to(device) | |
# Training loop | |
# ----------------------------------------------------------- | |
metrics = {TRAIN: {}, TEST: {}} | |
for step_index, (batch_mix, batch_signal, batch_noise) in tqdm( | |
enumerate(dl[TRAIN]), desc=f"Epoch {epoch}", total=len(dl[TRAIN])): | |
if max_steps is not None and step_index >= max_steps: | |
break | |
batch_mix, batch_signal, batch_noise = batch_mix.to(device), batch_signal.to(device), batch_noise.to(device) | |
model.zero_grad() | |
batch_output_signal, batch_output_noise = model(batch_mix) | |
loss = update_metrics( | |
costs, TRAIN, | |
batch_output_signal, batch_signal, | |
batch_output_noise, batch_noise | |
) | |
# costs[TRAIN].update(batch_output_signal, batch_signal, SIGNAL) | |
# costs[TRAIN].update(batch_output_noise, batch_noise, NOISE) | |
# loss = costs[TRAIN].finish_step() | |
loss.backward() | |
optimizer.step() | |
costs[TRAIN].finish_epoch() | |
# Validation loop | |
# ----------------------------------------------------------- | |
model.eval() | |
torch.cuda.empty_cache() | |
with torch.no_grad(): | |
for step_index, (batch_mix, batch_signal, batch_noise) in tqdm( | |
enumerate(dl[TEST]), desc=f"Epoch {epoch}", total=len(dl[TEST])): | |
if max_steps is not None and step_index >= max_steps: | |
break | |
batch_mix, batch_signal, batch_noise = batch_mix.to( | |
device), batch_signal.to(device), batch_noise.to(device) | |
batch_output_signal, batch_output_noise = model(batch_mix) | |
loss = update_metrics( | |
costs, TEST, | |
batch_output_signal, batch_signal, | |
batch_output_noise, batch_noise | |
) | |
costs[TEST].finish_epoch() | |
if scheduler is not None and isinstance(scheduler, ReduceLROnPlateau): | |
scheduler.step(costs[TEST].total_metric[SNR]) | |
print(f"epoch {epoch}:\n{costs[TRAIN]}\n{costs[TEST]}") | |
wandblogs = {} | |
if wandb_flag: | |
for phase in [TRAIN, TEST]: | |
wandblogs[f"{phase} loss signal"] = costs[phase].total_metric[SIGNAL] | |
wandblogs[f"debug loss/{phase} loss signal"] = costs[phase].total_metric[SIGNAL] | |
wandblogs[f"debug loss/{phase} loss total"] = costs[phase].total_metric[TOTAL] | |
wandblogs[f"debug loss/{phase} loss noise"] = costs[phase].total_metric[NOISE] | |
wandblogs[f"{phase} snr"] = costs[phase].total_metric[SNR] | |
wandblogs["learning rate"] = optimizer.param_groups[0]['lr'] | |
wandb.log(wandblogs) | |
metrics[TRAIN] = costs[TRAIN].total_metric | |
metrics[TEST] = costs[TEST].total_metric | |
Dump.save_json(metrics, exp_dir/f"metrics_{epoch:04d}.json") | |
save_checkpoint(model, exp_dir, optimizer, config=config, epoch=epoch) | |
torch.cuda.empty_cache() | |
def main(argv): | |
default_device = "cuda" if torch.cuda.is_available() else "cpu" | |
parser_def = shared_parser(help="Launch training \nCheck results at: https://wandb.ai/teammd/audio-separation" | |
+ ("\n<<<Cuda available>>>" if default_device == "cuda" else "")) | |
parser_def.add_argument("-nowb", "--no-wandb", action="store_true") | |
parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT) | |
parser_def.add_argument("-f", "--force", action="store_true", help="Override existing experiment") | |
parser_def.add_argument("-d", "--device", type=str, default=default_device, | |
help="Training device", choices=["cpu", "cuda"]) | |
args = parser_def.parse_args(argv) | |
for exp in args.experiments: | |
launch_training( | |
exp, wandb_flag=not args.no_wandb, save_dir=Path(args.output_dir), | |
override=args.force, | |
device=args.device | |
) | |
if __name__ == "__main__": | |
main(sys.argv[1:]) | |