riccorl's picture
first commit
626eca0
raw
history blame
2.12 kB
import argparse
import os
from typing import Tuple
import omegaconf
import torch
from relik.common.utils import from_cache
from relik.reader.lightning_modules.relik_reader_pl_module import RelikReaderPLModule
from relik.reader.relik_reader_core import RelikReaderCoreModel
CKPT_FILE_NAME = "model.ckpt"
CONFIG_FILE_NAME = "cfg.yaml"
def convert_pl_module(pl_module_ckpt_path: str, output_dir: str) -> None:
if not os.path.exists(output_dir):
os.makedirs(output_dir)
else:
print(f"{output_dir} already exists, aborting operation")
exit(1)
relik_pl_module: RelikReaderPLModule = RelikReaderPLModule.load_from_checkpoint(
pl_module_ckpt_path
)
torch.save(
relik_pl_module.relik_reader_core_model, f"{output_dir}/{CKPT_FILE_NAME}"
)
with open(f"{output_dir}/{CONFIG_FILE_NAME}", "w") as f:
omegaconf.OmegaConf.save(
omegaconf.OmegaConf.create(relik_pl_module.hparams["cfg"]), f
)
def load_model_and_conf(
model_dir_path: str,
) -> Tuple[RelikReaderCoreModel, omegaconf.DictConfig]:
# TODO: quick workaround to load the model from HF hub
model_dir = from_cache(
model_dir_path,
filenames=[CKPT_FILE_NAME, CONFIG_FILE_NAME],
cache_dir=None,
force_download=False,
)
ckpt_path = f"{model_dir}/{CKPT_FILE_NAME}"
model = torch.load(ckpt_path, map_location=torch.device("cpu"))
model_cfg_path = f"{model_dir}/{CONFIG_FILE_NAME}"
model_conf = omegaconf.OmegaConf.load(model_cfg_path)
return model, model_conf
def parse_arg() -> argparse.Namespace:
parser = argparse.ArgumentParser()
parser.add_argument(
"--ckpt",
help="Path to the pytorch lightning ckpt you want to convert.",
required=True,
)
parser.add_argument(
"--output-dir",
"-o",
help="The output dir to store the bare models and the config.",
required=True,
)
return parser.parse_args()
def main():
args = parse_arg()
convert_pl_module(args.ckpt, args.output_dir)
if __name__ == "__main__":
main()