balthou's picture
initiate demo
cec5823
import sys
import argparse
from typing import Optional
import torch
import logging
from pathlib import Path
import json
from tqdm import tqdm
from rstor.properties import (
ID, NAME, NB_EPOCHS,
TRAIN, VALIDATION, LR,
LOSS_MSE, METRIC_PSNR, METRIC_SSIM,
DEVICE, SCHEDULER_CONFIGURATION, SCHEDULER, REDUCELRONPLATEAU,
REDUCTION_SUM,
SELECTED_METRICS,
LOSS
)
from rstor.learning.metrics import compute_metrics
from rstor.learning.loss import compute_loss
from torch.optim.lr_scheduler import ReduceLROnPlateau
from configuration import WANDBSPACE, ROOT_DIR, OUTPUT_FOLDER_NAME
from rstor.learning.experiments import get_training_content
from rstor.learning.experiments_definition import get_experiment_config
WANDB_AVAILABLE = False
try:
WANDB_AVAILABLE = True
import wandb
except ImportError:
logging.warning("Could not import wandb. Disabling wandb.")
pass
def get_parser(parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser:
if parser is None:
parser = argparse.ArgumentParser(description="Train a model")
parser.add_argument("-e", "--exp", nargs="+", type=int, required=True, help="Experiment id")
parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR/OUTPUT_FOLDER_NAME, help="Output directory")
parser.add_argument("-nowb", "--no-wandb", action="store_true", help="Disable weights and biases")
parser.add_argument("--cpu", action="store_true", help="Force CPU")
return parser
def training_loop(
model,
optimizer,
dl_dict: dict,
config: dict,
scheduler=None,
device: str = DEVICE,
wandb_flag: bool = False,
output_dir: Path = None,
):
best_accuracy = 0.
chosen_metrics = config.get(SELECTED_METRICS, [METRIC_PSNR, METRIC_SSIM])
for n_epoch in tqdm(range(config[NB_EPOCHS])):
current_metrics = {
TRAIN: 0.,
VALIDATION: 0.,
LR: optimizer.param_groups[0]['lr'],
}
for met in chosen_metrics:
current_metrics[met] = 0.
for phase in [TRAIN, VALIDATION]:
total_elements = 0
if phase == TRAIN:
model.train()
else:
model.eval()
for x, y in tqdm(dl_dict[phase], desc=f"{phase} - Epoch {n_epoch}"):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == TRAIN):
y_pred = model(x)
loss = compute_loss(y_pred, y, mode=config.get(LOSS, LOSS_MSE))
if torch.isnan(loss):
print(f"Loss is NaN at epoch {n_epoch} and phase {phase}!")
continue
if phase == TRAIN:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
optimizer.step()
current_metrics[phase] += loss.item()
if phase == VALIDATION:
metrics_on_batch = compute_metrics(
y_pred,
y,
chosen_metrics=chosen_metrics,
reduction=REDUCTION_SUM
)
total_elements += y_pred.shape[0]
for k, v in metrics_on_batch.items():
current_metrics[k] += v
current_metrics[phase] /= (len(dl_dict[phase]))
if phase == VALIDATION:
for k, v in metrics_on_batch.items():
current_metrics[k] /= total_elements
try:
current_metrics[k] = current_metrics[k].item()
except AttributeError:
pass
debug_print = f"{phase}: Epoch {n_epoch} - Loss: {current_metrics[phase]:.3e} "
for k, v in current_metrics.items():
if k not in [TRAIN, VALIDATION, LR]:
debug_print += f"{k}: {v:.3} |"
print(debug_print)
if scheduler is not None and isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(current_metrics[VALIDATION])
if output_dir is not None:
with open(output_dir/f"metrics_{n_epoch}.json", "w") as f:
json.dump(current_metrics, f)
if wandb_flag:
wandb.log(current_metrics)
if best_accuracy < current_metrics[METRIC_PSNR]:
best_accuracy = current_metrics[METRIC_PSNR]
if output_dir is not None:
print("new best model saved!")
torch.save(model.state_dict(), output_dir/"best_model.pt")
if output_dir is not None:
torch.save(model.cpu().state_dict(), output_dir/"last_model.pt")
return model
def train(config: dict, output_dir: Path, device: str = DEVICE, wandb_flag: bool = False):
logging.basicConfig(level=logging.INFO)
logging.info(f"Training experiment {config[ID]} on device {device}...")
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir/"config.json", "w") as f:
json.dump(config, f)
model, optimizer, dl_dict = get_training_content(config, training_mode=True, device=device)
model.to(device)
if wandb_flag:
import wandb
wandb.init(
project=WANDBSPACE,
entity="balthazarneveu",
name=config[NAME],
tags=["debug"],
# tags=["base"],
config=config
)
scheduler = None
if config.get(SCHEDULER, False):
scheduler_config = config[SCHEDULER_CONFIGURATION]
if config[SCHEDULER] == REDUCELRONPLATEAU:
scheduler = ReduceLROnPlateau(optimizer, mode='min', verbose=True, **scheduler_config)
else:
raise NameError(f"Scheduler {config[SCHEDULER]} not implemented")
model = training_loop(model, optimizer, dl_dict, config, scheduler=scheduler, device=device,
wandb_flag=wandb_flag, output_dir=output_dir)
if wandb_flag:
wandb.finish()
def train_main(argv):
parser = get_parser()
args = parser.parse_args(argv)
if not WANDB_AVAILABLE:
args.no_wandb = True
device = "cpu" if args.cpu else DEVICE
for exp in args.exp:
config = get_experiment_config(exp)
print(config)
output_dir = Path(args.output_dir)/config[NAME]
logging.info(f"Training experiment {config[ID]} on device {device}...")
train(config, device=device, output_dir=output_dir, wandb_flag=not args.no_wandb)
if __name__ == "__main__":
train_main(sys.argv[1:])