balthou's picture
initiate demo
cec5823
raw
history blame
6.66 kB
import sys
import argparse
from typing import Optional
import torch
import logging
from pathlib import Path
import json
from tqdm import tqdm
from rstor.properties import (
ID, NAME, NB_EPOCHS,
TRAIN, VALIDATION, LR,
LOSS_MSE, METRIC_PSNR, METRIC_SSIM,
DEVICE, SCHEDULER_CONFIGURATION, SCHEDULER, REDUCELRONPLATEAU,
REDUCTION_SUM,
SELECTED_METRICS,
LOSS
)
from rstor.learning.metrics import compute_metrics
from rstor.learning.loss import compute_loss
from torch.optim.lr_scheduler import ReduceLROnPlateau
from configuration import WANDBSPACE, ROOT_DIR, OUTPUT_FOLDER_NAME
from rstor.learning.experiments import get_training_content
from rstor.learning.experiments_definition import get_experiment_config
WANDB_AVAILABLE = False
try:
WANDB_AVAILABLE = True
import wandb
except ImportError:
logging.warning("Could not import wandb. Disabling wandb.")
pass
def get_parser(parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser:
if parser is None:
parser = argparse.ArgumentParser(description="Train a model")
parser.add_argument("-e", "--exp", nargs="+", type=int, required=True, help="Experiment id")
parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR/OUTPUT_FOLDER_NAME, help="Output directory")
parser.add_argument("-nowb", "--no-wandb", action="store_true", help="Disable weights and biases")
parser.add_argument("--cpu", action="store_true", help="Force CPU")
return parser
def training_loop(
model,
optimizer,
dl_dict: dict,
config: dict,
scheduler=None,
device: str = DEVICE,
wandb_flag: bool = False,
output_dir: Path = None,
):
best_accuracy = 0.
chosen_metrics = config.get(SELECTED_METRICS, [METRIC_PSNR, METRIC_SSIM])
for n_epoch in tqdm(range(config[NB_EPOCHS])):
current_metrics = {
TRAIN: 0.,
VALIDATION: 0.,
LR: optimizer.param_groups[0]['lr'],
}
for met in chosen_metrics:
current_metrics[met] = 0.
for phase in [TRAIN, VALIDATION]:
total_elements = 0
if phase == TRAIN:
model.train()
else:
model.eval()
for x, y in tqdm(dl_dict[phase], desc=f"{phase} - Epoch {n_epoch}"):
x, y = x.to(device), y.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(phase == TRAIN):
y_pred = model(x)
loss = compute_loss(y_pred, y, mode=config.get(LOSS, LOSS_MSE))
if torch.isnan(loss):
print(f"Loss is NaN at epoch {n_epoch} and phase {phase}!")
continue
if phase == TRAIN:
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
optimizer.step()
current_metrics[phase] += loss.item()
if phase == VALIDATION:
metrics_on_batch = compute_metrics(
y_pred,
y,
chosen_metrics=chosen_metrics,
reduction=REDUCTION_SUM
)
total_elements += y_pred.shape[0]
for k, v in metrics_on_batch.items():
current_metrics[k] += v
current_metrics[phase] /= (len(dl_dict[phase]))
if phase == VALIDATION:
for k, v in metrics_on_batch.items():
current_metrics[k] /= total_elements
try:
current_metrics[k] = current_metrics[k].item()
except AttributeError:
pass
debug_print = f"{phase}: Epoch {n_epoch} - Loss: {current_metrics[phase]:.3e} "
for k, v in current_metrics.items():
if k not in [TRAIN, VALIDATION, LR]:
debug_print += f"{k}: {v:.3} |"
print(debug_print)
if scheduler is not None and isinstance(scheduler, ReduceLROnPlateau):
scheduler.step(current_metrics[VALIDATION])
if output_dir is not None:
with open(output_dir/f"metrics_{n_epoch}.json", "w") as f:
json.dump(current_metrics, f)
if wandb_flag:
wandb.log(current_metrics)
if best_accuracy < current_metrics[METRIC_PSNR]:
best_accuracy = current_metrics[METRIC_PSNR]
if output_dir is not None:
print("new best model saved!")
torch.save(model.state_dict(), output_dir/"best_model.pt")
if output_dir is not None:
torch.save(model.cpu().state_dict(), output_dir/"last_model.pt")
return model
def train(config: dict, output_dir: Path, device: str = DEVICE, wandb_flag: bool = False):
logging.basicConfig(level=logging.INFO)
logging.info(f"Training experiment {config[ID]} on device {device}...")
output_dir.mkdir(parents=True, exist_ok=True)
with open(output_dir/"config.json", "w") as f:
json.dump(config, f)
model, optimizer, dl_dict = get_training_content(config, training_mode=True, device=device)
model.to(device)
if wandb_flag:
import wandb
wandb.init(
project=WANDBSPACE,
entity="balthazarneveu",
name=config[NAME],
tags=["debug"],
# tags=["base"],
config=config
)
scheduler = None
if config.get(SCHEDULER, False):
scheduler_config = config[SCHEDULER_CONFIGURATION]
if config[SCHEDULER] == REDUCELRONPLATEAU:
scheduler = ReduceLROnPlateau(optimizer, mode='min', verbose=True, **scheduler_config)
else:
raise NameError(f"Scheduler {config[SCHEDULER]} not implemented")
model = training_loop(model, optimizer, dl_dict, config, scheduler=scheduler, device=device,
wandb_flag=wandb_flag, output_dir=output_dir)
if wandb_flag:
wandb.finish()
def train_main(argv):
parser = get_parser()
args = parser.parse_args(argv)
if not WANDB_AVAILABLE:
args.no_wandb = True
device = "cpu" if args.cpu else DEVICE
for exp in args.exp:
config = get_experiment_config(exp)
print(config)
output_dir = Path(args.output_dir)/config[NAME]
logging.info(f"Training experiment {config[ID]} on device {device}...")
train(config, device=device, output_dir=output_dir, wandb_flag=not args.no_wandb)
if __name__ == "__main__":
train_main(sys.argv[1:])