|
from models.controlnet import ControlNetModel |
|
from safetensors.torch import load_file |
|
import torch |
|
|
|
def load_safetensors(model, safetensors_path, strict=True, load_weight_increasement=False): |
|
if not load_weight_increasement: |
|
state_dict = load_file(safetensors_path) |
|
model.load_state_dict(state_dict, strict=strict) |
|
else: |
|
state_dict = load_file(safetensors_path) |
|
pretrained_state_dict = model.state_dict() |
|
for k in state_dict.keys(): |
|
state_dict[k] = state_dict[k] + pretrained_state_dict[k] |
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
controlnet = ControlNetModel() |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
controlnet.to(device) |
|
|
|
load_safetensors(controlnet, '/home/ControlNeXt/ControlNeXt-SDXL/controlnet.safetensors') |
|
|
|
image = torch.randn((1, 3, 1024, 1024), dtype=torch.float32).to(device) |
|
timestep = torch.rand(1, dtype= torch.float32).to(device) |
|
|
|
dummy_inputs = (image, timestep) |
|
|
|
onnx_output_path = '/home/new_onnx/cnext/model.onnx' |
|
torch.onnx.export( |
|
controlnet, |
|
dummy_inputs, |
|
onnx_output_path, |
|
export_params=True, |
|
opset_version=18, |
|
do_constant_folding=True, |
|
input_names=['controlnext_image', 'timestep'], |
|
output_names=['sample'], |
|
dynamic_axes={ |
|
'controlnext_image': {0: 'batch_size', 2: 'height', 3: 'width'}, |
|
'sample': {0: 'batch_size'}, |
|
} |
|
) |