OpenSound's picture
Upload 84 files
b9d6819 verified
raw
history blame
1.25 kB
from .models.autoencoders import create_autoencoder_from_config
import os
import json
import torch
from torch.nn.utils import remove_weight_norm
def remove_all_weight_norm(model):
for name, module in model.named_modules():
if hasattr(module, 'weight_g'):
remove_weight_norm(module)
def load_vae(ckpt_path, remove_weight_norm=False):
config_file = os.path.join(os.path.dirname(ckpt_path), 'config.json')
# Load the model configuration
with open(config_file) as f:
model_config = json.load(f)
# Create the model from the configuration
model = create_autoencoder_from_config(model_config)
# Load the state dictionary from the checkpoint
model_dict = torch.load(ckpt_path, map_location='cpu')['state_dict']
# Strip the "autoencoder." prefix from the keys
model_dict = {key[len("autoencoder."):]: value for key, value in model_dict.items() if key.startswith("autoencoder.")}
# Load the state dictionary into the model
model.load_state_dict(model_dict)
# Remove weight normalization
if remove_weight_norm:
remove_all_weight_norm(model)
# Set the model to evaluation mode
model.eval()
return model