MIDI-3D / midi /models /transformers /triposg_transformer.py
huanngzh's picture
update
c9724af
# Copyright 2024 HunyuanDiT Authors, Qixun Wang and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple, Union
import torch
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.loaders import PeftAdapterMixin
from diffusers.models.attention import FeedForward
from diffusers.models.attention_processor import Attention, AttentionProcessor
from diffusers.models.embeddings import (
GaussianFourierProjection,
TimestepEmbedding,
Timesteps,
)
from diffusers.models.modeling_utils import ModelMixin
from diffusers.models.normalization import (
AdaLayerNormContinuous,
FP32LayerNorm,
LayerNorm,
)
from diffusers.utils import (
USE_PEFT_BACKEND,
is_torch_version,
logging,
scale_lora_layers,
unscale_lora_layers,
)
from diffusers.utils.torch_utils import maybe_allow_in_graph
from torch import nn
from ..attention_processor import FusedTripoSGAttnProcessor2_0, TripoSGAttnProcessor2_0
from .modeling_outputs import Transformer1DModelOutput
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@maybe_allow_in_graph
class DiTBlock(nn.Module):
r"""
Transformer block used in Hunyuan-DiT model (https://github.com/Tencent/HunyuanDiT). Allow skip connection and
QKNorm
Parameters:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of headsto use for multi-head attention.
cross_attention_dim (`int`,*optional*):
The size of the encoder_hidden_states vector for cross attention.
dropout(`float`, *optional*, defaults to 0.0):
The dropout probability to use.
activation_fn (`str`,*optional*, defaults to `"geglu"`):
Activation function to be used in feed-forward. .
norm_elementwise_affine (`bool`, *optional*, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, *optional*, defaults to 1e-6):
A small constant added to the denominator in normalization layers to prevent division by zero.
final_dropout (`bool` *optional*, defaults to False):
Whether to apply a final dropout after the last feed-forward layer.
ff_inner_dim (`int`, *optional*):
The size of the hidden layer in the feed-forward block. Defaults to `None`.
ff_bias (`bool`, *optional*, defaults to `True`):
Whether to use bias in the feed-forward block.
skip (`bool`, *optional*, defaults to `False`):
Whether to use skip connection. Defaults to `False` for down-blocks and mid-blocks.
qk_norm (`bool`, *optional*, defaults to `True`):
Whether to use normalization in QK calculation. Defaults to `True`.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
use_self_attention: bool = True,
use_cross_attention: bool = False,
self_attention_norm_type: Optional[str] = None, # ada layer norm
cross_attention_dim: Optional[int] = None,
cross_attention_norm_type: Optional[str] = "fp32_layer_norm",
# parallel second cross attention
use_cross_attention_2: bool = False,
cross_attention_2_dim: Optional[int] = None,
cross_attention_2_norm_type: Optional[str] = None,
dropout=0.0,
activation_fn: str = "gelu",
norm_type: str = "fp32_layer_norm", # TODO
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
final_dropout: bool = False,
ff_inner_dim: Optional[int] = None, # int(dim * 4) if None
ff_bias: bool = True,
skip: bool = False,
skip_concat_front: bool = False, # [x, skip] or [skip, x]
skip_norm_last: bool = False, # this is an error
qk_norm: bool = True,
qkv_bias: bool = True,
):
super().__init__()
self.use_self_attention = use_self_attention
self.use_cross_attention = use_cross_attention
self.use_cross_attention_2 = use_cross_attention_2
self.skip_concat_front = skip_concat_front
self.skip_norm_last = skip_norm_last
# Define 3 blocks. Each block has its own normalization layer.
# NOTE: when new version comes, check norm2 and norm 3
# 1. Self-Attn
if use_self_attention:
if (
self_attention_norm_type == "fp32_layer_norm"
or self_attention_norm_type is None
):
self.norm1 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
else:
raise NotImplementedError
self.attn1 = Attention(
query_dim=dim,
cross_attention_dim=None,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
qk_norm="rms_norm" if qk_norm else None,
eps=1e-6,
bias=qkv_bias,
processor=TripoSGAttnProcessor2_0(),
)
# 2. Cross-Attn
if use_cross_attention:
assert cross_attention_dim is not None
self.norm2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
qk_norm="rms_norm" if qk_norm else None,
cross_attention_norm=cross_attention_norm_type,
eps=1e-6,
bias=qkv_bias,
processor=TripoSGAttnProcessor2_0(),
)
# 2'. Parallel Second Cross-Attn
if use_cross_attention_2:
assert cross_attention_2_dim is not None
self.norm2_2 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.attn2_2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_2_dim,
dim_head=dim // num_attention_heads,
heads=num_attention_heads,
qk_norm="rms_norm" if qk_norm else None,
cross_attention_norm=cross_attention_2_norm_type,
eps=1e-6,
bias=qkv_bias,
processor=TripoSGAttnProcessor2_0(),
)
# 3. Feed-forward
self.norm3 = FP32LayerNorm(dim, norm_eps, norm_elementwise_affine)
self.ff = FeedForward(
dim,
dropout=dropout, ### 0.0
activation_fn=activation_fn, ### approx GeLU
final_dropout=final_dropout, ### 0.0
inner_dim=ff_inner_dim, ### int(dim * mlp_ratio)
bias=ff_bias,
)
# 4. Skip Connection
if skip:
self.skip_norm = FP32LayerNorm(dim, norm_eps, elementwise_affine=True)
self.skip_linear = nn.Linear(2 * dim, dim)
else:
self.skip_linear = None
# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
# Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
temb: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[torch.Tensor] = None,
skip: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
) -> torch.Tensor:
# Prepare attention kwargs
attention_kwargs = attention_kwargs or {}
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Long Skip Connection
if self.skip_linear is not None:
cat = torch.cat(
(
[skip, hidden_states]
if self.skip_concat_front
else [hidden_states, skip]
),
dim=-1,
)
if self.skip_norm_last:
# don't do this
hidden_states = self.skip_linear(cat)
hidden_states = self.skip_norm(hidden_states)
else:
cat = self.skip_norm(cat)
hidden_states = self.skip_linear(cat)
# 1. Self-Attention
if self.use_self_attention:
norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(
norm_hidden_states,
image_rotary_emb=image_rotary_emb,
**attention_kwargs,
)
hidden_states = hidden_states + attn_output
# 2. Cross-Attention
if self.use_cross_attention:
if self.use_cross_attention_2:
hidden_states = (
hidden_states
+ self.attn2(
self.norm2(hidden_states),
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**attention_kwargs,
)
+ self.attn2_2(
self.norm2_2(hidden_states),
encoder_hidden_states=encoder_hidden_states_2,
image_rotary_emb=image_rotary_emb,
**attention_kwargs,
)
)
else:
hidden_states = hidden_states + self.attn2(
self.norm2(hidden_states),
encoder_hidden_states=encoder_hidden_states,
image_rotary_emb=image_rotary_emb,
**attention_kwargs,
)
# FFN Layer ### TODO: switch norm2 and norm3 in the state dict
mlp_inputs = self.norm3(hidden_states)
hidden_states = hidden_states + self.ff(mlp_inputs)
return hidden_states
class TripoSGDiTModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
"""
TripoSG: Diffusion model with a Transformer backbone.
Inherit ModelMixin and ConfigMixin to be compatible with the sampler StableDiffusionPipeline of diffusers.
Parameters:
num_attention_heads (`int`, *optional*, defaults to 16):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 88):
The number of channels in each head.
in_channels (`int`, *optional*):
The number of channels in the input and output (specify if the input is **continuous**).
patch_size (`int`, *optional*):
The size of the patch to use for the input.
activation_fn (`str`, *optional*, defaults to `"geglu"`):
Activation function to use in feed-forward.
sample_size (`int`, *optional*):
The width of the latent images. This is fixed during training since it is used to learn a number of
position embeddings.
dropout (`float`, *optional*, defaults to 0.0):
The dropout probability to use.
cross_attention_dim (`int`, *optional*):
The number of dimension in the clip text embedding.
hidden_size (`int`, *optional*):
The size of hidden layer in the conditioning embedding layers.
num_layers (`int`, *optional*, defaults to 1):
The number of layers of Transformer blocks to use.
mlp_ratio (`float`, *optional*, defaults to 4.0):
The ratio of the hidden layer size to the input size.
learn_sigma (`bool`, *optional*, defaults to `True`):
Whether to predict variance.
cross_attention_dim_t5 (`int`, *optional*):
The number dimensions in t5 text embedding.
pooled_projection_dim (`int`, *optional*):
The size of the pooled projection.
text_len (`int`, *optional*):
The length of the clip text embedding.
text_len_t5 (`int`, *optional*):
The length of the T5 text embedding.
use_style_cond_and_image_meta_size (`bool`, *optional*):
Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
width: int = 2048,
in_channels: int = 64,
num_layers: int = 21,
cross_attention_dim: int = 768,
cross_attention_2_dim: int = 1024,
):
super().__init__()
self.out_channels = in_channels
self.num_heads = num_attention_heads
self.inner_dim = width
self.mlp_ratio = 4.0
time_embed_dim, timestep_input_dim = self._set_time_proj(
"positional",
inner_dim=self.inner_dim,
flip_sin_to_cos=False,
freq_shift=0,
time_embedding_dim=None,
)
self.time_proj = TimestepEmbedding(
timestep_input_dim, time_embed_dim, act_fn="gelu", out_dim=self.inner_dim
)
self.proj_in = nn.Linear(self.config.in_channels, self.inner_dim, bias=True)
self.blocks = nn.ModuleList(
[
DiTBlock(
dim=self.inner_dim,
num_attention_heads=self.config.num_attention_heads,
use_self_attention=True,
use_cross_attention=True,
self_attention_norm_type="fp32_layer_norm",
cross_attention_dim=self.config.cross_attention_dim,
cross_attention_norm_type=None,
use_cross_attention_2=True,
cross_attention_2_dim=self.config.cross_attention_2_dim,
cross_attention_2_norm_type=None,
activation_fn="gelu",
norm_type="fp32_layer_norm", # TODO
norm_eps=1e-5,
ff_inner_dim=int(self.inner_dim * self.mlp_ratio),
skip=layer > num_layers // 2,
skip_concat_front=True,
skip_norm_last=True, # this is an error
qk_norm=True, # See http://arxiv.org/abs/2302.05442 for details.
qkv_bias=False,
)
for layer in range(num_layers)
]
)
self.norm_out = LayerNorm(self.inner_dim)
self.proj_out = nn.Linear(self.inner_dim, self.out_channels, bias=True)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def _set_time_proj(
self,
time_embedding_type: str,
inner_dim: int,
flip_sin_to_cos: bool,
freq_shift: float,
time_embedding_dim: int,
) -> Tuple[int, int]:
if time_embedding_type == "fourier":
time_embed_dim = time_embedding_dim or inner_dim * 2
if time_embed_dim % 2 != 0:
raise ValueError(
f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}."
)
self.time_embed = GaussianFourierProjection(
time_embed_dim // 2,
set_W_to_weight=False,
log=False,
flip_sin_to_cos=flip_sin_to_cos,
)
timestep_input_dim = time_embed_dim
elif time_embedding_type == "positional":
time_embed_dim = time_embedding_dim or inner_dim * 4
self.time_embed = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
timestep_input_dim = inner_dim
else:
raise ValueError(
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
)
return time_embed_dim, timestep_input_dim
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedTripoSGAttnProcessor2_0
def fuse_qkv_projections(self):
"""
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
are fused. For cross-attention modules, key and value projection matrices are fused.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
self.original_attn_processors = None
for _, attn_processor in self.attn_processors.items():
if "Added" in str(attn_processor.__class__.__name__):
raise ValueError(
"`fuse_qkv_projections()` is not supported for models having added KV projections."
)
self.original_attn_processors = self.attn_processors
for module in self.modules():
if isinstance(module, Attention):
module.fuse_projections(fuse=True)
self.set_attn_processor(FusedTripoSGAttnProcessor2_0())
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
def unfuse_qkv_projections(self):
"""Disables the fused QKV projection if enabled.
<Tip warning={true}>
This API is 🧪 experimental.
</Tip>
"""
if self.original_attn_processors is not None:
self.set_attn_processor(self.original_attn_processors)
@property
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Returns:
`dict` of attention processors: A dictionary containing all attention processors used in the model with
indexed by its weight name.
"""
# set recursively
processors = {}
def fn_recursive_add_processors(
name: str,
module: torch.nn.Module,
processors: Dict[str, AttentionProcessor],
):
if hasattr(module, "get_processor"):
processors[f"{name}.processor"] = module.get_processor()
for sub_name, child in module.named_children():
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
return processors
for name, module in self.named_children():
fn_recursive_add_processors(name, module, processors)
return processors
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
def set_attn_processor(
self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]
):
r"""
Sets the attention processor to use to compute attention.
Parameters:
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
The instantiated processor class or a dictionary of processor classes that will be set as the processor
for **all** `Attention` layers.
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
processor. This is strongly recommended when setting trainable attention processors.
"""
count = len(self.attn_processors.keys())
if isinstance(processor, dict) and len(processor) != count:
raise ValueError(
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
)
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
if hasattr(module, "set_processor"):
if not isinstance(processor, dict):
module.set_processor(processor)
else:
module.set_processor(processor.pop(f"{name}.processor"))
for sub_name, child in module.named_children():
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
for name, module in self.named_children():
fn_recursive_attn_processor(name, module, processor)
def set_default_attn_processor(self):
"""
Disables custom attention processors and sets the default attention implementation.
"""
self.set_attn_processor(TripoSGAttnProcessor2_0())
def forward(
self,
hidden_states: Optional[torch.Tensor],
timestep: Union[int, float, torch.LongTensor],
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_hidden_states_2: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
):
"""
The [`HunyuanDiT2DModel`] forward method.
Args:
hidden_states (`torch.Tensor` of shape `(batch size, dim, height, width)`):
The input tensor.
timestep ( `torch.LongTensor`, *optional*):
Used to indicate denoising step.
encoder_hidden_states ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer.
encoder_hidden_states_2 ( `torch.Tensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
Conditional embeddings for cross attention layer.
return_dict: bool
Whether to return a dictionary.
"""
if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
else:
lora_scale = 1.0
if USE_PEFT_BACKEND:
# weight the lora layers by setting `lora_scale` for each PEFT layer
scale_lora_layers(self, lora_scale)
else:
if (
attention_kwargs is not None
and attention_kwargs.get("scale", None) is not None
):
logger.warning(
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
)
_, N, _ = hidden_states.shape
temb = self.time_embed(timestep).to(hidden_states.dtype)
temb = self.time_proj(temb)
temb = temb.unsqueeze(dim=1) # unsqueeze to concat with hidden_states
hidden_states = self.proj_in(hidden_states)
# N + 1 token
hidden_states = torch.cat([temb, hidden_states], dim=1)
skips = []
for layer, block in enumerate(self.blocks):
skip = None if layer <= self.config.num_layers // 2 else skips.pop()
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = (
{"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
encoder_hidden_states_2,
temb,
image_rotary_emb,
skip,
attention_kwargs,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_hidden_states_2=encoder_hidden_states_2,
temb=temb,
image_rotary_emb=image_rotary_emb,
skip=skip,
attention_kwargs=attention_kwargs,
) # (N, L, D)
if layer < self.config.num_layers // 2:
skips.append(hidden_states)
# final layer
hidden_states = self.norm_out(hidden_states)
hidden_states = hidden_states[:, -N:]
hidden_states = self.proj_out(hidden_states)
if USE_PEFT_BACKEND:
# remove `lora_scale` from each PEFT layer
unscale_lora_layers(self, lora_scale)
if not return_dict:
return (hidden_states,)
return Transformer1DModelOutput(sample=hidden_states)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
def enable_forward_chunking(
self, chunk_size: Optional[int] = None, dim: int = 0
) -> None:
"""
Sets the attention processor to use [feed forward
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
Parameters:
chunk_size (`int`, *optional*):
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
over each tensor of dim=`dim`.
dim (`int`, *optional*, defaults to `0`):
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
or dim=1 (sequence length).
"""
if dim not in [0, 1]:
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
# By default chunk size is 1
chunk_size = chunk_size or 1
def fn_recursive_feed_forward(
module: torch.nn.Module, chunk_size: int, dim: int
):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
for module in self.children():
fn_recursive_feed_forward(module, chunk_size, dim)
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
def disable_forward_chunking(self):
def fn_recursive_feed_forward(
module: torch.nn.Module, chunk_size: int, dim: int
):
if hasattr(module, "set_chunk_feed_forward"):
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
for child in module.children():
fn_recursive_feed_forward(child, chunk_size, dim)
for module in self.children():
fn_recursive_feed_forward(module, None, 0)