ICEdit / icedit /diffusers /models /transformers /transformer_allegro.py
RiverZ's picture
upd
3b609b9
raw
history blame
17.5 kB
# Copyright 2024 The RhymesAI and The HuggingFace Team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import is_torch_version, logging
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import FeedForward
from ..attention_processor import AllegroAttnProcessor2_0, Attention
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
from ..modeling_outputs import Transformer2DModelOutput
from ..modeling_utils import ModelMixin
from ..normalization import AdaLayerNormSingle
logger = logging.get_logger(__name__)
@maybe_allow_in_graph
class AllegroTransformerBlock(nn.Module):
r"""
Transformer block used in [Allegro](https://github.com/rhymes-ai/Allegro) model.
Args:
dim (`int`):
The number of channels in the input and output.
num_attention_heads (`int`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`):
The number of channels in each head.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
cross_attention_dim (`int`, defaults to `2304`):
The dimension of the cross attention features.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to be used in feed-forward.
attention_bias (`bool`, defaults to `False`):
Whether or not to use bias in attention projection layers.
only_cross_attention (`bool`, defaults to `False`):
norm_elementwise_affine (`bool`, defaults to `True`):
Whether to use learnable elementwise affine parameters for normalization.
norm_eps (`float`, defaults to `1e-5`):
Epsilon value for normalization layers.
final_dropout (`bool` defaults to `False`):
Whether to apply a final dropout after the last feed-forward layer.
"""
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
attention_bias: bool = False,
norm_elementwise_affine: bool = True,
norm_eps: float = 1e-5,
):
super().__init__()
# 1. Self Attention
self.norm1 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
cross_attention_dim=None,
processor=AllegroAttnProcessor2_0(),
)
# 2. Cross Attention
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
processor=AllegroAttnProcessor2_0(),
)
# 3. Feed Forward
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
)
# 4. Scale-shift
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
temb: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb=None,
) -> torch.Tensor:
# 0. Self-Attention
batch_size = hidden_states.shape[0]
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
self.scale_shift_table[None] + temb.reshape(batch_size, 6, -1)
).chunk(6, dim=1)
norm_hidden_states = self.norm1(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
norm_hidden_states = norm_hidden_states.squeeze(1)
attn_output = self.attn1(
norm_hidden_states,
encoder_hidden_states=None,
attention_mask=attention_mask,
image_rotary_emb=image_rotary_emb,
)
attn_output = gate_msa * attn_output
hidden_states = attn_output + hidden_states
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
# 1. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = hidden_states
attn_output = self.attn2(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states,
attention_mask=encoder_attention_mask,
image_rotary_emb=None,
)
hidden_states = attn_output + hidden_states
# 2. Feed-forward
norm_hidden_states = self.norm2(hidden_states)
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
ff_output = self.ff(norm_hidden_states)
ff_output = gate_mlp * ff_output
hidden_states = ff_output + hidden_states
# TODO(aryan): maybe following line is not required
if hidden_states.ndim == 4:
hidden_states = hidden_states.squeeze(1)
return hidden_states
class AllegroTransformer3DModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
"""
A 3D Transformer model for video-like data.
Args:
patch_size (`int`, defaults to `2`):
The size of spatial patches to use in the patch embedding layer.
patch_size_t (`int`, defaults to `1`):
The size of temporal patches to use in the patch embedding layer.
num_attention_heads (`int`, defaults to `24`):
The number of heads to use for multi-head attention.
attention_head_dim (`int`, defaults to `96`):
The number of channels in each head.
in_channels (`int`, defaults to `4`):
The number of channels in the input.
out_channels (`int`, *optional*, defaults to `4`):
The number of channels in the output.
num_layers (`int`, defaults to `32`):
The number of layers of Transformer blocks to use.
dropout (`float`, defaults to `0.0`):
The dropout probability to use.
cross_attention_dim (`int`, defaults to `2304`):
The dimension of the cross attention features.
attention_bias (`bool`, defaults to `True`):
Whether or not to use bias in the attention projection layers.
sample_height (`int`, defaults to `90`):
The height of the input latents.
sample_width (`int`, defaults to `160`):
The width of the input latents.
sample_frames (`int`, defaults to `22`):
The number of frames in the input latents.
activation_fn (`str`, defaults to `"gelu-approximate"`):
Activation function to use in feed-forward.
norm_elementwise_affine (`bool`, defaults to `False`):
Whether or not to use elementwise affine in normalization layers.
norm_eps (`float`, defaults to `1e-6`):
The epsilon value to use in normalization layers.
caption_channels (`int`, defaults to `4096`):
Number of channels to use for projecting the caption embeddings.
interpolation_scale_h (`float`, defaults to `2.0`):
Scaling factor to apply in 3D positional embeddings across height dimension.
interpolation_scale_w (`float`, defaults to `2.0`):
Scaling factor to apply in 3D positional embeddings across width dimension.
interpolation_scale_t (`float`, defaults to `2.2`):
Scaling factor to apply in 3D positional embeddings across time dimension.
"""
@register_to_config
def __init__(
self,
patch_size: int = 2,
patch_size_t: int = 1,
num_attention_heads: int = 24,
attention_head_dim: int = 96,
in_channels: int = 4,
out_channels: int = 4,
num_layers: int = 32,
dropout: float = 0.0,
cross_attention_dim: int = 2304,
attention_bias: bool = True,
sample_height: int = 90,
sample_width: int = 160,
sample_frames: int = 22,
activation_fn: str = "gelu-approximate",
norm_elementwise_affine: bool = False,
norm_eps: float = 1e-6,
caption_channels: int = 4096,
interpolation_scale_h: float = 2.0,
interpolation_scale_w: float = 2.0,
interpolation_scale_t: float = 2.2,
):
super().__init__()
self.inner_dim = num_attention_heads * attention_head_dim
interpolation_scale_t = (
interpolation_scale_t
if interpolation_scale_t is not None
else ((sample_frames - 1) // 16 + 1)
if sample_frames % 2 == 1
else sample_frames // 16
)
interpolation_scale_h = interpolation_scale_h if interpolation_scale_h is not None else sample_height / 30
interpolation_scale_w = interpolation_scale_w if interpolation_scale_w is not None else sample_width / 40
# 1. Patch embedding
self.pos_embed = PatchEmbed(
height=sample_height,
width=sample_width,
patch_size=patch_size,
in_channels=in_channels,
embed_dim=self.inner_dim,
pos_embed_type=None,
)
# 2. Transformer blocks
self.transformer_blocks = nn.ModuleList(
[
AllegroTransformerBlock(
self.inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
norm_elementwise_affine=norm_elementwise_affine,
norm_eps=norm_eps,
)
for _ in range(num_layers)
]
)
# 3. Output projection & norm
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * out_channels)
# 4. Timestep embeddings
self.adaln_single = AdaLayerNormSingle(self.inner_dim, use_additional_conditions=False)
# 5. Caption projection
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=self.inner_dim)
self.gradient_checkpointing = False
def _set_gradient_checkpointing(self, module, value=False):
self.gradient_checkpointing = value
def forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
timestep: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
return_dict: bool = True,
):
batch_size, num_channels, num_frames, height, width = hidden_states.shape
p_t = self.config.patch_size_t
p = self.config.patch_size
post_patch_num_frames = num_frames // p_t
post_patch_height = height // p
post_patch_width = width // p
# ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
# we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
# we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
# expects mask of shape:
# [batch, key_tokens]
# adds singleton query_tokens dimension:
# [batch, 1, key_tokens]
# this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
# [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
# [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) attention_mask_vid, attention_mask_img = None, None
if attention_mask is not None and attention_mask.ndim == 4:
# assume that mask is expressed as:
# (1 = keep, 0 = discard)
# convert mask into a bias that can be added to attention scores:
# (keep = +0, discard = -10000.0)
# b, frame+use_image_num, h, w -> a video with images
# b, 1, h, w -> only images
attention_mask = attention_mask.to(hidden_states.dtype)
attention_mask = attention_mask[:, :num_frames] # [batch_size, num_frames, height, width]
if attention_mask.numel() > 0:
attention_mask = attention_mask.unsqueeze(1) # [batch_size, 1, num_frames, height, width]
attention_mask = F.max_pool3d(attention_mask, kernel_size=(p_t, p, p), stride=(p_t, p, p))
attention_mask = attention_mask.flatten(1).view(batch_size, 1, -1)
attention_mask = (
(1 - attention_mask.bool().to(hidden_states.dtype)) * -10000.0 if attention_mask.numel() > 0 else None
)
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(self.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
# 1. Timestep embeddings
timestep, embedded_timestep = self.adaln_single(
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
)
# 2. Patch embeddings
hidden_states = hidden_states.permute(0, 2, 1, 3, 4).flatten(0, 1)
hidden_states = self.pos_embed(hidden_states)
hidden_states = hidden_states.unflatten(0, (batch_size, -1)).flatten(1, 2)
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, encoder_hidden_states.shape[-1])
# 3. Transformer blocks
for i, block in enumerate(self.transformer_blocks):
# TODO(aryan): Implement gradient checkpointing
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(block),
hidden_states,
encoder_hidden_states,
timestep,
attention_mask,
encoder_attention_mask,
image_rotary_emb,
**ckpt_kwargs,
)
else:
hidden_states = block(
hidden_states=hidden_states,
encoder_hidden_states=encoder_hidden_states,
temb=timestep,
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
image_rotary_emb=image_rotary_emb,
)
# 4. Output normalization & projection
shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
# Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
# 5. Unpatchify
hidden_states = hidden_states.reshape(
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p, p, -1
)
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
if not return_dict:
return (output,)
return Transformer2DModelOutput(sample=output)