Spaces:
Running
Running
import sys | |
import argparse | |
from typing import Optional | |
import torch | |
import logging | |
from pathlib import Path | |
import json | |
from tqdm import tqdm | |
from rstor.properties import ( | |
ID, NAME, NB_EPOCHS, | |
TRAIN, VALIDATION, LR, | |
LOSS_MSE, METRIC_PSNR, METRIC_SSIM, | |
DEVICE, SCHEDULER_CONFIGURATION, SCHEDULER, REDUCELRONPLATEAU, | |
REDUCTION_SUM, | |
SELECTED_METRICS, | |
LOSS | |
) | |
from rstor.learning.metrics import compute_metrics | |
from rstor.learning.loss import compute_loss | |
from torch.optim.lr_scheduler import ReduceLROnPlateau | |
from configuration import WANDBSPACE, ROOT_DIR, OUTPUT_FOLDER_NAME | |
from rstor.learning.experiments import get_training_content | |
from rstor.learning.experiments_definition import get_experiment_config | |
WANDB_AVAILABLE = False | |
try: | |
WANDB_AVAILABLE = True | |
import wandb | |
except ImportError: | |
logging.warning("Could not import wandb. Disabling wandb.") | |
pass | |
def get_parser(parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser: | |
if parser is None: | |
parser = argparse.ArgumentParser(description="Train a model") | |
parser.add_argument("-e", "--exp", nargs="+", type=int, required=True, help="Experiment id") | |
parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR/OUTPUT_FOLDER_NAME, help="Output directory") | |
parser.add_argument("-nowb", "--no-wandb", action="store_true", help="Disable weights and biases") | |
parser.add_argument("--cpu", action="store_true", help="Force CPU") | |
return parser | |
def training_loop( | |
model, | |
optimizer, | |
dl_dict: dict, | |
config: dict, | |
scheduler=None, | |
device: str = DEVICE, | |
wandb_flag: bool = False, | |
output_dir: Path = None, | |
): | |
best_accuracy = 0. | |
chosen_metrics = config.get(SELECTED_METRICS, [METRIC_PSNR, METRIC_SSIM]) | |
for n_epoch in tqdm(range(config[NB_EPOCHS])): | |
current_metrics = { | |
TRAIN: 0., | |
VALIDATION: 0., | |
LR: optimizer.param_groups[0]['lr'], | |
} | |
for met in chosen_metrics: | |
current_metrics[met] = 0. | |
for phase in [TRAIN, VALIDATION]: | |
total_elements = 0 | |
if phase == TRAIN: | |
model.train() | |
else: | |
model.eval() | |
for x, y in tqdm(dl_dict[phase], desc=f"{phase} - Epoch {n_epoch}"): | |
x, y = x.to(device), y.to(device) | |
optimizer.zero_grad() | |
with torch.set_grad_enabled(phase == TRAIN): | |
y_pred = model(x) | |
loss = compute_loss(y_pred, y, mode=config.get(LOSS, LOSS_MSE)) | |
if torch.isnan(loss): | |
print(f"Loss is NaN at epoch {n_epoch} and phase {phase}!") | |
continue | |
if phase == TRAIN: | |
loss.backward() | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.) | |
optimizer.step() | |
current_metrics[phase] += loss.item() | |
if phase == VALIDATION: | |
metrics_on_batch = compute_metrics( | |
y_pred, | |
y, | |
chosen_metrics=chosen_metrics, | |
reduction=REDUCTION_SUM | |
) | |
total_elements += y_pred.shape[0] | |
for k, v in metrics_on_batch.items(): | |
current_metrics[k] += v | |
current_metrics[phase] /= (len(dl_dict[phase])) | |
if phase == VALIDATION: | |
for k, v in metrics_on_batch.items(): | |
current_metrics[k] /= total_elements | |
try: | |
current_metrics[k] = current_metrics[k].item() | |
except AttributeError: | |
pass | |
debug_print = f"{phase}: Epoch {n_epoch} - Loss: {current_metrics[phase]:.3e} " | |
for k, v in current_metrics.items(): | |
if k not in [TRAIN, VALIDATION, LR]: | |
debug_print += f"{k}: {v:.3} |" | |
print(debug_print) | |
if scheduler is not None and isinstance(scheduler, ReduceLROnPlateau): | |
scheduler.step(current_metrics[VALIDATION]) | |
if output_dir is not None: | |
with open(output_dir/f"metrics_{n_epoch}.json", "w") as f: | |
json.dump(current_metrics, f) | |
if wandb_flag: | |
wandb.log(current_metrics) | |
if best_accuracy < current_metrics[METRIC_PSNR]: | |
best_accuracy = current_metrics[METRIC_PSNR] | |
if output_dir is not None: | |
print("new best model saved!") | |
torch.save(model.state_dict(), output_dir/"best_model.pt") | |
if output_dir is not None: | |
torch.save(model.cpu().state_dict(), output_dir/"last_model.pt") | |
return model | |
def train(config: dict, output_dir: Path, device: str = DEVICE, wandb_flag: bool = False): | |
logging.basicConfig(level=logging.INFO) | |
logging.info(f"Training experiment {config[ID]} on device {device}...") | |
output_dir.mkdir(parents=True, exist_ok=True) | |
with open(output_dir/"config.json", "w") as f: | |
json.dump(config, f) | |
model, optimizer, dl_dict = get_training_content(config, training_mode=True, device=device) | |
model.to(device) | |
if wandb_flag: | |
import wandb | |
wandb.init( | |
project=WANDBSPACE, | |
entity="balthazarneveu", | |
name=config[NAME], | |
tags=["debug"], | |
# tags=["base"], | |
config=config | |
) | |
scheduler = None | |
if config.get(SCHEDULER, False): | |
scheduler_config = config[SCHEDULER_CONFIGURATION] | |
if config[SCHEDULER] == REDUCELRONPLATEAU: | |
scheduler = ReduceLROnPlateau(optimizer, mode='min', verbose=True, **scheduler_config) | |
else: | |
raise NameError(f"Scheduler {config[SCHEDULER]} not implemented") | |
model = training_loop(model, optimizer, dl_dict, config, scheduler=scheduler, device=device, | |
wandb_flag=wandb_flag, output_dir=output_dir) | |
if wandb_flag: | |
wandb.finish() | |
def train_main(argv): | |
parser = get_parser() | |
args = parser.parse_args(argv) | |
if not WANDB_AVAILABLE: | |
args.no_wandb = True | |
device = "cpu" if args.cpu else DEVICE | |
for exp in args.exp: | |
config = get_experiment_config(exp) | |
print(config) | |
output_dir = Path(args.output_dir)/config[NAME] | |
logging.info(f"Training experiment {config[ID]} on device {device}...") | |
train(config, device=device, output_dir=output_dir, wandb_flag=not args.no_wandb) | |
if __name__ == "__main__": | |
train_main(sys.argv[1:]) | |