from diffusers import FluxPipeline, AutoencoderKL, AutoencoderTiny from diffusers.image_processor import VaeImageProcessor from diffusers.schedulers import FlowMatchEulerDiscreteScheduler from huggingface_hub.constants import HF_HUB_CACHE from transformers import T5EncoderModel, T5TokenizerFast, CLIPTokenizer, CLIPTextModel import torch import torch._dynamo import gc from PIL import Image as img from PIL.Image import Image from pipelines.models import TextToImageRequest from torch import Generator import time from diffusers import DiffusionPipeline from torchao.quantization import quantize_, int8_weight_only, fpx_weight_only import os os.environ['PYTORCH_CUDA_ALLOC_CONF']="expandable_segments:True" import torch import math from typing import Type, Dict, Any, Tuple, Callable, Optional, Union import ghanta import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin from diffusers.models.attention import FeedForward from diffusers.models.attention_processor import ( Attention, AttentionProcessor, FluxAttnProcessor2_0, FusedFluxAttnProcessor2_0, ) from diffusers.models.modeling_utils import ModelMixin from diffusers.models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers from diffusers.utils.import_utils import is_torch_npu_available from diffusers.utils.torch_utils import maybe_allow_in_graph from diffusers.models.embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed from diffusers.models.modeling_outputs import Transformer2DModelOutput class BasicQuantization: def __init__(self, bits=1): self.bits = bits self.qmin = -(2**(bits-1)) self.qmax = 2**(bits-1) - 1 def quantize_tensor(self, tensor): scale = (tensor.max() - tensor.min()) / (self.qmax - self.qmin) zero_point = self.qmin - torch.round(tensor.min() / scale) qtensor = torch.round(tensor / scale + zero_point) qtensor = torch.clamp(qtensor, self.qmin, self.qmax) return (qtensor - zero_point) * scale, scale, zero_point class ModelQuantization: def __init__(self, model, bits=7): self.model = model self.quant = BasicQuantization(bits) def quantize_model(self): for name, module in self.model.named_modules(): if isinstance(module, torch.nn.Linear): if hasattr(module, 'weightML'): quantized_weight, _, _ = self.quant.quantize_tensor(module.weight) module.weight = torch.nn.Parameter(quantized_weight) if hasattr(module, 'bias') and module.bias is not None: quantized_bias, _, _ = self.quant.quantize_tensor(module.bias) module.bias = torch.nn.Parameter(quantized_bias) def inicializar_generador(dispositivo: torch.device, respaldo: torch.Generator = None): if dispositivo.type == "cpu": return torch.Generator(device="cpu").set_state(torch.get_rng_state()) elif dispositivo.type == "cuda": return torch.Generator(device=dispositivo).set_state(torch.cuda.get_rng_state()) else: if respaldo is None: return inicializar_generador(torch.device("cpu")) else: return respaldo def calcular_fusion(x: torch.Tensor, info_tome: Dict[str, Any]) -> Tuple[Callable, ...]: alto_original, ancho_original = info_tome["size"] tokens_originales = alto_original * ancho_original submuestreo = int(math.ceil(math.sqrt(tokens_originales // x.shape[1]))) argumentos = info_tome["args"] if submuestreo <= argumentos["down"]: ancho = int(math.ceil(ancho_original / submuestreo)) alto = int(math.ceil(alto_original / submuestreo)) radio = int(x.shape[1] * argumentos["ratio"]) if argumentos["generator"] is None: argumentos["generator"] = inicializar_generador(x.device) elif argumentos["generator"].device != x.device: argumentos["generator"] = inicializar_generador(x.device, respaldo=argumentos["generator"]) usar_aleatoriedad = argumentos["rando"] fusion, desfusion = ghanta.emparejamiento_suave_aleatorio_2d( x, ancho, alto, argumentos["sx"], argumentos["sy"], radio, sin_aleatoriedad=not usar_aleatoriedad, generador=argumentos["generator"] ) else: fusion, desfusion = (hacer_nada, hacer_nada) fusion_a, desfusion_a = (fusion, desfusion) if argumentos["m1"] else (hacer_nada, hacer_nada) fusion_c, desfusion_c = (fusion, desfusion) if argumentos["m2"] else (hacer_nada, hacer_nada) fusion_m, desfusion_m = (fusion, desfusion) if argumentos["m3"] else (hacer_nada, hacer_nada) return fusion_a, fusion_c, fusion_m, desfusion_a, desfusion_c, desfusion_m @torch.compile @maybe_allow_in_graph class FluxSingleTransformerBlock(nn.Module): def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0): super().__init__() self.mlp_hidden_dim = int(dim * mlp_ratio) self.norm = AdaLayerNormZeroSingle(dim) self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim) self.act_mlp = nn.GELU(approximate="tanh") self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim) processor = FluxAttnProcessor2_0() self.attn = Attention( query_dim=dim, cross_attention_dim=None, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, bias=True, processor=processor, qk_norm="rms_norm", eps=1e-6, pre_only=True, ) def forward( self, hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, joint_attention_kwargs=None, tinfo: Dict[str, Any] = None, ): if tinfo is not None: m_a, m_c, mom, u_a, u_c, u_m = calcular_fusion(hidden_states, tinfo) else: m_a, m_c, mom, u_a, u_c, u_m = (ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada) residual = hidden_states norm_hidden_states, gate = self.norm(hidden_states, emb=temb) norm_hidden_states = m_a(norm_hidden_states) mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states)) joint_attention_kwargs = joint_attention_kwargs or {} attn_output = self.attn( hidden_states=norm_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, ) hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2) gate = gate.unsqueeze(1) hidden_states = gate * self.proj_out(hidden_states) hidden_states = u_a(residual + hidden_states) return hidden_states @torch.compile @maybe_allow_in_graph class FluxTransformerBlock(nn.Module): def __init__(self, dim, num_attention_heads, attention_head_dim, qk_norm="rms_norm", eps=1e-6): super().__init__() self.norm1 = AdaLayerNormZero(dim) self.norm1_context = AdaLayerNormZero(dim) if hasattr(F, "scaled_dot_product_attention"): processor = FluxAttnProcessor2_0() else: raise ValueError( "The current PyTorch version does not support the `scaled_dot_product_attention` function." ) self.attn = Attention( query_dim=dim, cross_attention_dim=None, added_kv_proj_dim=dim, dim_head=attention_head_dim, heads=num_attention_heads, out_dim=dim, context_pre_only=False, bias=True, processor=processor, qk_norm=qk_norm, eps=eps, ) self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6) self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate") self._chunk_size = None self._chunk_dim = 0 def forward( self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor, image_rotary_emb=None, joint_attention_kwargs=None, tinfo: Dict[str, Any] = None, # Add tinfo parameter ): if tinfo is not None: m_a, m_c, mom, u_a, u_c, u_m = calcular_fusion(hidden_states, tinfo) else: m_a, m_c, mom, u_a, u_c, u_m = (ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada, ghanta.hacer_nada) norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb) norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context( encoder_hidden_states, emb=temb ) joint_attention_kwargs = joint_attention_kwargs or {} norm_hidden_states = m_a(norm_hidden_states) norm_encoder_hidden_states = m_c(norm_encoder_hidden_states) attn_output, context_attn_output = self.attn( hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states, image_rotary_emb=image_rotary_emb, **joint_attention_kwargs, ) attn_output = gate_msa.unsqueeze(1) * attn_output hidden_states = u_a(attn_output) + hidden_states norm_hidden_states = self.norm2(hidden_states) norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None] norm_hidden_states = mom(norm_hidden_states) ff_output = self.ff(norm_hidden_states) ff_output = gate_mlp.unsqueeze(1) * ff_output hidden_states = u_m(ff_output) + hidden_states context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output encoder_hidden_states = u_c(context_attn_output) + encoder_hidden_states norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states) norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None] context_ff_output = self.ff_context(norm_encoder_hidden_states) encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output return encoder_hidden_states, hidden_states class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin): _supports_gradient_checkpointing = True _no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"] @register_to_config def __init__( self, patch_size: int = 1, in_channels: int = 64, out_channels: Optional[int] = None, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, num_attention_heads: int = 24, joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, axes_dims_rope: Tuple[int] = (16, 56, 56), generator: Optional[torch.Generator] = None, ): super().__init__() self.out_channels = out_channels or in_channels self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings ) self.time_text_embed = text_time_guidance_cls( embedding_dim=self.inner_dim, pooled_projection_dim=self.config.pooled_projection_dim ) self.context_embedder = nn.Linear(self.config.joint_attention_dim, self.inner_dim) self.x_embedder = nn.Linear(self.config.in_channels, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ FluxTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, ) for i in range(self.config.num_layers) ] ) self.single_transformer_blocks = nn.ModuleList( [ FluxSingleTransformerBlock( dim=self.inner_dim, num_attention_heads=self.config.num_attention_heads, attention_head_dim=self.config.attention_head_dim, ) for i in range(self.config.num_single_layers) ] ) self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6) self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True) ratio: float = 0.5 down: int = 1 sx: int = 2 sy: int = 2 rando: bool = False m1: bool = False m2: bool = True m3: bool = False self.tinfo = { "size": None, "args": { "ratio": ratio, "down": down, "sx": sx, "sy": sy, "rando": rando, "m1": m1, "m2": m2, "m3": m3, "generator": generator } } self.gradient_checkpointing = False @property def attn_processors(self) -> Dict[str, AttentionProcessor]: r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0 def fuse_qkv_projections(self): self.original_attn_processors = None for _, attn_processor in self.attn_processors.items(): if "Added" in str(attn_processor.__class__.__name__): raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") self.original_attn_processors = self.attn_processors for module in self.modules(): if isinstance(module, Attention): module.fuse_projections(fuse=True) self.set_attn_processor(FusedFluxAttnProcessor2_0()) # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections def unfuse_qkv_projections(self): if self.original_attn_processors is not None: self.set_attn_processor(self.original_attn_processors) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value def forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor = None, pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, controlnet_block_samples=None, controlnet_single_block_samples=None, return_dict: bool = True, controlnet_blocks_repeat: bool = False, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) hidden_states = self.x_embedder(hidden_states) if len(hidden_states.shape) == 4: self.tinfo["size"] = (hidden_states.shape[2], hidden_states.shape[3]) timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 else: guidance = None temb = ( self.time_text_embed(timestep, pooled_projections) if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) if txt_ids.ndim == 3: logger.warning( "Passing `txt_ids` 3d torch.Tensor is deprecated." "Please remove the batch dimension and pass it as a 2d torch Tensor" ) txt_ids = txt_ids[0] if img_ids.ndim == 3: logger.warning( "Passing `img_ids` 3d torch.Tensor is deprecated." "Please remove the batch dimension and pass it as a 2d torch Tensor" ) img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, image_rotary_emb, **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) if controlnet_block_samples is not None: interval_control = len(self.transformer_blocks) / len(controlnet_block_samples) interval_control = int(np.ceil(interval_control)) if controlnet_blocks_repeat: hidden_states = ( hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)] ) else: hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control] hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) for index_block, block in enumerate(self.single_transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, temb, image_rotary_emb, **ckpt_kwargs, ) else: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, joint_attention_kwargs=joint_attention_kwargs, ) if controlnet_single_block_samples is not None: interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples) interval_control = int(np.ceil(interval_control)) hidden_states[:, encoder_hidden_states.shape[1] :, ...] = ( hidden_states[:, encoder_hidden_states.shape[1] :, ...] + controlnet_single_block_samples[index_block // interval_control] ) hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...] hidden_states = self.norm_out(hidden_states, temb) output = self.proj_out(hidden_states) if USE_PEFT_BACKEND: unscale_lora_layers(self, lora_scale) if not return_dict: return (output,) return Transformer2DModelOutput(sample=output) def load_single_file_checkpoint( pretrained_model_link_or_path, force_download=False, proxies=None, token=None, cache_dir=None, local_files_only=None, revision=None, ): import pdb; pdb.set_trace() if os.path.isfile(pretrained_model_link_or_path): pretrained_model_link_or_path = pretrained_model_link_or_path else: repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path) pretrained_model_link_or_path = _get_model_file( repo_id, weights_name=weights_name, force_download=force_download, cache_dir=cache_dir, proxies=proxies, local_files_only=local_files_only, token=token, revision=revision, ) import pdb; pdb.set_trace() checkpoint = load_state_dict(pretrained_model_link_or_path) # some checkpoints contain the model state dict under a "state_dict" key while "state_dict" in checkpoint: checkpoint = checkpoint["state_dict"] return checkpoint Pipeline = None torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True # ckpt_id = "black-forest-labs/FLUX.1-schnell" # ckpt_revision = "741f7c3ce8b383c54771c7003378a50191e9efe9" ckpt_id = "silentdriver/4b68f38c0b" ckpt_revision = "36a3cf4a9f733fc5f31257099b56b304fb2eceab" def empty_cache(): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() torch.cuda.reset_peak_memory_stats() def load_pipeline() -> Pipeline: empty_cache() dtype, device = torch.bfloat16, "cuda" import pdb; pdb.set_trace() # t5_path = os.path.join(HF_HUB_CACHE, "models--manbeast3b--t5-v1_1-xxl-encoder-q8/snapshots/59c6c9cb99dcea42067f32caac3ea0836ef4c548/t5-v1_1-xxl-encoder-Q8_0.gguf") # # config_path = os.path.join(HF_HUB_CACHE, "models--black-forest--labs/FLUX.1-schnell/snapshots/741f7c3ce8b383c54771c7003378a50191e9efe9/text_encoder_2/config.json") # config_path = os.path.join(HF_HUB_CACHE, "models--black-forest-labs--FLUX.1-schnell/snapshots/741f7c3ce8b383c54771c7003378a50191e9efe9/") # ckpt_t5 = load_single_file_checkpoint(t5_path,local_files_only=True) # print("the file is loaded") text_encoder_2 = T5EncoderModel.from_pretrained( "silentdriver/aadb864af9", revision = "060dabc7fa271c26dfa3fd43c16e7c5bf3ac7892", torch_dtype=torch.bfloat16 ).to(memory_format=torch.channels_last) vae = AutoencoderTiny.from_pretrained("silentdriver/7815792fb4", revision="bdb7d88ebe5a1c6b02a3c0c78651dd57a403fdf5", torch_dtype=dtype) path = os.path.join(HF_HUB_CACHE, "models--silentdriver--7d92df966a/snapshots/add1b8d9a84c728c1209448c4a695759240bad3c") generator = torch.Generator(device=device) model = FluxTransformer2DModel.from_pretrained(path, torch_dtype=dtype, use_safetensors=False, generator= generator).to(memory_format=torch.channels_last) torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False # model = torch.compile(model, mode="max-autotune-no-cudagraphs") # model = torch.compile(model,backend="aot_eager") vae = torch.compile(vae) pipeline = DiffusionPipeline.from_pretrained( ckpt_id, vae=vae, revision=ckpt_revision, transformer=model, text_encoder_2=text_encoder_2, torch_dtype=dtype, ).to(device) pipeline.vae.requires_grad_(False) pipeline.transformer.requires_grad_(False) pipeline.text_encoder_2.requires_grad_(False) pipeline.text_encoder.requires_grad_(False) # pipeline.enable_sequential_cpu_offload(exclude=["transformer"]) for _ in range(3): pipeline(prompt="blah blah waah waah oneshot oneshot gang gang", width=1024, height=1024, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256) empty_cache() return pipeline @torch.no_grad() def infer(request: TextToImageRequest, pipeline: Pipeline, generator: Generator) -> Image: image=pipeline(request.prompt,generator=generator, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, height=request.height, width=request.width, output_type="pil").images[0] return image