Spaces:
Running
on
Zero
Running
on
Zero
import sys | |
import os | |
import torch | |
from omegaconf import OmegaConf | |
from ldm.util import instantiate_from_config | |
input_path = sys.argv[1] | |
output_path = sys.argv[2] | |
config_path = sys.argv[3] | |
''' | |
# Determine whether the weight needs to be controlled? | |
# In simple terms, in the model, all weights that begin with "control_" need to be controlled. For instance, "control_model.middle_block_out.0.bias" need to be controlled. | |
# In code, "control_model.middle_block_out.0.bias" belongs to "self.control_model" | |
# Return True, "model.middle_block_out.0.bias" | |
''' | |
def get_node_name(name, parent_name): | |
if len(name) <= len(parent_name): | |
return False, '' | |
p = name[:len(parent_name)] | |
if p != parent_name: | |
return False, '' | |
return True, name[len(parent_name):] | |
# Load model | |
configs = OmegaConf.load(config_path) | |
model = instantiate_from_config(configs["model"]) | |
scratch_dict = model.state_dict() | |
# Load pre-trained weights | |
pretrained_weights = torch.load(input_path)['state_dict'] | |
# Generate target weights | |
target_dict = {} | |
for k in scratch_dict.keys(): | |
is_control, name = get_node_name(k, 'control_') | |
# Need to be controlled | |
if is_control: | |
copy_k = 'model.diffusion_' + name | |
# Don't need to be controlled | |
else: | |
copy_k = k | |
# The weights that exist in pbe, copy from it | |
if copy_k in pretrained_weights: | |
target_dict[k] = pretrained_weights[copy_k].clone() | |
# The weights not existing in pbe, set to zero | |
else: | |
target_dict[k] = scratch_dict[k].clone() | |
print(f'These weights are newly added: {k}') | |
# Save | |
model.load_state_dict(target_dict, strict=True) | |
torch.save(model.state_dict(), output_path) | |
print('Done.') |