File size: 4,120 Bytes
205a7af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
"""
A set of utilities to manage and load checkpoints of training experiments.
Author: Paul-Edouard Sarlin (skydes)
"""
import logging
import os
import re
import shutil
from pathlib import Path
import torch
from omegaconf import OmegaConf
from siclib.models import get_model
from siclib.settings import TRAINING_PATH
logger = logging.getLogger(__name__)
# flake8: noqa
# mypy: ignore-errors
def list_checkpoints(dir_):
"""List all valid checkpoints in a given directory."""
checkpoints = []
for p in dir_.glob("checkpoint_*.tar"):
numbers = re.findall(r"(\d+)", p.name)
assert len(numbers) <= 2
if len(numbers) == 0:
continue
if len(numbers) == 1:
checkpoints.append((int(numbers[0]), p))
else:
checkpoints.append((int(numbers[1]), p))
return checkpoints
def get_last_checkpoint(exper, allow_interrupted=True):
"""Get the last saved checkpoint for a given experiment name."""
ckpts = list_checkpoints(Path(TRAINING_PATH, exper))
if not allow_interrupted:
ckpts = [(n, p) for (n, p) in ckpts if "_interrupted" not in p.name]
assert len(ckpts) > 0
return sorted(ckpts)[-1][1]
def get_best_checkpoint(exper):
"""Get the checkpoint with the best loss, for a given experiment name."""
return Path(TRAINING_PATH, exper, "checkpoint_best.tar")
def delete_old_checkpoints(dir_, num_keep):
"""Delete all but the num_keep last saved checkpoints."""
ckpts = list_checkpoints(dir_)
ckpts = sorted(ckpts)[::-1]
kept = 0
for ckpt in ckpts:
if ("_interrupted" in str(ckpt[1]) and kept > 0) or kept >= num_keep:
logger.info(f"Deleting checkpoint {ckpt[1].name}")
ckpt[1].unlink()
else:
kept += 1
def load_experiment(exper, conf=None, get_last=False, ckpt=None):
"""Load and return the model of a given experiment."""
if conf is None:
conf = {}
exper = Path(exper)
if exper.suffix != ".tar":
ckpt = get_last_checkpoint(exper) if get_last else get_best_checkpoint(exper)
else:
ckpt = exper
logger.info(f"Loading checkpoint {ckpt.name}")
ckpt = torch.load(str(ckpt), map_location="cpu")
loaded_conf = OmegaConf.create(ckpt["conf"])
OmegaConf.set_struct(loaded_conf, False)
conf = OmegaConf.merge(loaded_conf.model, OmegaConf.create(conf))
model = get_model(conf.name)(conf).eval()
state_dict = ckpt["model"]
dict_params = set(state_dict.keys())
model_params = set(map(lambda n: n[0], model.named_parameters()))
diff = model_params - dict_params
if len(diff) > 0:
subs = os.path.commonprefix(list(diff)).rstrip(".")
logger.warning(f"Missing {len(diff)} parameters in {subs}: {diff}")
model.load_state_dict(state_dict, strict=False)
return model
def save_experiment(
model,
optimizer,
lr_scheduler,
conf,
losses,
results,
best_eval,
epoch,
iter_i,
output_dir,
stop=False,
distributed=False,
cp_name=None,
):
"""Save the current model to a checkpoint
and return the best result so far."""
state = (model.module if distributed else model).state_dict()
checkpoint = {
"model": state,
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"conf": OmegaConf.to_container(conf, resolve=True),
"epoch": epoch,
"losses": losses,
"eval": results,
}
if cp_name is None:
cp_name = f"checkpoint_{epoch}_{iter_i}" + ("_interrupted" if stop else "") + ".tar"
logger.info(f"Saving checkpoint {cp_name}")
cp_path = str(output_dir / cp_name)
torch.save(checkpoint, cp_path)
if cp_name != "checkpoint_best.tar" and results[conf.train.best_key] < best_eval:
best_eval = results[conf.train.best_key]
logger.info(f"New best val: {conf.train.best_key}={best_eval}")
shutil.copy(cp_path, str(output_dir / "checkpoint_best.tar"))
delete_old_checkpoints(output_dir, conf.train.keep_last_checkpoints)
return best_eval
|