File size: 2,697 Bytes
f6b56a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from gyraudio.audio_separation.properties import SHORT_NAME, MODEL, OPTIMIZER, CURRENT_EPOCH, CONFIGURATION
from pathlib import Path
from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT
import logging
import torch


def get_output_folder(config: dict, root_dir: Path = EXPERIMENT_STORAGE_ROOT, override: bool = False) -> Path:
    output_folder = root_dir/config["short_name"]
    exists = False
    if output_folder.exists():
        if not override:
            logging.info(f"Experiment {config[SHORT_NAME]} already exists. Override is set to False. Skipping.")
        if override:
            logging.warning(f"Experiment {config[SHORT_NAME]} will be OVERRIDDEN")
            exists = True
    else:
        output_folder.mkdir(parents=True, exist_ok=True)
        exists = True
    return exists, output_folder


def checkpoint_paths(exp_dir: Path, epoch=None):
    if epoch is None:
        checkpoints = sorted(exp_dir.glob("model_*.pt"))
        assert len(checkpoints) > 0, f"No checkpoints found in {exp_dir}"
        model_checkpoint = checkpoints[-1]
        epoch = int(model_checkpoint.stem.split("_")[-1])
        optimizer_checkpoint = exp_dir/model_checkpoint.stem.replace("model", "optimizer")
    else:
        model_checkpoint = exp_dir/f"model_{epoch:04d}.pt"
        optimizer_checkpoint = exp_dir/f"optimizer_{epoch:04d}.pt"
    return model_checkpoint, optimizer_checkpoint, epoch


def load_checkpoint(model, exp_dir: Path, optimizer=None, epoch: int = None,
                    device="cuda" if torch.cuda.is_available() else "cpu"):
    config = {}
    model_checkpoint, optimizer_checkpoint, epoch = checkpoint_paths(exp_dir, epoch=epoch)
    model_state_dict = torch.load(model_checkpoint, map_location=torch.device(device))
    model.load_state_dict(model_state_dict[MODEL])
    if optimizer is not None:
        optimizer_state_dict = torch.load(optimizer_checkpoint, map_location=torch.device(device))
        optimizer.load_state_dict(optimizer_state_dict[OPTIMIZER])
        config = optimizer_state_dict[CONFIGURATION]
    return model, optimizer, epoch, config


def save_checkpoint(model, exp_dir: Path, optimizer=None, config: dict = {}, epoch: int = None):
    model_checkpoint, optimizer_checkpoint, epoch = checkpoint_paths(exp_dir, epoch=epoch)
    torch.save(
        {
            MODEL: model.state_dict(),
        },
        model_checkpoint
    )
    torch.save(
        {
            CURRENT_EPOCH: epoch,
            CONFIGURATION: config,
            OPTIMIZER: optimizer.state_dict()
        },
        optimizer_checkpoint
    )
    print(f"Checkpoint saved:\n   - model: {model_checkpoint}\n   - checkpoint: {optimizer_checkpoint}")