Bethie commited on
Commit
a63a2f3
·
verified ·
1 Parent(s): 745d42a

Upload code quantize int8 ONNX weight.

Browse files
Files changed (1) hide show
  1. utilities.py +569 -0
utilities.py ADDED
@@ -0,0 +1,569 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #
2
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ #
17
+
18
+ from collections import OrderedDict
19
+ from cuda import cudart
20
+ from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
21
+ from diffusers.utils.torch_utils import randn_tensor
22
+ from enum import Enum, auto
23
+ import gc
24
+ from io import BytesIO
25
+ import numpy as np
26
+ import onnx
27
+ from onnx import numpy_helper
28
+ import onnx_graphsurgeon as gs
29
+ import os
30
+ from PIL import Image
31
+ from polygraphy.backend.common import bytes_from_path
32
+ from polygraphy.backend.trt import (
33
+ CreateConfig,
34
+ ModifyNetworkOutputs,
35
+ Profile,
36
+ engine_from_bytes,
37
+ engine_from_network,
38
+ network_from_onnx_path,
39
+ save_engine
40
+ )
41
+ import random
42
+ import re
43
+ import requests
44
+ from scipy import integrate
45
+ import tensorrt as trt
46
+ import torch
47
+ import types
48
+
49
+ TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
50
+
51
+ # Map of numpy dtype -> torch dtype
52
+ numpy_to_torch_dtype_dict = {
53
+ np.uint8 : torch.uint8,
54
+ np.int8 : torch.int8,
55
+ np.int16 : torch.int16,
56
+ np.int32 : torch.int32,
57
+ np.int64 : torch.int64,
58
+ np.float16 : torch.float16,
59
+ np.float32 : torch.float32,
60
+ np.float64 : torch.float64,
61
+ np.complex64 : torch.complex64,
62
+ np.complex128 : torch.complex128
63
+ }
64
+ if np.version.full_version >= "1.24.0":
65
+ numpy_to_torch_dtype_dict[np.bool_] = torch.bool
66
+ else:
67
+ numpy_to_torch_dtype_dict[np.bool] = torch.bool
68
+
69
+ # Map of torch dtype -> numpy dtype
70
+ torch_to_numpy_dtype_dict = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()}
71
+
72
+ def unload_model(model):
73
+ if model:
74
+ del model
75
+ torch.cuda.empty_cache()
76
+ gc.collect()
77
+
78
+ def replace_lora_layers(model):
79
+ def lora_forward(self, x, scale=None):
80
+ return self._torch_forward(x)
81
+
82
+ for name, module in model.named_modules():
83
+ if isinstance(module, LoRACompatibleConv):
84
+ in_channels = module.in_channels
85
+ out_channels = module.out_channels
86
+ kernel_size = module.kernel_size
87
+ stride = module.stride
88
+ padding = module.padding
89
+ dilation = module.dilation
90
+ groups = module.groups
91
+ bias = module.bias
92
+
93
+ new_conv = torch.nn.Conv2d(
94
+ in_channels,
95
+ out_channels,
96
+ kernel_size,
97
+ stride=stride,
98
+ padding=padding,
99
+ dilation=dilation,
100
+ groups=groups,
101
+ bias=bias is not None,
102
+ )
103
+
104
+ new_conv.weight.data = module.weight.data.clone().to(module.weight.data.device)
105
+ if bias is not None:
106
+ new_conv.bias.data = module.bias.data.clone().to(module.bias.data.device)
107
+
108
+ # Replace the LoRACompatibleConv layer with the Conv2d layer
109
+ path = name.split(".")
110
+ sub_module = model
111
+ for p in path[:-1]:
112
+ sub_module = getattr(sub_module, p)
113
+ setattr(sub_module, path[-1], new_conv)
114
+ new_conv._torch_forward = new_conv.forward
115
+ new_conv.forward = types.MethodType(lora_forward, new_conv)
116
+
117
+ elif isinstance(module, LoRACompatibleLinear):
118
+ in_features = module.in_features
119
+ out_features = module.out_features
120
+ bias = module.bias
121
+
122
+ new_linear = torch.nn.Linear(in_features, out_features, bias=bias is not None)
123
+
124
+ new_linear.weight.data = module.weight.data.clone().to(module.weight.data.device)
125
+ if bias is not None:
126
+ new_linear.bias.data = module.bias.data.clone().to(module.bias.data.device)
127
+
128
+ # Replace the LoRACompatibleLinear layer with the Linear layer
129
+ path = name.split(".")
130
+ sub_module = model
131
+ for p in path[:-1]:
132
+ sub_module = getattr(sub_module, p)
133
+ setattr(sub_module, path[-1], new_linear)
134
+ new_linear._torch_forward = new_linear.forward
135
+ new_linear.forward = types.MethodType(lora_forward, new_linear)
136
+
137
+ def merge_loras(model, lora_dict, lora_alphas, lora_scales):
138
+ assert len(lora_scales) == len(lora_dict)
139
+ for path, lora in lora_dict.items():
140
+ print(f"[I] Fusing LoRA: {path}, scale {lora_scales[path]}")
141
+ model.load_attn_procs(lora, network_alphas=lora_alphas[path])
142
+ model.fuse_lora(lora_scale=lora_scales[path])
143
+ return model
144
+
145
+ def CUASSERT(cuda_ret):
146
+ err = cuda_ret[0]
147
+ if err != cudart.cudaError_t.cudaSuccess:
148
+ raise RuntimeError(f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t")
149
+ if len(cuda_ret) > 1:
150
+ return cuda_ret[1]
151
+ return None
152
+
153
+ class PIPELINE_TYPE(Enum):
154
+ TXT2IMG = auto()
155
+ IMG2IMG = auto()
156
+ INPAINT = auto()
157
+ CONTROLNET = auto()
158
+ XL_BASE = auto()
159
+ XL_REFINER = auto()
160
+
161
+ def is_txt2img(self):
162
+ return self == self.TXT2IMG
163
+
164
+ def is_img2img(self):
165
+ return self == self.IMG2IMG
166
+
167
+ def is_inpaint(self):
168
+ return self == self.INPAINT
169
+
170
+ def is_controlnet(self):
171
+ return self == self.CONTROLNET
172
+
173
+ def is_sd_xl_base(self):
174
+ return self == self.XL_BASE
175
+
176
+ def is_sd_xl_refiner(self):
177
+ return self == self.XL_REFINER
178
+
179
+ def is_sd_xl(self):
180
+ return self.is_sd_xl_base() or self.is_sd_xl_refiner()
181
+
182
+ class Engine():
183
+ def __init__(
184
+ self,
185
+ engine_path,
186
+ ):
187
+ self.engine_path = engine_path
188
+ self.engine = None
189
+ self.context = None
190
+ self.buffers = OrderedDict()
191
+ self.tensors = OrderedDict()
192
+ self.cuda_graph_instance = None # cuda graph
193
+
194
+ def __del__(self):
195
+ del self.engine
196
+ del self.context
197
+ del self.buffers
198
+ del self.tensors
199
+
200
+ def refit(self, refit_weights, is_fp16):
201
+ # Initialize refitter
202
+ refitter = trt.Refitter(self.engine, TRT_LOGGER)
203
+
204
+ refitted_weights = set()
205
+ # iterate through all tensorrt refittable weights
206
+ for trt_weight_name in refitter.get_all_weights():
207
+ if trt_weight_name not in refit_weights:
208
+ continue
209
+
210
+ # get weight from state dict
211
+ trt_datatype = trt.DataType.FLOAT
212
+ if is_fp16:
213
+ refit_weights[trt_weight_name] = refit_weights[trt_weight_name].half()
214
+ trt_datatype = trt.DataType.HALF
215
+
216
+ # trt.Weight and trt.TensorLocation
217
+ trt_wt_tensor = trt.Weights(trt_datatype, refit_weights[trt_weight_name].data_ptr(), torch.numel(refit_weights[trt_weight_name]))
218
+ trt_wt_location = trt.TensorLocation.DEVICE if refit_weights[trt_weight_name].is_cuda else trt.TensorLocation.HOST
219
+
220
+ # apply refit
221
+ refitter.set_named_weights(trt_weight_name, trt_wt_tensor, trt_wt_location)
222
+ refitted_weights.add(trt_weight_name)
223
+
224
+ assert set(refitted_weights) == set(refit_weights.keys())
225
+ if not refitter.refit_cuda_engine():
226
+ print("Error: failed to refit new weights.")
227
+ exit(0)
228
+
229
+ print(f"[I] Total refitted weights {len(refitted_weights)}.")
230
+
231
+ def build(self,
232
+ onnx_path,
233
+ fp16=True,
234
+ tf32=False,
235
+ int8=False,
236
+ input_profile=None,
237
+ enable_refit=False,
238
+ enable_all_tactics=False,
239
+ timing_cache=None,
240
+ update_output_names=None,
241
+ **extra_build_args
242
+ ):
243
+ print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
244
+ p = Profile()
245
+ if input_profile:
246
+ for name, dims in input_profile.items():
247
+ assert len(dims) == 3
248
+ p.add(name, min=dims[0], opt=dims[1], max=dims[2])
249
+
250
+ if not enable_all_tactics:
251
+ extra_build_args['tactic_sources'] = []
252
+
253
+ network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM])
254
+ if update_output_names:
255
+ print(f"Updating network outputs to {update_output_names}")
256
+ network = ModifyNetworkOutputs(network, update_output_names)
257
+ engine = engine_from_network(
258
+ network,
259
+ config=CreateConfig(fp16=fp16,
260
+ tf32=tf32,
261
+ int8=int8,
262
+ refittable=enable_refit,
263
+ profiles=[p],
264
+ load_timing_cache=timing_cache,
265
+ **extra_build_args
266
+ ),
267
+ save_timing_cache=timing_cache
268
+ )
269
+ save_engine(engine, path=self.engine_path)
270
+
271
+ def load(self):
272
+ print(f"Loading TensorRT engine: {self.engine_path}")
273
+ self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
274
+
275
+ def activate(self, reuse_device_memory=None):
276
+ if reuse_device_memory:
277
+ self.context = self.engine.create_execution_context_without_device_memory()
278
+ self.context.device_memory = reuse_device_memory
279
+ else:
280
+ self.context = self.engine.create_execution_context()
281
+
282
+ def allocate_buffers(self, shape_dict=None, device='cuda'):
283
+ for idx in range(self.engine.num_io_tensors):
284
+ binding = self.engine[idx]
285
+ if shape_dict and binding in shape_dict:
286
+ shape = shape_dict[binding]
287
+ else:
288
+ shape = self.engine.get_binding_shape(binding)
289
+ dtype = trt.nptype(self.engine.get_binding_dtype(binding))
290
+ if self.engine.binding_is_input(binding):
291
+ self.context.set_binding_shape(idx, shape)
292
+ tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
293
+ self.tensors[binding] = tensor
294
+
295
+ def infer(self, feed_dict, stream, use_cuda_graph=False):
296
+
297
+ for name, buf in feed_dict.items():
298
+ self.tensors[name].copy_(buf)
299
+
300
+ for name, tensor in self.tensors.items():
301
+ self.context.set_tensor_address(name, tensor.data_ptr())
302
+
303
+ if use_cuda_graph:
304
+ if self.cuda_graph_instance is not None:
305
+ CUASSERT(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream))
306
+ CUASSERT(cudart.cudaStreamSynchronize(stream))
307
+ else:
308
+ # do inference before CUDA graph capture
309
+ noerror = self.context.execute_async_v3(stream)
310
+ if not noerror:
311
+ raise ValueError(f"ERROR: inference failed.")
312
+ # capture cuda graph
313
+ CUASSERT(cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal))
314
+ self.context.execute_async_v3(stream)
315
+ self.graph = CUASSERT(cudart.cudaStreamEndCapture(stream))
316
+ self.cuda_graph_instance = CUASSERT(cudart.cudaGraphInstantiate(self.graph, 0))
317
+ else:
318
+ noerror = self.context.execute_async_v3(stream)
319
+ if not noerror:
320
+ raise ValueError(f"ERROR: inference failed.")
321
+
322
+ return self.tensors
323
+
324
+ def save_image(images, image_path_dir, image_name_prefix):
325
+ """
326
+ Save the generated images to png files.
327
+ """
328
+ images = ((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
329
+ for i in range(images.shape[0]):
330
+ image_path = os.path.join(image_path_dir, image_name_prefix+str(i+1)+'-'+str(random.randint(1000,9999))+'.png')
331
+ print(f"Saving image {i+1} / {images.shape[0]} to: {image_path}")
332
+ Image.fromarray(images[i]).save(image_path)
333
+
334
+ def preprocess_image(image):
335
+ """
336
+ image: torch.Tensor
337
+ """
338
+ w, h = image.size
339
+ w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
340
+ image = image.resize((w, h))
341
+ image = np.array(image).astype(np.float32) / 255.0
342
+ image = image[None].transpose(0, 3, 1, 2)
343
+ image = torch.from_numpy(image).contiguous()
344
+ return 2.0 * image - 1.0
345
+
346
+ def prepare_mask_and_masked_image(image, mask):
347
+ """
348
+ image: PIL.Image.Image
349
+ mask: PIL.Image.Image
350
+ """
351
+ if isinstance(image, Image.Image):
352
+ image = np.array(image.convert("RGB"))
353
+ image = image[None].transpose(0, 3, 1, 2)
354
+ image = torch.from_numpy(image).to(dtype=torch.float32).contiguous() / 127.5 - 1.0
355
+ if isinstance(mask, Image.Image):
356
+ mask = np.array(mask.convert("L"))
357
+ mask = mask.astype(np.float32) / 255.0
358
+ mask = mask[None, None]
359
+ mask[mask < 0.5] = 0
360
+ mask[mask >= 0.5] = 1
361
+ mask = torch.from_numpy(mask).to(dtype=torch.float32).contiguous()
362
+
363
+ masked_image = image * (mask < 0.5)
364
+
365
+ return mask, masked_image
366
+
367
+ def download_image(url):
368
+ response = requests.get(url)
369
+ return Image.open(BytesIO(response.content)).convert("RGB")
370
+
371
+ def get_refit_weights(state_dict, onnx_opt_path, weight_name_mapping, weight_shape_mapping):
372
+ onnx_opt_dir = os.path.dirname(onnx_opt_path)
373
+ onnx_opt_model = onnx.load(onnx_opt_path)
374
+ # Create initializer data hashes
375
+ initializer_hash_mapping = {}
376
+ for initializer in onnx_opt_model.graph.initializer:
377
+ initializer_data = numpy_helper.to_array(initializer, base_dir=onnx_opt_dir).astype(np.float16)
378
+ initializer_hash = hash(initializer_data.data.tobytes())
379
+ initializer_hash_mapping[initializer.name] = initializer_hash
380
+
381
+ refit_weights = OrderedDict()
382
+ for wt_name, wt in state_dict.items():
383
+ # query initializer to compare
384
+ initializer_name = weight_name_mapping[wt_name]
385
+ initializer_hash = initializer_hash_mapping[initializer_name]
386
+
387
+ # get shape transform info
388
+ initializer_shape, is_transpose = weight_shape_mapping[wt_name]
389
+ if is_transpose:
390
+ wt = torch.transpose(wt, 0, 1)
391
+ else:
392
+ wt = torch.reshape(wt, initializer_shape)
393
+
394
+ # include weight if hashes differ
395
+ wt_hash = hash(wt.cpu().detach().numpy().astype(np.float16).data.tobytes())
396
+ if initializer_hash != wt_hash:
397
+ refit_weights[initializer_name] = wt.contiguous()
398
+ return refit_weights
399
+
400
+ def load_calib_prompts(batch_size, calib_data_path):
401
+ with open(calib_data_path, "r") as file:
402
+ lst = [line.rstrip("\n") for line in file]
403
+ return [lst[i : i + batch_size] for i in range(0, len(lst), batch_size)]
404
+
405
+ def filter_func(name):
406
+ pattern = re.compile(
407
+ r".*(time_emb_proj|time_embedding|conv_in|conv_out|conv_shortcut|add_embedding).*"
408
+ )
409
+ return pattern.match(name) is not None
410
+
411
+ def quantize_lvl(unet, quant_level=2.5):
412
+ """
413
+ We should disable the unwanted quantizer when exporting the onnx
414
+ Because in the current ammo setting, it will load the quantizer amax for all the layers even
415
+ if we didn't add that unwanted layer into the config during the calibration
416
+ """
417
+ for name, module in unet.named_modules():
418
+ if isinstance(module, torch.nn.Conv2d):
419
+ module.input_quantizer.enable()
420
+ module.weight_quantizer.enable()
421
+ elif isinstance(module, torch.nn.Linear):
422
+ if (
423
+ (quant_level >= 2 and "ff.net" in name)
424
+ or (quant_level >= 2.5 and ("to_q" in name or "to_k" in name or "to_v" in name))
425
+ or quant_level == 3
426
+ ):
427
+ module.input_quantizer.enable()
428
+ module.weight_quantizer.enable()
429
+ else:
430
+ module.input_quantizer.disable()
431
+ module.weight_quantizer.disable()
432
+
433
+ def get_smoothquant_config(model, quant_level=3):
434
+ quant_config = {
435
+ "quant_cfg": {},
436
+ "algorithm": "smoothquant",
437
+ }
438
+ for name, module in model.named_modules():
439
+ w_name = f"{name}*weight_quantizer"
440
+ i_name = f"{name}*input_quantizer"
441
+
442
+ if (
443
+ w_name in quant_config["quant_cfg"].keys() # type: ignore
444
+ or i_name in quant_config["quant_cfg"].keys() # type: ignore
445
+ ):
446
+ continue
447
+ if filter_func(name):
448
+ continue
449
+ if isinstance(module, torch.nn.Linear):
450
+ if (
451
+ (quant_level >= 2 and "ff.net" in name)
452
+ or (quant_level >= 2.5 and ("to_q" in name or "to_k" in name or "to_v" in name))
453
+ or quant_level == 3
454
+ ):
455
+ quant_config["quant_cfg"][w_name] = {"num_bits": 8, "axis": 0} # type: ignore
456
+ quant_config["quant_cfg"][i_name] = {"num_bits": 8, "axis": -1} # type: ignore
457
+ elif isinstance(module, torch.nn.Conv2d):
458
+ quant_config["quant_cfg"][w_name] = {"num_bits": 8, "axis": 0} # type: ignore
459
+ quant_config["quant_cfg"][i_name] = {"num_bits": 8, "axis": None} # type: ignore
460
+ return quant_config
461
+
462
+ class PercentileAmaxes:
463
+ def __init__(self, total_step, percentile) -> None:
464
+ self.data = {}
465
+ self.total_step = total_step
466
+ self.percentile = percentile
467
+ self.i = 0
468
+
469
+ def append(self, item):
470
+ _cur_step = self.i % self.total_step
471
+ if _cur_step not in self.data.keys():
472
+ self.data[_cur_step] = item
473
+ else:
474
+ self.data[_cur_step] = np.maximum(self.data[_cur_step], item)
475
+ self.i += 1
476
+
477
+ def add_arguments(parser):
478
+ # Stable Diffusion configuration
479
+ parser.add_argument('--version', type=str, default="1.5", choices=["1.4", "1.5", "dreamshaper-7", "2.0-base", "2.0", "2.1-base", "2.1", "xl-1.0", "xl-turbo"], help="Version of Stable Diffusion")
480
+ parser.add_argument('prompt', nargs = '*', help="Text prompt(s) to guide image generation")
481
+ parser.add_argument('--negative-prompt', nargs = '*', default=[''], help="The negative prompt(s) to guide the image generation.")
482
+ parser.add_argument('--batch-size', type=int, default=1, choices=[1, 2, 4], help="Batch size (repeat prompt)")
483
+ parser.add_argument('--batch-count', type=int, default=1, help="Number of images to generate in sequence, one at a time.")
484
+ parser.add_argument('--height', type=int, default=512, help="Height of image to generate (must be multiple of 8)")
485
+ parser.add_argument('--width', type=int, default=512, help="Height of image to generate (must be multiple of 8)")
486
+ parser.add_argument('--denoising-steps', type=int, default=30, help="Number of denoising steps")
487
+ parser.add_argument('--scheduler', type=str, default=None, choices=["DDIM", "DDPM", "EulerA", "Euler", "LCM", "LMSD", "PNDM", "UniPC"], help="Scheduler for diffusion process")
488
+ parser.add_argument('--guidance-scale', type=float, default=7.5, help="Value of classifier-free guidance scale (must be greater than 1)")
489
+ parser.add_argument('--lora-scale', type=float, nargs='+', default=None, help="Scale of LoRA weights, default 1 (must between 0 and 1)")
490
+ parser.add_argument('--lora-path', type=str, nargs='+', default=None, help="Path to LoRA adaptor. Ex: 'latent-consistency/lcm-lora-sdv1-5'")
491
+
492
+ # ONNX export
493
+ parser.add_argument('--onnx-opset', type=int, default=18, choices=range(7,19), help="Select ONNX opset version to target for exported models")
494
+ parser.add_argument('--onnx-dir', default='onnx', help="Output directory for ONNX export")
495
+
496
+ # Framework model ckpt
497
+ parser.add_argument('--framework-model-dir', default='pytorch_model', help="Directory for HF saved models")
498
+
499
+ # TensorRT engine build
500
+ parser.add_argument('--engine-dir', default='engine', help="Output directory for TensorRT engines")
501
+ parser.add_argument('--int8', action='store_true', help="Apply int8 quantization.")
502
+ parser.add_argument('--quantization-level', type=float, default=3.0, choices=range(1,4), help="int8/fp8 quantization level, 1: CNN, 2: CNN+FFN, 2.5: CNN+FFN+QKV, 3: CNN+FC")
503
+ parser.add_argument('--build-static-batch', action='store_true', help="Build TensorRT engines with fixed batch size.")
504
+ parser.add_argument('--build-dynamic-shape', action='store_true', help="Build TensorRT engines with dynamic image shapes.")
505
+ parser.add_argument('--build-enable-refit', action='store_true', help="Enable Refit option in TensorRT engines during build.")
506
+ parser.add_argument('--build-all-tactics', action='store_true', help="Build TensorRT engines using all tactic sources.")
507
+ parser.add_argument('--timing-cache', default=None, type=str, help="Path to the precached timing measurements to accelerate build.")
508
+
509
+ # TensorRT inference
510
+ parser.add_argument('--num-warmup-runs', type=int, default=5, help="Number of warmup runs before benchmarking performance")
511
+ parser.add_argument('--use-cuda-graph', action='store_true', help="Enable cuda graph")
512
+ parser.add_argument('--nvtx-profile', action='store_true', help="Enable NVTX markers for performance profiling")
513
+ parser.add_argument('--torch-inference', default='', help="Run inference with PyTorch (using specified compilation mode) instead of TensorRT.")
514
+
515
+ parser.add_argument('--seed', type=int, default=None, help="Seed for random generator to get consistent results")
516
+ parser.add_argument('--output-dir', default='output', help="Output directory for logs and image artifacts")
517
+ parser.add_argument('--hf-token', type=str, help="HuggingFace API access token for downloading model checkpoints")
518
+ parser.add_argument('-v', '--verbose', action='store_true', help="Show verbose output")
519
+ return parser
520
+
521
+ def process_pipeline_args(args):
522
+ if args.height % 8 != 0 or args.width % 8 != 0:
523
+ raise ValueError(f"Image height and width have to be divisible by 8 but specified as: {args.image_height} and {args.width}.")
524
+
525
+ max_batch_size = 4
526
+ if args.batch_size > max_batch_size:
527
+ raise ValueError(f"Batch size {args.batch_size} is larger than allowed {max_batch_size}.")
528
+
529
+ if args.use_cuda_graph and (not args.build_static_batch or args.build_dynamic_shape):
530
+ raise ValueError(f"Using CUDA graph requires static dimensions. Enable `--build-static-batch` and do not specify `--build-dynamic-shape`")
531
+
532
+ if args.int8 and not args.version.startswith('xl'):
533
+ raise ValueError(f"int8 quantization only supported for SDXL pipeline.")
534
+
535
+ kwargs_init_pipeline = {
536
+ 'version': args.version,
537
+ 'max_batch_size': max_batch_size,
538
+ 'denoising_steps': args.denoising_steps,
539
+ 'scheduler': args.scheduler,
540
+ 'guidance_scale': args.guidance_scale,
541
+ 'output_dir': args.output_dir,
542
+ 'hf_token': args.hf_token,
543
+ 'verbose': args.verbose,
544
+ 'nvtx_profile': args.nvtx_profile,
545
+ 'use_cuda_graph': args.use_cuda_graph,
546
+ 'lora_scale': args.lora_scale,
547
+ 'lora_path': args.lora_path,
548
+ 'framework_model_dir': args.framework_model_dir,
549
+ 'torch_inference': args.torch_inference,
550
+ }
551
+
552
+ kwargs_load_engine = {
553
+ 'onnx_opset': args.onnx_opset,
554
+ 'opt_batch_size': args.batch_size,
555
+ 'opt_image_height': args.height,
556
+ 'opt_image_width': args.width,
557
+ 'static_batch': args.build_static_batch,
558
+ 'static_shape': not args.build_dynamic_shape,
559
+ 'enable_all_tactics': args.build_all_tactics,
560
+ 'enable_refit': args.build_enable_refit,
561
+ 'timing_cache': args.timing_cache,
562
+ 'int8': args.int8,
563
+ 'quantization_level': args.quantization_level,
564
+ 'denoising_steps': args.denoising_steps,
565
+ }
566
+
567
+ args_run_demo = (args.prompt, args.negative_prompt, args.height, args.width, args.batch_size, args.batch_count, args.num_warmup_runs, args.use_cuda_graph)
568
+
569
+ return kwargs_init_pipeline, kwargs_load_engine, args_run_demo