Bethie's picture
Code convert ONNX
67f650b verified
from models.controlnet import ControlNetModel
from safetensors.torch import load_file
import torch
def load_safetensors(model, safetensors_path, strict=True, load_weight_increasement=False):
if not load_weight_increasement:
state_dict = load_file(safetensors_path)
model.load_state_dict(state_dict, strict=strict)
else:
state_dict = load_file(safetensors_path)
pretrained_state_dict = model.state_dict()
for k in state_dict.keys():
state_dict[k] = state_dict[k] + pretrained_state_dict[k]
model.load_state_dict(state_dict, strict=False)
controlnet = ControlNetModel()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
controlnet.to(device)
load_safetensors(controlnet, '/home/ControlNeXt/ControlNeXt-SDXL/controlnet.safetensors')
image = torch.randn((1, 3, 1024, 1024), dtype=torch.float32).to(device)
timestep = torch.rand(1, dtype= torch.float32).to(device)
dummy_inputs = (image, timestep)
onnx_output_path = '/home/new_onnx/cnext/model.onnx'
torch.onnx.export(
controlnet,
dummy_inputs,
onnx_output_path,
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=['controlnext_image', 'timestep'],
output_names=['sample'],
dynamic_axes={
'controlnext_image': {0: 'batch_size', 2: 'height', 3: 'width'},
'sample': {0: 'batch_size'},
}
)