File size: 1,977 Bytes
67f650b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from models.unet import UNet2DConditionModel
import torch
from ip_adapter import IPAdapterXL
from safetensors.torch import load_file
import onnx
from pathlib import Path

output_path = '/home/new_onnx/unet'
output_path = Path(output_path)

unet = UNet2DConditionModel.from_pretrained(
        "neta-art/neta-xl-2.0", 
        subfolder="unet", 
    )

state_dict = load_file('/home/ControlNeXt/ControlNeXt-SDXL/unet.safetensors')
unet.load_state_dict(state_dict, strict=False)
        
image_encoder_path = "h94/IP-Adapter"
ip_ckpt = "h94/IP-Adapter"
device = 'cpu'
ip_model = IPAdapterXL(unet, image_encoder_path, ip_ckpt, device, num_tokens=4)

unet = ip_model.unet

sample = torch.randn((1, 4, 128, 128))
timestep = torch.rand(1, dtype=torch.float32)
encoder_hidden_state = torch.randn((1, 81, 2048))
mid_block_additional_residual_scale = torch.tensor([1], dtype=torch.float32)
mid_block_additional_residual = torch.randn((1, 320, 128, 128), dtype=torch.float32)

dummy_inputs = (sample, timestep, encoder_hidden_state, mid_block_additional_residual, mid_block_additional_residual_scale)

onnx_output_path = output_path / "unet" / "model.onnx"
torch.onnx.export(
    unet,
    dummy_inputs,         
    str(onnx_output_path),  # Đường dẫn dưới dạng chuỗi để đảm bảo tương thích
    export_params=True,
    opset_version=18,
    do_constant_folding=True,
    input_names=['sample', 'timestep', 'encoder_hidden_state', 'control_out', 'control_scale'],   
    output_names=['predict_noise'],  
    dynamic_axes={
        "sample": {0: "B"},
        "encoder_hidden_state": {0: "B", 1: "1B", 2: '2B'},  
        "control_out": {0: "B"},
        "predict_noise": {0: 'B'}
    }
)

unet_opt_graph = onnx.load(str(onnx_output_path))
unet_optimize = output_path / "unet_optimize" / "model.onnx"
onnx.save_model(
    unet_opt_graph,
    str(unet_optimize),  
    save_as_external_data=True, 
    all_tensors_to_one_file=True,  
    location="weights.pb", 
)