Spaces:
Running
Running
File size: 6,658 Bytes
cec5823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import sys
import argparse
from typing import Optional
import torch
import logging
from pathlib import Path
import json
from tqdm import tqdm
from rstor.properties import (
ID, NAME, NB_EPOCHS,
TRAIN, VALIDATION, LR,
LOSS_MSE, METRIC_PSNR, METRIC_SSIM,
DEVICE, SCHEDULER_CONFIGURATION, SCHEDULER, REDUCELRONPLATEAU,
REDUCTION_SUM,
SELECTED_METRICS,
LOSS
)
from rstor.learning.metrics import compute_metrics
from rstor.learning.loss import compute_loss
from torch.optim.lr_scheduler import ReduceLROnPlateau
from configuration import WANDBSPACE, ROOT_DIR, OUTPUT_FOLDER_NAME
from rstor.learning.experiments import get_training_content
from rstor.learning.experiments_definition import get_experiment_config
WANDB_AVAILABLE = False
try:
WANDB_AVAILABLE = True
import wandb
except ImportError:
logging.warning("Could not import wandb. Disabling wandb.")
pass
def get_parser(parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser:
if parser is None:
parser = argparse.ArgumentParser(description="Train a model")
parser.add_argument("-e", "--exp", nargs="+", type=int, required=True, help="Experiment id")
parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR/OUTPUT_FOLDER_NAME, help="Output directory")
parser.add_argument("-nowb", "--no-wandb", action="store_true", help="Disable weights and biases")
parser.add_argument("--cpu", action="store_true", help="Force CPU")
return parser
def training_loop(
model,
optimizer,
dl_dict: dict,
config: dict,
scheduler=None,
device: str = DEVICE,
wandb_flag: bool = False,
output_dir: Path = None,
):
best_accuracy = 0.
chosen_metrics = config.get(SELECTED_METRICS, [METRIC_PSNR, METRIC_SSIM])
for n_epoch in tqdm(range(config[NB_EPOCHS])):
current_metrics = {
TRAIN: 0.,
VALIDATION: 0.,
LR: optimizer.param_groups[0]['lr'],
}
for met in chosen_metrics:
current_metrics[met] = 0.
for phase in [TRAIN, VALIDATION]:
total_elements = 0
if phase == TRAIN:
model.train()
else:
model.eval()
for x, y in tqdm(dl_dict[phase], desc=f"{phase} - Epoch {n_epoch}"):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == TRAIN):
y_pred = model(x)
loss = compute_loss(y_pred, y, mode=config.get(LOSS, LOSS_MSE))
if torch.isnan(loss):
print(f"Loss is NaN at epoch {n_epoch} and phase {phase}!")
continue
if phase == TRAIN:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
optimizer.step()
current_metrics[phase] += loss.item()
if phase == VALIDATION:
metrics_on_batch = compute_metrics(
y_pred,
y,
chosen_metrics=chosen_metrics,
reduction=REDUCTION_SUM
)
total_elements += y_pred.shape[0]
for k, v in metrics_on_batch.items():
current_metrics[k] += v
current_metrics[phase] /= (len(dl_dict[phase]))
if phase == VALIDATION:
for k, v in metrics_on_batch.items():
current_metrics[k] /= total_elements
try:
current_metrics[k] = current_metrics[k].item()
except AttributeError:
pass
debug_print = f"{phase}: Epoch {n_epoch} - Loss: {current_metrics[phase]:.3e} "
for k, v in current_metrics.items():
if k not in [TRAIN, VALIDATION, LR]:
debug_print += f"{k}: {v:.3} |"
print(debug_print)
if scheduler is not None and isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(current_metrics[VALIDATION])
if output_dir is not None:
with open(output_dir/f"metrics_{n_epoch}.json", "w") as f:
json.dump(current_metrics, f)
if wandb_flag:
wandb.log(current_metrics)
if best_accuracy < current_metrics[METRIC_PSNR]:
best_accuracy = current_metrics[METRIC_PSNR]
if output_dir is not None:
print("new best model saved!")
torch.save(model.state_dict(), output_dir/"best_model.pt")
if output_dir is not None:
torch.save(model.cpu().state_dict(), output_dir/"last_model.pt")
return model
def train(config: dict, output_dir: Path, device: str = DEVICE, wandb_flag: bool = False):
logging.basicConfig(level=logging.INFO)
logging.info(f"Training experiment {config[ID]} on device {device}...")
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir/"config.json", "w") as f:
json.dump(config, f)
model, optimizer, dl_dict = get_training_content(config, training_mode=True, device=device)
model.to(device)
if wandb_flag:
import wandb
wandb.init(
project=WANDBSPACE,
entity="balthazarneveu",
name=config[NAME],
tags=["debug"],
# tags=["base"],
config=config
)
scheduler = None
if config.get(SCHEDULER, False):
scheduler_config = config[SCHEDULER_CONFIGURATION]
if config[SCHEDULER] == REDUCELRONPLATEAU:
scheduler = ReduceLROnPlateau(optimizer, mode='min', verbose=True, **scheduler_config)
else:
raise NameError(f"Scheduler {config[SCHEDULER]} not implemented")
model = training_loop(model, optimizer, dl_dict, config, scheduler=scheduler, device=device,
wandb_flag=wandb_flag, output_dir=output_dir)
if wandb_flag:
wandb.finish()
def train_main(argv):
parser = get_parser()
args = parser.parse_args(argv)
if not WANDB_AVAILABLE:
args.no_wandb = True
device = "cpu" if args.cpu else DEVICE
for exp in args.exp:
config = get_experiment_config(exp)
print(config)
output_dir = Path(args.output_dir)/config[NAME]
logging.info(f"Training experiment {config[ID]} on device {device}...")
train(config, device=device, output_dir=output_dir, wandb_flag=not args.no_wandb)
if __name__ == "__main__":
train_main(sys.argv[1:])
|