phylo-diffusion / ldm /loading_utils.py
mridulk's picture
added few ldm files
642d5e2
raw
history blame
1.2 kB
#based on https://github.com/CompVis/taming-transformers
import yaml
from omegaconf import OmegaConf
import torch
from ldm.util import instantiate_from_config
######### loaders
def load_config(config_path, display=False):
config = OmegaConf.load(config_path)
if display:
print(yaml.dump(OmegaConf.to_container(config)))
return config
def load_model_from_config(config, ckpt):
print(f"Loading model from {ckpt}")
pl_sd = torch.load(ckpt)#, map_location="cpu")
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
model.cuda()
model.eval()
return model
def load_model(config_path, ckpt_path=None):
# def load_model(config_path, ckpt_path=None, cuda=False, model_type=VQModel):
# breakpoint()
# model = model_type(**config.model.params)
# if ckpt_path is not None:
# sd = torch.load(ckpt_path, map_location="cpu")["state_dict"]
# missing, unexpected = model.load_state_dict(sd, strict=True)
# if cuda:
# model = model.cuda()
config = OmegaConf.load(config_path)
model = load_model_from_config(config, ckpt_path)
return model