diffuse_edge1_acc / src /pipeline.py
manbeast3b's picture
Update src/pipeline.py
94ae420 verified
from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
from huggingface_hub.constants import HF_HUB_CACHE
from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel
import torch
import torch._dynamo
import gc
from PIL import Image as img
from PIL.Image import Image
from pipelines.models import TextToImageRequest
from torch import Generator
import time
from diffusers import DiffusionPipeline
from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only
import os
os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True"
import torch
import math
from typing import Type, Dict, Any, Tuple, Callable, Optional, Union
import ghanta
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
from diffusers.models.attention import FeedForward
from diffusers.models.attention_processor import (
Attention,
AttentionProcessor,
FluxAttnProcessor2_0,
FusedFluxAttnProcessor2_0,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
from diffusers.utils.import_utils import is_torch_npu_available
from diffusers.utils.torch_utils import maybe_allow_in_graph
from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
from diffusers.models.modeling_outputs import Transformer2DModelOutput
class BasicQuantization:
def __init__(self, bits=1):
self.bits = bits
self.qmin = -(2**(bits-1))
self.qmax = 2**(bits-1) - 1
def quantize_tensor(self, tensor):
scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin)
zero_point = self.qmin - torch.round(tensor.min() / scale)
qtensor = torch.round(tensor / scale + zero_point)
qtensor = torch.clamp(qtensor, self.qmin, self.qmax)
return (qtensor - zero_point) * scale, scale, zero_point
class ModelQuantization:
def __init__(self, model, bits=7):
self.model = model
self.quant = BasicQuantization(bits)
def quantize_model(self):
for name, module in self.model.named_modules():
if isinstance(module, torch.nn.Linear):
if hasattr(module, 'weightML'):
quantized_weight, _, _ = self.quant.quantize_tensor(module.weight)
module.weight = torch.nn.Parameter(quantized_weight)
if hasattr(module, 'bias') and module.bias is not None:
quantized_bias, _, _ = self.quant.quantize_tensor(module.bias)
module.bias = torch.nn.Parameter(quantized_bias)
def inicializar_generador(dispositivo: torch.device, respaldo: torch.Generator = None):
if dispositivo.type == "cpu":
return torch.Generator(device="cpu").set_state(torch.get_rng_state())
elif dispositivo.type == "cuda":
return torch.Generator(device=dispositivo).set_state(torch.cuda.get_rng_state())
else:
if respaldo is None:
return inicializar_generador(torch.device("cpu"))
else:
return respaldo
def calcular_fusion(x: torch.Tensor, info_tome: Dict[str, Any]) -> Tuple[Callable, ...]:
alto_original, ancho_original = info_tome["size"]
tokens_originales = alto_original * ancho_original
submuestreo = int(math.ceil(math.sqrt(tokens_originales // x.shape[1])))
argumentos = info_tome["args"]
if submuestreo <= argumentos["down"]:
ancho = int(math.ceil(ancho_original / submuestreo))
alto = int(math.ceil(alto_original / submuestreo))
radio = int(x.shape[1] * argumentos["ratio"])
if argumentos["generator"] is None:
argumentos["generator"] = inicializar_generador(x.device)
elif argumentos["generator"].device != x.device:
argumentos["generator"] = inicializar_generador(x.device, respaldo=argumentos["generator"])
usar_aleatoriedad = argumentos["rando"]
fusion, desfusion = ghanta.emparejamiento_suave_aleatorio_2d(
x, ancho, alto, argumentos["sx"], argumentos["sy"], radio,
sin_aleatoriedad=not usar_aleatoriedad, generador=argumentos["generator"]
)
else:
fusion, desfusion = (hacer_nada, hacer_nada)
fusion_a, desfusion_a = (fusion, desfusion) if argumentos["m1"] else (hacer_nada, hacer_nada)
fusion_c, desfusion_c = (fusion, desfusion) if argumentos["m2"] else (hacer_nada, hacer_nada)
fusion_m, desfusion_m = (fusion, desfusion) if argumentos["m3"] else (hacer_nada, hacer_nada)
return fusion_a, fusion_c, fusion_m, desfusion_a, desfusion_c, desfusion_m
@torch.compile
@maybe_allow_in_graph
class FluxSingleTransformerBlock(nn.Module):
def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
super().__init__()
self.mlp_hidden_dim = int(dim * mlp_ratio)
self.norm = AdaLayerNormZeroSingle(dim)
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
self.act_mlp = nn.GELU(approximate="tanh")
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
processor = FluxAttnProcessor2_0()
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
bias=True,
processor=processor,
qk_norm="rms_norm",
eps=1e-6,
pre_only=True,
)
def forward(
self,
hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None,
tinfo: Dict[str, Any] = None,
):
if tinfo is not None:
m_a, m_c, mom, u_a, u_c, u_m = calcular_fusion(hidden_states, tinfo)
else:
m_a, m_c, mom, u_a, u_c, u_m = (ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada)
residual = hidden_states
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
norm_hidden_states = m_a(norm_hidden_states)
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
joint_attention_kwargs = joint_attention_kwargs or {}
attn_output = self.attn(
hidden_states=norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
gate = gate.unsqueeze(1)
hidden_states = gate * self.proj_out(hidden_states)
hidden_states = u_a(residual + hidden_states)
return hidden_states
@torch.compile
@maybe_allow_in_graph
class FluxTransformerBlock(nn.Module):
def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6):
super().__init__()
self.norm1 = AdaLayerNormZero(dim)
self.norm1_context = AdaLayerNormZero(dim)
if hasattr(F, "scaled_dot_product_attention"):
processor = FluxAttnProcessor2_0()
else:
raise ValueError(
"The current PyTorch version does not support the `scaled_dot_product_attention` function."
)
self.attn = Attention(
query_dim=dim,
cross_attention_dim=None,
added_kv_proj_dim=dim,
dim_head=attention_head_dim,
heads=num_attention_heads,
out_dim=dim,
context_pre_only=False,
bias=True,
processor=processor,
qk_norm=qk_norm,
eps=eps,
)
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
self._chunk_size = None
self._chunk_dim = 0
def forward(
self,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor,
temb: torch.FloatTensor,
image_rotary_emb=None,
joint_attention_kwargs=None,
tinfo: Dict[str, Any] = None, # Add tinfo parameter
):
if tinfo is not None:
m_a, m_c, mom, u_a, u_c, u_m = calcular_fusion(hidden_states, tinfo)
else:
m_a, m_c, mom, u_a, u_c, u_m = (ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada)
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
encoder_hidden_states, emb=temb
)
joint_attention_kwargs = joint_attention_kwargs or {}
norm_hidden_states = m_a(norm_hidden_states)
norm_encoder_hidden_states = m_c(norm_encoder_hidden_states)
attn_output, context_attn_output = self.attn(
hidden_states=norm_hidden_states,
encoder_hidden_states=norm_encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**joint_attention_kwargs,
)
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = u_a(attn_output) + hidden_states
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
norm_hidden_states = mom(norm_hidden_states)
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp.unsqueeze(1) * ff_output
hidden_states = u_m(ff_output) + hidden_states
context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
encoder_hidden_states = u_c(context_attn_output) + encoder_hidden_states
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
context_ff_output = self.ff_context(norm_encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
return encoder_hidden_states, hidden_states
class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
_supports_gradient_checkpointing = True
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
@register_to_config
def __init__(
self,
patch_size: int = 1,
in_channels: int = 64,
out_channels: Optional[int] = None,
num_layers: int = 19,
num_single_layers: int = 38,
attention_head_dim: int = 128,
num_attention_heads: int = 24,
joint_attention_dim: int = 4096,
pooled_projection_dim: int = 768,
guidance_embeds: bool = False,
axes_dims_rope: Tuple[int] = (16, 56, 56),
generator: Optional[torch.Generator] = None,
):
super().__init__()
self.out_channels = out_channels or in_channels
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
text_time_guidance_cls = (
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
)
self.time_text_embed = text_time_guidance_cls(
embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim
)
self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim)
self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim)
self.transformer_blocks = nn.ModuleList(
[
FluxTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_layers)
]
)
self.single_transformer_blocks = nn.ModuleList(
[
FluxSingleTransformerBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
attention_head_dim=self.config.attention_head_dim,
)
for i in range(self.config.num_single_layers)
]
)
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
ratio: float = 0.5
down: int = 1
sx: int = 2
sy: int = 2
rando: bool = False
m1: bool = False
m2: bool = True
m3: bool = False
self.tinfo = {
"size": None,
"args": {
"ratio": ratio,
"down": down,
"sx": sx,
"sy": sy,
"rando": rando,
"m1": m1,
"m2": m2,
"m3": m3,
"generator": generator
}
}
self.gradient_checkpointing = False
@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
processors = {}
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
def fuse_qkv_projections(self):
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedFluxAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
def _set_gradient_checkpointing(self, module, value=False):
if hasattr(module, "gradient_checkpointing"):
module.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor = None,
pooled_projections: torch.Tensor = None,
timestep: torch.LongTensor = None,
img_ids: torch.Tensor = None,
txt_ids: torch.Tensor = None,
guidance: torch.Tensor = None,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_block_samples=None,
controlnet_single_block_samples=None,
return_dict: bool = True,
controlnet_blocks_repeat: bool = False,
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
if joint_attention_kwargs is not None:
joint_attention_kwargs = joint_attention_kwargs.copy()
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
logger.warning(
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
)
hidden_states = self.x_embedder(hidden_states)
if len(hidden_states.shape) == 4:
self.tinfo["size"] = (hidden_states.shape[2], hidden_states.shape[3])
timestep = timestep.to(hidden_states.dtype) * 1000
if guidance is not None:
guidance = guidance.to(hidden_states.dtype) * 1000
else:
guidance = None
temb = (
self.time_text_embed(timestep, pooled_projections)
if guidance is None
else self.time_text_embed(timestep, guidance, pooled_projections)
)
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
if txt_ids.ndim == 3:
logger.warning(
"Passing `txt_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
txt_ids = txt_ids[0]
if img_ids.ndim == 3:
logger.warning(
"Passing `img_ids` 3d torch.Tensor is deprecated."
"Please remove the batch dimension and pass it as a 2d torch Tensor"
)
img_ids = img_ids[0]
ids = torch.cat((txt_ids, img_ids), dim=0)
image_rotary_emb = self.pos_embed(ids)
for index_block, block in enumerate(self.transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
encoder_hidden_states, hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
if controlnet_block_samples is not None:
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
interval_control = int(np.ceil(interval_control))
if controlnet_blocks_repeat:
hidden_states = (
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
)
else:
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
for index_block, block in enumerate(self.single_transformer_blocks):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
temb,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
temb=temb,
image_rotary_emb=image_rotary_emb,
joint_attention_kwargs=joint_attention_kwargs,
)
if controlnet_single_block_samples is not None:
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
interval_control = int(np.ceil(interval_control))
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
+ controlnet_single_block_samples[index_block // interval_control]
)
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
hidden_states = self.norm_out(hidden_states, temb)
output = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)
def load_single_file_checkpoint(
pretrained_model_link_or_path,
force_download=False,
proxies=None,
token=None,
cache_dir=None,
local_files_only=None,
revision=None,
):
import pdb; pdb.set_trace()
if os.path.isfile(pretrained_model_link_or_path):
pretrained_model_link_or_path = pretrained_model_link_or_path
else:
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
pretrained_model_link_or_path = _get_model_file(
repo_id,
weights_name=weights_name,
force_download=force_download,
cache_dir=cache_dir,
proxies=proxies,
local_files_only=local_files_only,
token=token,
revision=revision,
)
import pdb; pdb.set_trace()
checkpoint = load_state_dict(pretrained_model_link_or_path)
# some checkpoints contain the model state dict under a "state_dict" key
while "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
return checkpoint
Pipeline = None
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# ckpt_id = "black-forest-labs/FLUX.1-schnell"
# ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9"
ckpt_id = "silentdriver/4b68f38c0b"
ckpt_revision = "36a3cf4a9f733fc5f31257099b56b304fb2eceab"
def empty_cache():
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_peak_memory_stats()
def load_pipeline() -> Pipeline:
empty_cache()
dtype, device = torch.bfloat16, "cuda"
import pdb; pdb.set_trace()
# t5_path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--t5-v1_1-xxl-encoder-q8/snapshots/59c6c9cb99dcea42067f32caac3ea0836ef4c548/t5-v1_1-xxl-encoder-Q8_0.gguf")
# # config_path = os.path.join(HF_HUB_CACHE, "models--black-forest--labs/FLUX.1-schnell/snapshots/741f7c3ce8b383c54771c7003378a50191e9efe9/text_encoder_2/config.json")
# config_path = os.path.join(HF_HUB_CACHE, "models--black-forest-labs--FLUX.1-schnell/snapshots/741f7c3ce8b383c54771c7003378a50191e9efe9/")
# ckpt_t5 = load_single_file_checkpoint(t5_path,local_files_only=True)
# print("the file is loaded")
text_encoder_2 = T5EncoderModel.from_pretrained(
"silentdriver/aadb864af9", revision = "060dabc7fa271c26dfa3fd43c16e7c5bf3ac7892", torch_dtype=torch.bfloat16
).to(memory_format=torch.channels_last)
vae = AutoencoderTiny.from_pretrained("silentdriver/7815792fb4", revision="bdb7d88ebe5a1c6b02a3c0c78651dd57a403fdf5", torch_dtype=dtype)
path = os.path.join(HF_HUB_CACHE, "models--silentdriver--7d92df966a/snapshots/add1b8d9a84c728c1209448c4a695759240bad3c")
generator = torch.Generator(device=device)
model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False, generator= generator).to(memory_format=torch.channels_last)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
# model = torch.compile(model, mode="max-autotune-no-cudagraphs")
# model = torch.compile(model,backend="aot_eager")
vae = torch.compile(vae)
pipeline = DiffusionPipeline.from_pretrained(
ckpt_id,
vae=vae,
revision=ckpt_revision,
transformer=model,
text_encoder_2=text_encoder_2,
torch_dtype=dtype,
).to(device)
pipeline.vae.requires_grad_(False)
pipeline.transformer.requires_grad_(False)
pipeline.text_encoder_2.requires_grad_(False)
pipeline.text_encoder.requires_grad_(False)
# pipeline.enable_sequential_cpu_offload(exclude=["transformer"])
for _ in range(3):
pipeline(prompt="blah blah waah waah oneshot oneshot gang gang", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256)
empty_cache()
return pipeline
@torch.no_grad()
def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image:
image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0]
return image