File size: 5,455 Bytes
3de498f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
from typing import Dict, Optional, Tuple, OrderedDict
from transformers import CLIPTextConfig
from diffusers import UNet2DConditionModel
import torch
from optimum.exporters.onnx.model_configs import VisionOnnxConfig, NormalizedConfig, DummyVisionInputGenerator, DummyTimestepInputGenerator, DummySeq2SeqDecoderTextInputGenerator, DummySeq2SeqDecoderTextInputGenerator
from optimum.exporters.openvino import main_export
from optimum.utils.input_generators import DummyInputGenerator, DEFAULT_DUMMY_SHAPES
from optimum.utils.normalized_config import NormalizedTextConfig
# IMPORTANT: You need to specify some scheduler in downloaded model cache folder to avoid errors
class CustomDummyTimestepInputGenerator(DummyInputGenerator):
"""
Generates dummy time step inputs.
"""
SUPPORTED_INPUT_NAMES = (
"timestep",
"timestep_cond",
"text_embeds",
"time_ids",
)
def __init__(
self,
task: str,
normalized_config: NormalizedConfig,
batch_size: int = DEFAULT_DUMMY_SHAPES["batch_size"],
time_cond_proj_dim: int = 256,
random_batch_size_range: Optional[Tuple[int, int]] = None,
**kwargs,
):
self.task = task
self.vocab_size = normalized_config.vocab_size
self.text_encoder_projection_dim = normalized_config.text_encoder_projection_dim
self.time_ids = 5 if normalized_config.requires_aesthetics_score else 6
if random_batch_size_range:
low, high = random_batch_size_range
self.batch_size = random.randint(low, high)
else:
self.batch_size = batch_size
self.time_cond_proj_dim = normalized_config.get("time_cond_proj_dim", time_cond_proj_dim)
def generate(self, input_name: str, framework: str = "pt", int_dtype: str = "int64", float_dtype: str = "fp32"):
shape = [self.batch_size]
if input_name == "timestep":
return self.random_int_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=int_dtype)
if input_name == "timestep_cond":
shape.append(self.time_cond_proj_dim)
return self.random_float_tensor(shape, min_value=-1.0, max_value=1.0, framework=framework, dtype=float_dtype)
shape.append(self.text_encoder_projection_dim if input_name == "text_embeds" else self.time_ids)
return self.random_float_tensor(shape, max_value=self.vocab_size, framework=framework, dtype=float_dtype)
class LCMUNetOnnxConfig(VisionOnnxConfig):
ATOL_FOR_VALIDATION = 1e-3
# The ONNX export of a CLIPText architecture, an other Stable Diffusion component, needs the Trilu
# operator support, available since opset 14
DEFAULT_ONNX_OPSET = 14
NORMALIZED_CONFIG_CLASS = NormalizedConfig.with_args(
image_size="sample_size",
num_channels="in_channels",
hidden_size="cross_attention_dim",
vocab_size="norm_num_groups",
allow_new=True,
)
DUMMY_INPUT_GENERATOR_CLASSES = (
DummyVisionInputGenerator,
CustomDummyTimestepInputGenerator,
DummySeq2SeqDecoderTextInputGenerator,
)
@property
def inputs(self) -> Dict[str, Dict[int, str]]:
common_inputs = OrderedDict({
"sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
"timestep": {0: "steps"},
"encoder_hidden_states": {0: "batch_size", 1: "sequence_length"},
"timestep_cond": {0: "batch_size"},
})
# TODO : add text_image, image and image_embeds
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
common_inputs["text_embeds"] = {0: "batch_size"}
common_inputs["time_ids"] = {0: "batch_size"}
return common_inputs
@property
def outputs(self) -> Dict[str, Dict[int, str]]:
return {
"out_sample": {0: "batch_size", 1: "num_channels", 2: "height", 3: "width"},
}
@property
def torch_to_onnx_output_map(self) -> Dict[str, str]:
return {
"sample": "out_sample",
}
def generate_dummy_inputs(self, framework: str = "pt", **kwargs):
dummy_inputs = super().generate_dummy_inputs(framework=framework, **kwargs)
dummy_inputs["encoder_hidden_states"] = dummy_inputs["encoder_hidden_states"][0]
if getattr(self._normalized_config, "addition_embed_type", None) == "text_time":
dummy_inputs["added_cond_kwargs"] = {
"text_embeds": dummy_inputs.pop("text_embeds"),
"time_ids": dummy_inputs.pop("time_ids"),
}
return dummy_inputs
def ordered_inputs(self, model) -> Dict[str, Dict[int, str]]:
return self.inputs # Breaks order if timestep_cond involved ( so just copy original one )
model_id = "SimianLuo/LCM_Dreamshaper_v7"
text_encoder_config = CLIPTextConfig.from_pretrained(model_id, subfolder = "text_encoder")
unet_config = UNet2DConditionModel.from_pretrained(model_id, subfolder = "unet").config
unet_config.text_encoder_projection_dim = text_encoder_config.projection_dim
unet_config.requires_aesthetics_score = False
custom_onnx_configs = {
"unet": LCMUNetOnnxConfig(config = unet_config, task = "semantic-segmentation")
}
main_export(model_name_or_path = model_id, output = "./", task = "stable-diffusion", fp16 = False, int8 = False, custom_onnx_configs = custom_onnx_configs)
|