# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass from typing import Any, Dict, List, Optional, Tuple, Union import torch import torch.nn as nn from diffusers.configuration_utils import ConfigMixin, register_to_config from .lora.peft import PeftAdapterMixin from diffusers.models.attention_processor import AttentionProcessor from diffusers.models.modeling_utils import ModelMixin from diffusers.utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers #from .controlnet import BaseOutput, zero_module from diffusers.utils import BaseOutput from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings from diffusers.models.modeling_outputs import Transformer2DModelOutput from .transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock import numpy as np logger = logging.get_logger(__name__) # pylint: disable=invalid-name def zero_module(module): for p in module.parameters(): nn.init.zeros_(p) return module def get_1d_rotary_pos_embed( dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, freqs_dtype=torch.float32, # torch.float32, torch.float64 (flux) ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (`int`): Dimension of the frequency tensor. pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar theta (`float`, *optional*, defaults to 10000.0): Scaling factor for frequency computation. Defaults to 10000.0. use_real (`bool`, *optional*): If True, return real part and imaginary part separately. Otherwise, return complex numbers. linear_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the context extrapolation. Defaults to 1.0. ntk_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. repeat_interleave_real (`bool`, *optional*, defaults to `True`): If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. Otherwise, they are concateanted with themselves. freqs_dtype (`torch.float32` or `torch.float64`, *optional*, defaults to `torch.float32`): the dtype of the frequency tensor. Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ assert dim % 2 == 0 if isinstance(pos, int): pos = torch.arange(pos) if isinstance(pos, np.ndarray): pos = torch.from_numpy(pos) # type: ignore # [S] theta = theta * ntk_factor freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype)[: (dim // 2)] / dim)) / linear_factor # [D/2] freqs = freqs.to(pos.device) freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2] if use_real and repeat_interleave_real: # flux, hunyuan-dit, cogvideox freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D] return freqs_cos, freqs_sin elif use_real: # stable audio freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1).float() # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1).float() # [S, D] return freqs_cos, freqs_sin else: # lumina freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis class FluxPosEmbed(nn.Module): # modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11 def __init__(self, theta: int, axes_dim: List[int]): super().__init__() self.theta = theta self.axes_dim = axes_dim def forward(self, ids: torch.Tensor) -> torch.Tensor: n_axes = ids.shape[-1] cos_out = [] sin_out = [] pos = ids.squeeze().float() is_mps = ids.device.type == "mps" freqs_dtype = torch.float32 if is_mps else torch.float64 for i in range(n_axes): cos, sin = get_1d_rotary_pos_embed( self.axes_dim[i], pos[:, i], repeat_interleave_real=True, use_real=True, freqs_dtype=freqs_dtype ) cos_out.append(cos) sin_out.append(sin) freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device) freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device) return freqs_cos, freqs_sin @dataclass class FluxControlNetOutput(BaseOutput): controlnet_block_samples: Tuple[torch.Tensor] controlnet_single_block_samples: Tuple[torch.Tensor] class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin): _supports_gradient_checkpointing = True @register_to_config def __init__( self, patch_size: int = 1, in_channels: int = 64, num_layers: int = 19, num_single_layers: int = 38, attention_head_dim: int = 128, num_attention_heads: int = 24, joint_attention_dim: int = 4096, pooled_projection_dim: int = 768, guidance_embeds: bool = False, axes_dims_rope: List[int] = [16, 56, 56], num_mode: int = None, ): super().__init__() self.out_channels = in_channels self.inner_dim = num_attention_heads * attention_head_dim self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope) text_time_guidance_cls = ( CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings ) self.time_text_embed = text_time_guidance_cls( embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim ) self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim) self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim) self.transformer_blocks = nn.ModuleList( [ FluxTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, ) for i in range(num_layers) ] ) self.single_transformer_blocks = nn.ModuleList( [ FluxSingleTransformerBlock( dim=self.inner_dim, num_attention_heads=num_attention_heads, attention_head_dim=attention_head_dim, ) for i in range(num_single_layers) ] ) # controlnet_blocks self.controlnet_blocks = nn.ModuleList([]) for _ in range(len(self.transformer_blocks)): self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) self.controlnet_single_blocks = nn.ModuleList([]) for _ in range(len(self.single_transformer_blocks)): self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim))) self.union = num_mode is not None if self.union: self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim) self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim)) self.gradient_checkpointing = False @property # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors def attn_processors(self): r""" Returns: `dict` of attention processors: A dictionary containing all attention processors used in the model with indexed by its weight name. """ # set recursively processors = {} def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): if hasattr(module, "get_processor"): processors[f"{name}.processor"] = module.get_processor() for sub_name, child in module.named_children(): fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) return processors for name, module in self.named_children(): fn_recursive_add_processors(name, module, processors) return processors # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor def set_attn_processor(self, processor): r""" Sets the attention processor to use to compute attention. Parameters: processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): The instantiated processor class or a dictionary of processor classes that will be set as the processor for **all** `Attention` layers. If `processor` is a dict, the key needs to define the path to the corresponding cross attention processor. This is strongly recommended when setting trainable attention processors. """ count = len(self.attn_processors.keys()) if isinstance(processor, dict) and len(processor) != count: raise ValueError( f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" f" number of attention layers: {count}. Please make sure to pass {count} processor classes." ) def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): if hasattr(module, "set_processor"): if not isinstance(processor, dict): module.set_processor(processor) else: module.set_processor(processor.pop(f"{name}.processor")) for sub_name, child in module.named_children(): fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) for name, module in self.named_children(): fn_recursive_attn_processor(name, module, processor) def _set_gradient_checkpointing(self, module, value=False): if hasattr(module, "gradient_checkpointing"): module.gradient_checkpointing = value @classmethod def from_transformer( cls, transformer, num_layers: int = 4, num_single_layers: int = 10, attention_head_dim: int = 128, num_attention_heads: int = 24, load_weights_from_transformer=True, ): config = transformer.config config["num_layers"] = num_layers config["num_single_layers"] = num_single_layers config["attention_head_dim"] = attention_head_dim config["num_attention_heads"] = num_attention_heads controlnet = cls(**config) if load_weights_from_transformer: controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict()) controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict()) controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict()) controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict()) controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False) controlnet.single_transformer_blocks.load_state_dict( transformer.single_transformer_blocks.state_dict(), strict=False ) controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder) return controlnet def forward( self, hidden_states: torch.Tensor, controlnet_cond: torch.Tensor, controlnet_mode: torch.Tensor = None, conditioning_scale: float = 1.0, encoder_hidden_states: torch.Tensor = None, t5_encoder_hidden_states: torch.Tensor = None, pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[torch.FloatTensor, Transformer2DModelOutput]: """ The [`FluxTransformer2DModel`] forward method. Args: hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input `hidden_states`. controlnet_cond (`torch.Tensor`): The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. controlnet_mode (`torch.Tensor`): The mode tensor of shape `(batch_size, 1)`. conditioning_scale (`float`, defaults to `1.0`): The scale factor for ControlNet outputs. encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`): Conditional embeddings (embeddings computed from the input conditions such as prompts) to use. pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected from the embeddings of input conditions. timestep ( `torch.LongTensor`): Used to indicate denoising step. block_controlnet_hidden_states: (`list` of `torch.Tensor`): A list of tensors that if specified are added to the residuals of transformer blocks. joint_attention_kwargs (`dict`, *optional*): A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under `self.processor` in [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain tuple. Returns: If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ if joint_attention_kwargs is not None: joint_attention_kwargs = joint_attention_kwargs.copy() lora_scale = joint_attention_kwargs.pop("scale", 1.0) else: lora_scale = 1.0 if USE_PEFT_BACKEND: # weight the lora layers by setting `lora_scale` for each PEFT layer scale_lora_layers(self, lora_scale) else: if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None: logger.warning( "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective." ) hidden_states = self.x_embedder(hidden_states) # add hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond) timestep = timestep.to(hidden_states.dtype) * 1000 if guidance is not None: guidance = guidance.to(hidden_states.dtype) * 1000 else: guidance = None temb = ( self.time_text_embed(timestep, pooled_projections) if guidance is None else self.time_text_embed(timestep, guidance, pooled_projections) ) encoder_hidden_states = self.context_embedder(encoder_hidden_states) if t5_encoder_hidden_states is not None: encoder_hidden_states = torch.cat([encoder_hidden_states, t5_encoder_hidden_states], dim=1) if txt_ids.ndim == 3: logger.warning( "Passing `txt_ids` 3d torch.Tensor is deprecated." "Please remove the batch dimension and pass it as a 2d torch Tensor" ) txt_ids = txt_ids[0] if self.union: # union mode if controlnet_mode is None: raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union") # union mode emb controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode) encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1) txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0) if img_ids.ndim == 3: logger.warning( "Passing `img_ids` 3d torch.Tensor is deprecated." "Please remove the batch dimension and pass it as a 2d torch Tensor" ) img_ids = img_ids[0] ids = torch.cat((txt_ids, img_ids), dim=0) image_rotary_emb = self.pos_embed(ids) block_samples = () for index_block, block in enumerate(self.transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, encoder_hidden_states, temb, image_rotary_emb, **ckpt_kwargs, ) else: encoder_hidden_states, hidden_states = block( hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, ) block_samples = block_samples + (hidden_states,) hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1) single_block_samples = () for index_block, block in enumerate(self.single_transformer_blocks): if self.training and self.gradient_checkpointing: def create_custom_forward(module, return_dict=None): def custom_forward(*inputs): if return_dict is not None: return module(*inputs, return_dict=return_dict) else: return module(*inputs) return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, temb, image_rotary_emb, **ckpt_kwargs, ) else: hidden_states = block( hidden_states=hidden_states, temb=temb, image_rotary_emb=image_rotary_emb, ) single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],) # controlnet block controlnet_block_samples = () for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks): block_sample = controlnet_block(block_sample) controlnet_block_samples = controlnet_block_samples + (block_sample,) controlnet_single_block_samples = () for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks): single_block_sample = controlnet_block(single_block_sample) controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,) # scaling controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples] controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples] controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples controlnet_single_block_samples = ( None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples ) if USE_PEFT_BACKEND: # remove `lora_scale` from each PEFT layer unscale_lora_layers(self, lora_scale) if not return_dict: return (controlnet_block_samples, controlnet_single_block_samples) return FluxControlNetOutput( controlnet_block_samples=controlnet_block_samples, controlnet_single_block_samples=controlnet_single_block_samples, ) class FluxMultiControlNetModel(ModelMixin): r""" `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be compatible with `FluxControlNetModel`. Args: controlnets (`List[FluxControlNetModel]`): Provides additional conditioning to the unet during the denoising process. You must set multiple `FluxControlNetModel` as a list. """ def __init__(self, controlnets): super().__init__() self.nets = nn.ModuleList(controlnets) def forward( self, hidden_states: torch.FloatTensor, controlnet_cond: List[torch.tensor], controlnet_mode: List[torch.tensor], conditioning_scale: List[float], encoder_hidden_states: torch.Tensor = None, t5_encoder_hidden_states: torch.Tensor = None, pooled_projections: torch.Tensor = None, timestep: torch.LongTensor = None, img_ids: torch.Tensor = None, txt_ids: torch.Tensor = None, guidance: torch.Tensor = None, joint_attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ) -> Union[FluxControlNetOutput, Tuple]: # ControlNet-Union with multiple conditions # only load one ControlNet for saving memories if len(self.nets) == 1 and self.nets[0].union: controlnet = self.nets[0] for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)): block_samples, single_block_samples = controlnet( hidden_states=hidden_states, controlnet_cond=image, controlnet_mode=mode[:, None], conditioning_scale=scale, timestep=timestep, guidance=guidance, pooled_projections=pooled_projections, encoder_hidden_states=encoder_hidden_states, t5_encoder_hidden_states=t5_encoder_hidden_states, txt_ids=txt_ids, img_ids=img_ids, joint_attention_kwargs=joint_attention_kwargs, return_dict=return_dict, ) # merge samples if i == 0: control_block_samples = block_samples control_single_block_samples = single_block_samples else: control_block_samples = [ control_block_sample + block_sample for control_block_sample, block_sample in zip(control_block_samples, block_samples) ] control_single_block_samples = [ control_single_block_sample + block_sample for control_single_block_sample, block_sample in zip( control_single_block_samples, single_block_samples ) ] # Regular Multi-ControlNets # load all ControlNets into memories else: for i, (image, mode, scale, controlnet) in enumerate( zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets) ): block_samples, single_block_samples = controlnet( hidden_states=hidden_states, controlnet_cond=image, controlnet_mode=mode[:, None], conditioning_scale=scale, timestep=timestep, guidance=guidance, pooled_projections=pooled_projections, encoder_hidden_states=encoder_hidden_states, txt_ids=txt_ids, img_ids=img_ids, joint_attention_kwargs=joint_attention_kwargs, return_dict=return_dict, ) # merge samples if i == 0: control_block_samples = block_samples control_single_block_samples = single_block_samples else: if block_samples is not None and control_block_samples is not None: control_block_samples = [ control_block_sample + block_sample for control_block_sample, block_sample in zip(control_block_samples, block_samples) ] if single_block_samples is not None and control_single_block_samples is not None: control_single_block_samples = [ control_single_block_sample + block_sample for control_single_block_sample, block_sample in zip( control_single_block_samples, single_block_samples ) ] return control_block_samples, control_single_block_samples