File size: 1,470 Bytes
67f650b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from models.controlnet import ControlNetModel
from safetensors.torch import load_file
import torch

def load_safetensors(model, safetensors_path, strict=True, load_weight_increasement=False):
    if not load_weight_increasement:
        state_dict = load_file(safetensors_path)
        model.load_state_dict(state_dict, strict=strict)
    else:
        state_dict = load_file(safetensors_path)
        pretrained_state_dict = model.state_dict()
        for k in state_dict.keys():
            state_dict[k] = state_dict[k] + pretrained_state_dict[k]
        model.load_state_dict(state_dict, strict=False)
        
controlnet = ControlNetModel()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
controlnet.to(device)

load_safetensors(controlnet, '/home/ControlNeXt/ControlNeXt-SDXL/controlnet.safetensors')

image = torch.randn((1, 3, 1024, 1024), dtype=torch.float32).to(device)
timestep = torch.rand(1, dtype= torch.float32).to(device)

dummy_inputs = (image, timestep)

onnx_output_path = '/home/new_onnx/cnext/model.onnx'
torch.onnx.export(
    controlnet,
    dummy_inputs,               
    onnx_output_path,          
    export_params=True,         
    opset_version=18,           
    do_constant_folding=True,   
    input_names=['controlnext_image', 'timestep'],  
    output_names=['sample'],    
    dynamic_axes={
        'controlnext_image': {0: 'batch_size', 2: 'height', 3: 'width'},
        'sample': {0: 'batch_size'},
    }
)