LHM / scripts /convert_hf.py
QZFantasies's picture
add wheels
c614b0f
raw
history blame
3.29 kB
# Copyright (c) 2023-2024, Zexin He
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import pdb
import sys
import traceback
from tempfile import TemporaryDirectory
import safetensors
import torch.nn as nn
from accelerate import Accelerator
from megfile import (
smart_copy,
smart_exists,
smart_listdir,
smart_makedirs,
smart_path_join,
)
from omegaconf import OmegaConf
sys.path.append(".")
from LHM.models import model_dict
from LHM.utils.hf_hub import wrap_model_hub
from LHM.utils.proxy import no_proxy
@no_proxy
def auto_load_model(cfg, model: nn.Module) -> int:
ckpt_root = smart_path_join(
cfg.saver.checkpoint_root,
cfg.experiment.parent,
cfg.experiment.child,
)
if not smart_exists(ckpt_root):
raise FileNotFoundError(f"Checkpoint root not found: {ckpt_root}")
ckpt_dirs = smart_listdir(ckpt_root)
if len(ckpt_dirs) == 0:
raise FileNotFoundError(f"No checkpoint found in {ckpt_root}")
ckpt_dirs.sort()
load_step = (
f"{cfg.convert.global_step}"
if cfg.convert.global_step is not None
else ckpt_dirs[-1]
)
load_model_path = smart_path_join(ckpt_root, load_step, "model.safetensors")
if load_model_path.startswith("s3"):
tmpdir = TemporaryDirectory()
tmp_model_path = smart_path_join(tmpdir.name, f"tmp.safetensors")
smart_copy(load_model_path, tmp_model_path)
load_model_path = tmp_model_path
print(f"Loading from {load_model_path}")
try:
safetensors.torch.load_model(model, load_model_path, strict=True)
except:
traceback.print_exc()
safetensors.torch.load_model(model, load_model_path, strict=False)
return int(load_step)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default="./assets/config.yaml")
args, unknown = parser.parse_known_args()
cfg = OmegaConf.load(args.config)
cli_cfg = OmegaConf.from_cli(unknown)
cfg = OmegaConf.merge(cfg, cli_cfg)
"""
[cfg.convert]
global_step: int
save_dir: str
"""
accelerator = Accelerator()
# hf_model_cls = wrap_model_hub(model_dict[cfg.experiment.type])
hf_model_cls = wrap_model_hub(model_dict["human_lrm_sapdino_bh_sd3_5"])
hf_model = hf_model_cls(OmegaConf.to_container(cfg.model))
loaded_step = auto_load_model(cfg, hf_model)
dump_path = smart_path_join(
f"./exps/releases",
cfg.experiment.parent,
cfg.experiment.child,
f"step_{loaded_step:06d}",
)
print(f"Saving locally to {dump_path}")
smart_makedirs(dump_path, exist_ok=True)
hf_model.save_pretrained(
save_directory=dump_path,
config=hf_model.config,
)