|
""" |
|
A set of utilities to manage and load checkpoints of training experiments. |
|
|
|
Author: Paul-Edouard Sarlin (skydes) |
|
""" |
|
|
|
import logging |
|
import os |
|
import re |
|
import shutil |
|
from pathlib import Path |
|
|
|
import torch |
|
from omegaconf import OmegaConf |
|
|
|
from siclib.models import get_model |
|
from siclib.settings import TRAINING_PATH |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def list_checkpoints(dir_): |
|
"""List all valid checkpoints in a given directory.""" |
|
checkpoints = [] |
|
for p in dir_.glob("checkpoint_*.tar"): |
|
numbers = re.findall(r"(\d+)", p.name) |
|
assert len(numbers) <= 2 |
|
if len(numbers) == 0: |
|
continue |
|
if len(numbers) == 1: |
|
checkpoints.append((int(numbers[0]), p)) |
|
else: |
|
checkpoints.append((int(numbers[1]), p)) |
|
return checkpoints |
|
|
|
|
|
def get_last_checkpoint(exper, allow_interrupted=True): |
|
"""Get the last saved checkpoint for a given experiment name.""" |
|
ckpts = list_checkpoints(Path(TRAINING_PATH, exper)) |
|
if not allow_interrupted: |
|
ckpts = [(n, p) for (n, p) in ckpts if "_interrupted" not in p.name] |
|
assert len(ckpts) > 0 |
|
return sorted(ckpts)[-1][1] |
|
|
|
|
|
def get_best_checkpoint(exper): |
|
"""Get the checkpoint with the best loss, for a given experiment name.""" |
|
return Path(TRAINING_PATH, exper, "checkpoint_best.tar") |
|
|
|
|
|
def delete_old_checkpoints(dir_, num_keep): |
|
"""Delete all but the num_keep last saved checkpoints.""" |
|
ckpts = list_checkpoints(dir_) |
|
ckpts = sorted(ckpts)[::-1] |
|
kept = 0 |
|
for ckpt in ckpts: |
|
if ("_interrupted" in str(ckpt[1]) and kept > 0) or kept >= num_keep: |
|
logger.info(f"Deleting checkpoint {ckpt[1].name}") |
|
ckpt[1].unlink() |
|
else: |
|
kept += 1 |
|
|
|
|
|
def load_experiment(exper, conf=None, get_last=False, ckpt=None): |
|
"""Load and return the model of a given experiment.""" |
|
if conf is None: |
|
conf = {} |
|
|
|
exper = Path(exper) |
|
if exper.suffix != ".tar": |
|
ckpt = get_last_checkpoint(exper) if get_last else get_best_checkpoint(exper) |
|
else: |
|
ckpt = exper |
|
logger.info(f"Loading checkpoint {ckpt.name}") |
|
ckpt = torch.load(str(ckpt), map_location="cpu") |
|
|
|
loaded_conf = OmegaConf.create(ckpt["conf"]) |
|
OmegaConf.set_struct(loaded_conf, False) |
|
conf = OmegaConf.merge(loaded_conf.model, OmegaConf.create(conf)) |
|
model = get_model(conf.name)(conf).eval() |
|
|
|
state_dict = ckpt["model"] |
|
|
|
dict_params = set(state_dict.keys()) |
|
model_params = set(map(lambda n: n[0], model.named_parameters())) |
|
diff = model_params - dict_params |
|
if len(diff) > 0: |
|
subs = os.path.commonprefix(list(diff)).rstrip(".") |
|
logger.warning(f"Missing {len(diff)} parameters in {subs}: {diff}") |
|
model.load_state_dict(state_dict, strict=False) |
|
return model |
|
|
|
|
|
def save_experiment( |
|
model, |
|
optimizer, |
|
lr_scheduler, |
|
conf, |
|
losses, |
|
results, |
|
best_eval, |
|
epoch, |
|
iter_i, |
|
output_dir, |
|
stop=False, |
|
distributed=False, |
|
cp_name=None, |
|
): |
|
"""Save the current model to a checkpoint |
|
and return the best result so far.""" |
|
state = (model.module if distributed else model).state_dict() |
|
checkpoint = { |
|
"model": state, |
|
"optimizer": optimizer.state_dict(), |
|
"lr_scheduler": lr_scheduler.state_dict(), |
|
"conf": OmegaConf.to_container(conf, resolve=True), |
|
"epoch": epoch, |
|
"losses": losses, |
|
"eval": results, |
|
} |
|
if cp_name is None: |
|
cp_name = f"checkpoint_{epoch}_{iter_i}" + ("_interrupted" if stop else "") + ".tar" |
|
logger.info(f"Saving checkpoint {cp_name}") |
|
cp_path = str(output_dir / cp_name) |
|
torch.save(checkpoint, cp_path) |
|
|
|
if cp_name != "checkpoint_best.tar" and results[conf.train.best_key] < best_eval: |
|
best_eval = results[conf.train.best_key] |
|
logger.info(f"New best val: {conf.train.best_key}={best_eval}") |
|
shutil.copy(cp_path, str(output_dir / "checkpoint_best.tar")) |
|
delete_old_checkpoints(output_dir, conf.train.keep_last_checkpoints) |
|
return best_eval |
|
|