File size: 3,651 Bytes
745d42a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import torch
import onnx
from pathlib import Path
from diffusers import DiffusionPipeline, StableDiffusionPipeline
import torch
from utilities import load_calib_prompts
from utilities import get_smoothquant_config
import ammo.torch.quantization as atq 
import ammo.torch.opt as ato
from utilities import filter_func, quantize_lvl

# pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", 
#                                              torch_dtype=torch.float16, 
#                                              use_safetensors=True, 
#                                              variant="fp16")

pipeline = StableDiffusionPipeline.from_pretrained("wyyadd/sd-1.5", torch_dtype=torch.float16)

pipeline.to("cuda")
# pipeline.enable_xformers_memory_efficient_attention()
# pipeline.enable_vae_slicing()

BATCH_SIZE = 4
cali_prompts = load_calib_prompts(batch_size=BATCH_SIZE, calib_data_path="./calibration-prompts.txt")

quant_config = get_smoothquant_config(pipeline.unet, quant_level=3.0)
 
def do_calibrate(base, calibration_prompts, **kwargs):
    for i_th, prompts in enumerate(calibration_prompts):
        print(prompts)
        if i_th >= kwargs["calib_size"]:
            return
        base(
            prompt=prompts,
            num_inference_steps=kwargs["n_steps"],
            negative_prompt=[
                "normal quality, low quality, worst quality, low res, blurry, nsfw, nude"
            ]
            * len(prompts),
        ).images

def calibration_loop():
    do_calibrate(
        base=pipeline,
        calibration_prompts=cali_prompts,
        calib_size=384,
        n_steps=50,
    )
    
    
quantized_model = atq.quantize(pipeline.unet, quant_config, forward_loop = calibration_loop)
ato.save(quantized_model, 'base.unet15_2.int8.pt')

quantize_lvl(quantized_model, quant_level=3.0)
atq.disable_quantizer(quantized_model, filter_func) 

device1 = "cpu"
quantized_model = quantized_model.to(torch.float32).to(device1)

#Export model
sample = torch.randn((1, 4, 128, 128), dtype=torch.float32, device=device1)
timestep = torch.rand(1, dtype=torch.float32, device=device1)
encoder_hidden_state = torch.randn((1, 77, 768), dtype=torch.float32, device=device1)

import onnx
from pathlib import Path

output_path = Path('/home/tiennv/trang/Convert-_Unet_int8_Rebuild/Diffusion/onnx_unet15')
output_path.mkdir(parents=True, exist_ok=True)

dummy_inputs = (sample, timestep, encoder_hidden_state)

onnx_output_path = output_path / "unet" / "model.onnx"
onnx_output_path.parent.mkdir(parents=True, exist_ok=True)


# to cpu to export onnx
# from onnx_utils import ammo_export_sd
# base.unet.to(torch.float32).to("cpu")
# ammo_export_sd(base, 'onnx_dir', 'stabilityai/stable-diffusion-xl-base-1.0')

torch.onnx.export(
    quantized_model,
    dummy_inputs,         
    str(onnx_output_path),  
    export_params=True,
    opset_version=18,
    do_constant_folding=True,
    input_names=['sample', 'timestep', 'encoder_hidden_state'],   
    output_names=['predict_noise'],  
    dynamic_axes={
        "sample": {0: "B", 2: "W", 3: 'H'},
        "encoder_hidden_state": {0: "B", 1: "S", 2: 'D'},  
        "predict_noise": {0: 'B', 2: "W", 3: 'H'}
    }
)

# Tối ưu hóa và lưu mô hình ONNX
unet_opt_graph = onnx.load(str(onnx_output_path))
unet_optimize_path = output_path / "unet_optimize"
unet_optimize_path.mkdir(parents=True, exist_ok=True)
unet_optimize_file = unet_optimize_path / "model.onnx"

onnx.save_model(
    unet_opt_graph,
    str(unet_optimize_file),  
    save_as_external_data=True, 
    all_tensors_to_one_file=True,  
    location="weights.pb", 
)