import sys import argparse from typing import Optional import torch import logging from pathlib import Path import json from tqdm import tqdm from rstor.properties import ( ID, NAME, NB_EPOCHS, TRAIN, VALIDATION, LR, LOSS_MSE, METRIC_PSNR, METRIC_SSIM, DEVICE, SCHEDULER_CONFIGURATION, SCHEDULER, REDUCELRONPLATEAU, REDUCTION_SUM, SELECTED_METRICS, LOSS ) from rstor.learning.metrics import compute_metrics from rstor.learning.loss import compute_loss from torch.optim.lr_scheduler import ReduceLROnPlateau from configuration import WANDBSPACE, ROOT_DIR, OUTPUT_FOLDER_NAME from rstor.learning.experiments import get_training_content from rstor.learning.experiments_definition import get_experiment_config WANDB_AVAILABLE = False try: WANDB_AVAILABLE = True import wandb except ImportError: logging.warning("Could not import wandb. Disabling wandb.") pass def get_parser(parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser: if parser is None: parser = argparse.ArgumentParser(description="Train a model") parser.add_argument("-e", "--exp", nargs="+", type=int, required=True, help="Experiment id") parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR/OUTPUT_FOLDER_NAME, help="Output directory") parser.add_argument("-nowb", "--no-wandb", action="store_true", help="Disable weights and biases") parser.add_argument("--cpu", action="store_true", help="Force CPU") return parser def training_loop( model, optimizer, dl_dict: dict, config: dict, scheduler=None, device: str = DEVICE, wandb_flag: bool = False, output_dir: Path = None, ): best_accuracy = 0. chosen_metrics = config.get(SELECTED_METRICS, [METRIC_PSNR, METRIC_SSIM]) for n_epoch in tqdm(range(config[NB_EPOCHS])): current_metrics = { TRAIN: 0., VALIDATION: 0., LR: optimizer.param_groups[0]['lr'], } for met in chosen_metrics: current_metrics[met] = 0. for phase in [TRAIN, VALIDATION]: total_elements = 0 if phase == TRAIN: model.train() else: model.eval() for x, y in tqdm(dl_dict[phase], desc=f"{phase} - Epoch {n_epoch}"): x, y = x.to(device), y.to(device) optimizer.zero_grad() with torch.set_grad_enabled(phase == TRAIN): y_pred = model(x) loss = compute_loss(y_pred, y, mode=config.get(LOSS, LOSS_MSE)) if torch.isnan(loss): print(f"Loss is NaN at epoch {n_epoch} and phase {phase}!") continue if phase == TRAIN: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.) optimizer.step() current_metrics[phase] += loss.item() if phase == VALIDATION: metrics_on_batch = compute_metrics( y_pred, y, chosen_metrics=chosen_metrics, reduction=REDUCTION_SUM ) total_elements += y_pred.shape[0] for k, v in metrics_on_batch.items(): current_metrics[k] += v current_metrics[phase] /= (len(dl_dict[phase])) if phase == VALIDATION: for k, v in metrics_on_batch.items(): current_metrics[k] /= total_elements try: current_metrics[k] = current_metrics[k].item() except AttributeError: pass debug_print = f"{phase}: Epoch {n_epoch} - Loss: {current_metrics[phase]:.3e} " for k, v in current_metrics.items(): if k not in [TRAIN, VALIDATION, LR]: debug_print += f"{k}: {v:.3} |" print(debug_print) if scheduler is not None and isinstance(scheduler, ReduceLROnPlateau): scheduler.step(current_metrics[VALIDATION]) if output_dir is not None: with open(output_dir/f"metrics_{n_epoch}.json", "w") as f: json.dump(current_metrics, f) if wandb_flag: wandb.log(current_metrics) if best_accuracy < current_metrics[METRIC_PSNR]: best_accuracy = current_metrics[METRIC_PSNR] if output_dir is not None: print("new best model saved!") torch.save(model.state_dict(), output_dir/"best_model.pt") if output_dir is not None: torch.save(model.cpu().state_dict(), output_dir/"last_model.pt") return model def train(config: dict, output_dir: Path, device: str = DEVICE, wandb_flag: bool = False): logging.basicConfig(level=logging.INFO) logging.info(f"Training experiment {config[ID]} on device {device}...") output_dir.mkdir(parents=True, exist_ok=True) with open(output_dir/"config.json", "w") as f: json.dump(config, f) model, optimizer, dl_dict = get_training_content(config, training_mode=True, device=device) model.to(device) if wandb_flag: import wandb wandb.init( project=WANDBSPACE, entity="balthazarneveu", name=config[NAME], tags=["debug"], # tags=["base"], config=config ) scheduler = None if config.get(SCHEDULER, False): scheduler_config = config[SCHEDULER_CONFIGURATION] if config[SCHEDULER] == REDUCELRONPLATEAU: scheduler = ReduceLROnPlateau(optimizer, mode='min', verbose=True, **scheduler_config) else: raise NameError(f"Scheduler {config[SCHEDULER]} not implemented") model = training_loop(model, optimizer, dl_dict, config, scheduler=scheduler, device=device, wandb_flag=wandb_flag, output_dir=output_dir) if wandb_flag: wandb.finish() def train_main(argv): parser = get_parser() args = parser.parse_args(argv) if not WANDB_AVAILABLE: args.no_wandb = True device = "cpu" if args.cpu else DEVICE for exp in args.exp: config = get_experiment_config(exp) print(config) output_dir = Path(args.output_dir)/config[NAME] logging.info(f"Training experiment {config[ID]} on device {device}...") train(config, device=device, output_dir=output_dir, wandb_flag=not args.no_wandb) if __name__ == "__main__": train_main(sys.argv[1:])