Spaces:
Runtime error
Runtime error
File size: 5,280 Bytes
96a9519 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import gc
import os
import torch
from diffusers import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import (
retrieve_latents,
)
from polygraphy import cuda
from ...pipeline import StreamDiffusion
from .builder import EngineBuilder, create_onnx_path
from .engine import AutoencoderKLEngine, UNet2DConditionModelEngine
from .models import VAE, BaseModel, UNet, VAEEncoder
class TorchVAEEncoder(torch.nn.Module):
def __init__(self, vae: AutoencoderKL):
super().__init__()
self.vae = vae
def forward(self, x: torch.Tensor):
return retrieve_latents(self.vae.encode(x))
def compile_vae_encoder(
vae: TorchVAEEncoder,
model_data: BaseModel,
onnx_path: str,
onnx_opt_path: str,
engine_path: str,
opt_batch_size: int = 1,
engine_build_options: dict = {},
):
builder = EngineBuilder(model_data, vae, device=torch.device("cuda"))
builder.build(
onnx_path,
onnx_opt_path,
engine_path,
opt_batch_size=opt_batch_size,
**engine_build_options,
)
def compile_vae_decoder(
vae: AutoencoderKL,
model_data: BaseModel,
onnx_path: str,
onnx_opt_path: str,
engine_path: str,
opt_batch_size: int = 1,
engine_build_options: dict = {},
):
vae = vae.to(torch.device("cuda"))
builder = EngineBuilder(model_data, vae, device=torch.device("cuda"))
builder.build(
onnx_path,
onnx_opt_path,
engine_path,
opt_batch_size=opt_batch_size,
**engine_build_options,
)
def compile_unet(
unet: UNet2DConditionModel,
model_data: BaseModel,
onnx_path: str,
onnx_opt_path: str,
engine_path: str,
opt_batch_size: int = 1,
engine_build_options: dict = {},
):
unet = unet.to(torch.device("cuda"), dtype=torch.float16)
builder = EngineBuilder(model_data, unet, device=torch.device("cuda"))
builder.build(
onnx_path,
onnx_opt_path,
engine_path,
opt_batch_size=opt_batch_size,
**engine_build_options,
)
def accelerate_with_tensorrt(
stream: StreamDiffusion,
engine_dir: str,
max_batch_size: int = 2,
min_batch_size: int = 1,
use_cuda_graph: bool = False,
engine_build_options: dict = {},
):
if "opt_batch_size" not in engine_build_options or engine_build_options["opt_batch_size"] is None:
engine_build_options["opt_batch_size"] = max_batch_size
text_encoder = stream.text_encoder
unet = stream.unet
vae = stream.vae
del stream.unet, stream.vae, stream.pipe.unet, stream.pipe.vae
vae_config = vae.config
vae_dtype = vae.dtype
unet.to(torch.device("cpu"))
vae.to(torch.device("cpu"))
gc.collect()
torch.cuda.empty_cache()
onnx_dir = os.path.join(engine_dir, "onnx")
os.makedirs(onnx_dir, exist_ok=True)
unet_engine_path = f"{engine_dir}/unet.engine"
vae_encoder_engine_path = f"{engine_dir}/vae_encoder.engine"
vae_decoder_engine_path = f"{engine_dir}/vae_decoder.engine"
unet_model = UNet(
fp16=True,
device=stream.device,
max_batch_size=max_batch_size,
min_batch_size=min_batch_size,
embedding_dim=text_encoder.config.hidden_size,
unet_dim=unet.config.in_channels,
)
vae_decoder_model = VAE(
device=stream.device,
max_batch_size=max_batch_size,
min_batch_size=min_batch_size,
)
vae_encoder_model = VAEEncoder(
device=stream.device,
max_batch_size=max_batch_size,
min_batch_size=min_batch_size,
)
if not os.path.exists(unet_engine_path):
compile_unet(
unet,
unet_model,
create_onnx_path("unet", onnx_dir, opt=False),
create_onnx_path("unet", onnx_dir, opt=True),
unet_engine_path,
**engine_build_options,
)
else:
del unet
if not os.path.exists(vae_decoder_engine_path):
vae.forward = vae.decode
compile_vae_decoder(
vae,
vae_decoder_model,
create_onnx_path("vae_decoder", onnx_dir, opt=False),
create_onnx_path("vae_decoder", onnx_dir, opt=True),
vae_decoder_engine_path,
**engine_build_options,
)
if not os.path.exists(vae_encoder_engine_path):
vae_encoder = TorchVAEEncoder(vae).to(torch.device("cuda"))
compile_vae_encoder(
vae_encoder,
vae_encoder_model,
create_onnx_path("vae_encoder", onnx_dir, opt=False),
create_onnx_path("vae_encoder", onnx_dir, opt=True),
vae_encoder_engine_path,
**engine_build_options,
)
del vae
cuda_steram = cuda.Stream()
stream.unet = UNet2DConditionModelEngine(unet_engine_path, cuda_steram, use_cuda_graph=use_cuda_graph)
stream.vae = AutoencoderKLEngine(
vae_encoder_engine_path,
vae_decoder_engine_path,
cuda_steram,
stream.pipe.vae_scale_factor,
use_cuda_graph=use_cuda_graph,
)
setattr(stream.vae, "config", vae_config)
setattr(stream.vae, "dtype", vae_dtype)
gc.collect()
torch.cuda.empty_cache()
return stream
|