Spaces:
Running
on
Zero
Running
on
Zero
from .models.autoencoders import create_autoencoder_from_config | |
import os | |
import json | |
import torch | |
from torch.nn.utils import remove_weight_norm | |
def remove_all_weight_norm(model): | |
for name, module in model.named_modules(): | |
if hasattr(module, 'weight_g'): | |
remove_weight_norm(module) | |
def load_vae(ckpt_path, remove_weight_norm=False): | |
config_file = os.path.join(os.path.dirname(ckpt_path), 'config.json') | |
# Load the model configuration | |
with open(config_file) as f: | |
model_config = json.load(f) | |
# Create the model from the configuration | |
model = create_autoencoder_from_config(model_config) | |
# Load the state dictionary from the checkpoint | |
model_dict = torch.load(ckpt_path, map_location='cpu')['state_dict'] | |
# Strip the "autoencoder." prefix from the keys | |
model_dict = {key[len("autoencoder."):]: value for key, value in model_dict.items() if key.startswith("autoencoder.")} | |
# Load the state dictionary into the model | |
model.load_state_dict(model_dict) | |
# Remove weight normalization | |
if remove_weight_norm: | |
remove_all_weight_norm(model) | |
# Set the model to evaluation mode | |
model.eval() | |
return model | |