Quantize-Calibration-int8 / quantize_int8_test.py
Bethie's picture
Upload code quantize int8 ONNX weight.
745d42a verified
raw
history blame
3.65 kB
import os
import torch
import onnx
from pathlib import Path
from diffusers import DiffusionPipeline, StableDiffusionPipeline
import torch
from utilities import load_calib_prompts
from utilities import get_smoothquant_config
import ammo.torch.quantization as atq
import ammo.torch.opt as ato
from utilities import filter_func, quantize_lvl
# pipeline = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
# torch_dtype=torch.float16,
# use_safetensors=True,
# variant="fp16")
pipeline = StableDiffusionPipeline.from_pretrained("wyyadd/sd-1.5", torch_dtype=torch.float16)
pipeline.to("cuda")
# pipeline.enable_xformers_memory_efficient_attention()
# pipeline.enable_vae_slicing()
BATCH_SIZE = 4
cali_prompts = load_calib_prompts(batch_size=BATCH_SIZE, calib_data_path="./calibration-prompts.txt")
quant_config = get_smoothquant_config(pipeline.unet, quant_level=3.0)
def do_calibrate(base, calibration_prompts, **kwargs):
for i_th, prompts in enumerate(calibration_prompts):
print(prompts)
if i_th >= kwargs["calib_size"]:
return
base(
prompt=prompts,
num_inference_steps=kwargs["n_steps"],
negative_prompt=[
"normal quality, low quality, worst quality, low res, blurry, nsfw, nude"
]
* len(prompts),
).images
def calibration_loop():
do_calibrate(
base=pipeline,
calibration_prompts=cali_prompts,
calib_size=384,
n_steps=50,
)
quantized_model = atq.quantize(pipeline.unet, quant_config, forward_loop = calibration_loop)
ato.save(quantized_model, 'base.unet15_2.int8.pt')
quantize_lvl(quantized_model, quant_level=3.0)
atq.disable_quantizer(quantized_model, filter_func)
device1 = "cpu"
quantized_model = quantized_model.to(torch.float32).to(device1)
#Export model
sample = torch.randn((1, 4, 128, 128), dtype=torch.float32, device=device1)
timestep = torch.rand(1, dtype=torch.float32, device=device1)
encoder_hidden_state = torch.randn((1, 77, 768), dtype=torch.float32, device=device1)
import onnx
from pathlib import Path
output_path = Path('/home/tiennv/trang/Convert-_Unet_int8_Rebuild/Diffusion/onnx_unet15')
output_path.mkdir(parents=True, exist_ok=True)
dummy_inputs = (sample, timestep, encoder_hidden_state)
onnx_output_path = output_path / "unet" / "model.onnx"
onnx_output_path.parent.mkdir(parents=True, exist_ok=True)
# to cpu to export onnx
# from onnx_utils import ammo_export_sd
# base.unet.to(torch.float32).to("cpu")
# ammo_export_sd(base, 'onnx_dir', 'stabilityai/stable-diffusion-xl-base-1.0')
torch.onnx.export(
quantized_model,
dummy_inputs,
str(onnx_output_path),
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=['sample', 'timestep', 'encoder_hidden_state'],
output_names=['predict_noise'],
dynamic_axes={
"sample": {0: "B", 2: "W", 3: 'H'},
"encoder_hidden_state": {0: "B", 1: "S", 2: 'D'},
"predict_noise": {0: 'B', 2: "W", 3: 'H'}
}
)
# Tối ưu hóa và lưu mô hình ONNX
unet_opt_graph = onnx.load(str(onnx_output_path))
unet_optimize_path = output_path / "unet_optimize"
unet_optimize_path.mkdir(parents=True, exist_ok=True)
unet_optimize_file = unet_optimize_path / "model.onnx"
onnx.save_model(
unet_opt_graph,
str(unet_optimize_file),
save_as_external_data=True,
all_tensors_to_one_file=True,
location="weights.pb",
)