File size: 7,895 Bytes
f6b56a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
from gyraudio.audio_separation.experiment_tracking.experiments import get_experience
from gyraudio.audio_separation.parser import shared_parser
from gyraudio.audio_separation.properties import (
    TRAIN, TEST, EPOCHS, OPTIMIZER, NAME, MAX_STEPS_PER_EPOCH,
    SIGNAL, NOISE, TOTAL, SNR, SCHEDULER, SCHEDULER_CONFIGURATION,
    LOSS, LOSS_L2, LOSS_L1, LOSS_TYPE, COEFFICIENT
)
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
from torch.optim.lr_scheduler import ReduceLROnPlateau
from gyraudio.audio_separation.experiment_tracking.storage import get_output_folder, save_checkpoint
from gyraudio.audio_separation.metrics import Costs, snr
# from gyraudio.audio_separation.experiment_tracking.storage import load_checkpoint
from pathlib import Path
from gyraudio.io.dump import Dump
import sys
import torch
from tqdm import tqdm
from copy import deepcopy
import wandb
import logging


def launch_training(exp: int, wandb_flag: bool = True, device: str = "cuda", save_dir: Path = None, override=False):

    short_name, model, config, dl = get_experience(exp)
    exists, output_folder = get_output_folder(config, root_dir=save_dir, override=override)
    if not exists:
        logging.warning(f"Skipping experiment {short_name}")
        return False
    else:
        logging.info(f"Experiment {short_name} saved in {output_folder}")

    print(short_name)
    print(config)
    logging.info(f"Starting training for {short_name}")
    logging.info(f"Config: {config}")
    if wandb_flag:
        wandb.init(
            project="audio-separation",
            entity="teammd",
            name=short_name,
            tags=["debug"],
            config=config
        )
    training_loop(model, config, dl, wandb_flag=wandb_flag, device=device, exp_dir=output_folder)
    if wandb_flag:
        wandb.finish()
    return True


def update_metrics(metrics, phase, pred, gt, pred_noise, gt_noise):
    metrics[phase].update(pred, gt, SIGNAL)
    metrics[phase].update(pred_noise, gt_noise, NOISE)
    metrics[phase].update(pred, gt, SNR)
    loss = metrics[phase].finish_step()
    return loss


def training_loop(model: torch.nn.Module, config: dict, dl, device: str = "cuda", wandb_flag: bool = False,
                  exp_dir: Path = None):
    optim_params = deepcopy(config[OPTIMIZER])
    optim_name = optim_params[NAME]
    optim_params.pop(NAME)
    if optim_name == "adam":
        optimizer = torch.optim.Adam(model.parameters(), **optim_params)

    scheduler = None
    scheduler_config = config.get(SCHEDULER_CONFIGURATION, {})
    scheduler_name = config.get(SCHEDULER, False)
    if scheduler_name:
        if scheduler_name == "ReduceLROnPlateau":
            scheduler = ReduceLROnPlateau(optimizer, mode='max', verbose=True, **scheduler_config)
            logging.info(f"Using scheduler {scheduler_name} with config {scheduler_config}")
        else:
            raise NotImplementedError(f"Scheduler {scheduler_name} not implemented")
    max_steps = config.get(MAX_STEPS_PER_EPOCH, None)
    chosen_loss = config.get(LOSS, LOSS_L2)
    if chosen_loss == LOSS_L2:
        costs = {TRAIN:  Costs(TRAIN), TEST: Costs(TEST)}
    elif chosen_loss == LOSS_L1:
        cost_init = {
            SIGNAL: {
                COEFFICIENT: 0.5,
                LOSS_TYPE: torch.nn.functional.l1_loss
            },
            NOISE: {
                COEFFICIENT: 0.5,
                LOSS_TYPE: torch.nn.functional.l1_loss
            },
            SNR: {
                LOSS_TYPE: snr
            }
        }
        costs = {
            TRAIN:  Costs(TRAIN, costs=cost_init),
            TEST: Costs(TEST)
        }
    for epoch in range(config[EPOCHS]):
        costs[TRAIN].reset_epoch()
        costs[TEST].reset_epoch()
        model.to(device)
        # Training loop
        # -----------------------------------------------------------

        metrics = {TRAIN: {}, TEST: {}}
        for step_index, (batch_mix, batch_signal, batch_noise) in tqdm(
                enumerate(dl[TRAIN]), desc=f"Epoch {epoch}", total=len(dl[TRAIN])):
            if max_steps is not None and step_index >= max_steps:
                break
            batch_mix, batch_signal, batch_noise = batch_mix.to(device), batch_signal.to(device), batch_noise.to(device)
            model.zero_grad()
            batch_output_signal, batch_output_noise = model(batch_mix)
            loss = update_metrics(
                costs, TRAIN,
                batch_output_signal, batch_signal,
                batch_output_noise, batch_noise
            )
            # costs[TRAIN].update(batch_output_signal, batch_signal, SIGNAL)
            # costs[TRAIN].update(batch_output_noise, batch_noise, NOISE)
            # loss = costs[TRAIN].finish_step()
            loss.backward()
            optimizer.step()
        costs[TRAIN].finish_epoch()

        # Validation loop
        # -----------------------------------------------------------
        model.eval()
        torch.cuda.empty_cache()
        with torch.no_grad():
            for step_index, (batch_mix, batch_signal, batch_noise) in tqdm(
                    enumerate(dl[TEST]), desc=f"Epoch {epoch}", total=len(dl[TEST])):
                if max_steps is not None and step_index >= max_steps:
                    break
                batch_mix, batch_signal, batch_noise = batch_mix.to(
                    device), batch_signal.to(device), batch_noise.to(device)
                batch_output_signal, batch_output_noise = model(batch_mix)
                loss = update_metrics(
                    costs, TEST,
                    batch_output_signal, batch_signal,
                    batch_output_noise, batch_noise
                )
        costs[TEST].finish_epoch()
        if scheduler is not None and isinstance(scheduler, ReduceLROnPlateau):
            scheduler.step(costs[TEST].total_metric[SNR])
        print(f"epoch {epoch}:\n{costs[TRAIN]}\n{costs[TEST]}")
        wandblogs = {}
        if wandb_flag:
            for phase in [TRAIN, TEST]:
                wandblogs[f"{phase} loss signal"] = costs[phase].total_metric[SIGNAL]
                wandblogs[f"debug loss/{phase} loss signal"] = costs[phase].total_metric[SIGNAL]
                wandblogs[f"debug loss/{phase} loss total"] = costs[phase].total_metric[TOTAL]
                wandblogs[f"debug loss/{phase} loss noise"] = costs[phase].total_metric[NOISE]
                wandblogs[f"{phase} snr"] = costs[phase].total_metric[SNR]
                wandblogs["learning rate"] = optimizer.param_groups[0]['lr']
            wandb.log(wandblogs)
        metrics[TRAIN] = costs[TRAIN].total_metric
        metrics[TEST] = costs[TEST].total_metric
        Dump.save_json(metrics, exp_dir/f"metrics_{epoch:04d}.json")
        save_checkpoint(model, exp_dir, optimizer, config=config, epoch=epoch)
        torch.cuda.empty_cache()


def main(argv):
    default_device = "cuda" if torch.cuda.is_available() else "cpu"
    parser_def = shared_parser(help="Launch training \nCheck results at: https://wandb.ai/teammd/audio-separation"
                               + ("\n<<<Cuda available>>>" if default_device == "cuda" else ""))
    parser_def.add_argument("-nowb", "--no-wandb", action="store_true")
    parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT)
    parser_def.add_argument("-f", "--force", action="store_true", help="Override existing experiment")

    parser_def.add_argument("-d", "--device", type=str, default=default_device,
                            help="Training device", choices=["cpu", "cuda"])
    args = parser_def.parse_args(argv)
    for exp in args.experiments:
        launch_training(
            exp, wandb_flag=not args.no_wandb, save_dir=Path(args.output_dir),
            override=args.force,
            device=args.device
        )


if __name__ == "__main__":
    main(sys.argv[1:])