from models.controlnet import ControlNetModel from safetensors.torch import load_file import torch def load_safetensors(model, safetensors_path, strict=True, load_weight_increasement=False): if not load_weight_increasement: state_dict = load_file(safetensors_path) model.load_state_dict(state_dict, strict=strict) else: state_dict = load_file(safetensors_path) pretrained_state_dict = model.state_dict() for k in state_dict.keys(): state_dict[k] = state_dict[k] + pretrained_state_dict[k] model.load_state_dict(state_dict, strict=False) controlnet = ControlNetModel() device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') controlnet.to(device) load_safetensors(controlnet, '/home/ControlNeXt/ControlNeXt-SDXL/controlnet.safetensors') image = torch.randn((1, 3, 1024, 1024), dtype=torch.float32).to(device) timestep = torch.rand(1, dtype= torch.float32).to(device) dummy_inputs = (image, timestep) onnx_output_path = '/home/new_onnx/cnext/model.onnx' torch.onnx.export( controlnet, dummy_inputs, onnx_output_path, export_params=True, opset_version=18, do_constant_folding=True, input_names=['controlnext_image', 'timestep'], output_names=['sample'], dynamic_axes={ 'controlnext_image': {0: 'batch_size', 2: 'height', 3: 'width'}, 'sample': {0: 'batch_size'}, } )