File size: 2,117 Bytes
626eca0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import os
from typing import Tuple

import omegaconf
import torch

from relik.common.utils import from_cache
from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule
from relik.reader.relik_reader_core import RelikReaderCoreModel

CKPT_FILE_NAME = "model.ckpt"
CONFIG_FILE_NAME = "cfg.yaml"


def convert_pl_module(pl_module_ckpt_path: str, output_dir: str) -> None:
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
    else:
        print(f"{output_dir} already exists, aborting operation")
        exit(1)

    relik_pl_module: RelikReaderPLModule = RelikReaderPLModule.load_from_checkpoint(
        pl_module_ckpt_path
    )
    torch.save(
        relik_pl_module.relik_reader_core_model, f"{output_dir}/{CKPT_FILE_NAME}"
    )
    with open(f"{output_dir}/{CONFIG_FILE_NAME}", "w") as f:
        omegaconf.OmegaConf.save(
            omegaconf.OmegaConf.create(relik_pl_module.hparams["cfg"]), f
        )


def load_model_and_conf(
    model_dir_path: str,
) -> Tuple[RelikReaderCoreModel, omegaconf.DictConfig]:
    # TODO: quick workaround to load the model from HF hub
    model_dir = from_cache(
        model_dir_path,
        filenames=[CKPT_FILE_NAME, CONFIG_FILE_NAME],
        cache_dir=None,
        force_download=False,
    )

    ckpt_path = f"{model_dir}/{CKPT_FILE_NAME}"
    model = torch.load(ckpt_path, map_location=torch.device("cpu"))

    model_cfg_path = f"{model_dir}/{CONFIG_FILE_NAME}"
    model_conf = omegaconf.OmegaConf.load(model_cfg_path)
    return model, model_conf


def parse_arg() -> argparse.Namespace:
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--ckpt",
        help="Path to the pytorch lightning ckpt you want to convert.",
        required=True,
    )
    parser.add_argument(
        "--output-dir",
        "-o",
        help="The output dir to store the bare models and the config.",
        required=True,
    )
    return parser.parse_args()


def main():
    args = parse_arg()
    convert_pl_module(args.ckpt, args.output_dir)


if __name__ == "__main__":
    main()