|
from typing import Any, List, Tuple, Optional, Union, Dict |
|
from einops import rearrange |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
import numpy as np |
|
|
|
from diffusers.models import ModelMixin |
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
|
|
from .activation_layers import get_activation_layer |
|
from .norm_layers import get_norm_layer |
|
from .embed_layers import TimestepEmbedder, PatchEmbed, TextProjection |
|
from .attention import attention, get_cu_seqlens |
|
from .posemb_layers import apply_rotary_emb |
|
from .mlp_layers import MLP, MLPEmbedder, FinalLayer |
|
from .modulate_layers import ModulateDiT, modulate, apply_gate |
|
from .token_refiner import SingleTokenRefiner |
|
from ...enhance_a_video.enhance import get_feta_scores |
|
from ...enhance_a_video.globals import is_enhance_enabled_single, is_enhance_enabled_double, set_num_frames |
|
from .norm_layers import RMSNorm |
|
|
|
from contextlib import contextmanager |
|
|
|
@contextmanager |
|
def init_weights_on_device(device = torch.device("meta"), include_buffers :bool = False): |
|
|
|
old_register_parameter = torch.nn.Module.register_parameter |
|
if include_buffers: |
|
old_register_buffer = torch.nn.Module.register_buffer |
|
|
|
def register_empty_parameter(module, name, param): |
|
old_register_parameter(module, name, param) |
|
if param is not None: |
|
param_cls = type(module._parameters[name]) |
|
kwargs = module._parameters[name].__dict__ |
|
kwargs["requires_grad"] = param.requires_grad |
|
module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs) |
|
|
|
def register_empty_buffer(module, name, buffer, persistent=True): |
|
old_register_buffer(module, name, buffer, persistent=persistent) |
|
if buffer is not None: |
|
module._buffers[name] = module._buffers[name].to(device) |
|
|
|
def patch_tensor_constructor(fn): |
|
def wrapper(*args, **kwargs): |
|
kwargs["device"] = device |
|
return fn(*args, **kwargs) |
|
|
|
return wrapper |
|
|
|
if include_buffers: |
|
tensor_constructors_to_patch = { |
|
torch_function_name: getattr(torch, torch_function_name) |
|
for torch_function_name in ["empty", "zeros", "ones", "full"] |
|
} |
|
else: |
|
tensor_constructors_to_patch = {} |
|
|
|
try: |
|
torch.nn.Module.register_parameter = register_empty_parameter |
|
if include_buffers: |
|
torch.nn.Module.register_buffer = register_empty_buffer |
|
for torch_function_name in tensor_constructors_to_patch.keys(): |
|
setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name))) |
|
yield |
|
finally: |
|
torch.nn.Module.register_parameter = old_register_parameter |
|
if include_buffers: |
|
torch.nn.Module.register_buffer = old_register_buffer |
|
for torch_function_name, old_torch_function in tensor_constructors_to_patch.items(): |
|
setattr(torch, torch_function_name, old_torch_function) |
|
|
|
class MMDoubleStreamBlock(nn.Module): |
|
""" |
|
A multimodal dit block with seperate modulation for |
|
text and image/video, see more details (SD3): https://arxiv.org/abs/2403.03206 |
|
(Flux.1): https://github.com/black-forest-labs/flux |
|
""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
heads_num: int, |
|
mlp_width_ratio: float, |
|
mlp_act_type: str = "gelu_tanh", |
|
qk_norm: bool = True, |
|
qk_norm_type: str = "rms", |
|
qkv_bias: bool = False, |
|
dtype: Optional[torch.dtype] = None, |
|
device: Optional[torch.device] = None, |
|
attention_mode: str = "sdpa", |
|
): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
|
|
self.attention_mode = attention_mode |
|
|
|
self.deterministic = False |
|
self.heads_num = heads_num |
|
head_dim = hidden_size // heads_num |
|
mlp_hidden_dim = int(hidden_size * mlp_width_ratio) |
|
|
|
self.img_mod = ModulateDiT( |
|
hidden_size, |
|
factor=6, |
|
act_layer=get_activation_layer("silu"), |
|
**factory_kwargs, |
|
) |
|
self.img_norm1 = nn.LayerNorm( |
|
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs |
|
) |
|
|
|
self.img_attn_qkv = nn.Linear( |
|
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs |
|
) |
|
qk_norm_layer = get_norm_layer(qk_norm_type) |
|
self.img_attn_q_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) |
|
if qk_norm |
|
else nn.Identity() |
|
) |
|
self.img_attn_k_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) |
|
if qk_norm |
|
else nn.Identity() |
|
) |
|
self.img_attn_proj = nn.Linear( |
|
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs |
|
) |
|
|
|
self.img_norm2 = nn.LayerNorm( |
|
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs |
|
) |
|
self.img_mlp = MLP( |
|
hidden_size, |
|
mlp_hidden_dim, |
|
act_layer=get_activation_layer(mlp_act_type), |
|
bias=True, |
|
**factory_kwargs, |
|
) |
|
|
|
self.txt_mod = ModulateDiT( |
|
hidden_size, |
|
factor=6, |
|
act_layer=get_activation_layer("silu"), |
|
**factory_kwargs, |
|
) |
|
self.txt_norm1 = nn.LayerNorm( |
|
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs |
|
) |
|
|
|
self.txt_attn_qkv = nn.Linear( |
|
hidden_size, hidden_size * 3, bias=qkv_bias, **factory_kwargs |
|
) |
|
self.txt_attn_q_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) |
|
if qk_norm |
|
else nn.Identity() |
|
) |
|
self.txt_attn_k_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) |
|
if qk_norm |
|
else nn.Identity() |
|
) |
|
self.txt_attn_proj = nn.Linear( |
|
hidden_size, hidden_size, bias=qkv_bias, **factory_kwargs |
|
) |
|
|
|
self.txt_norm2 = nn.LayerNorm( |
|
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs |
|
) |
|
self.txt_mlp = MLP( |
|
hidden_size, |
|
mlp_hidden_dim, |
|
act_layer=get_activation_layer(mlp_act_type), |
|
bias=True, |
|
**factory_kwargs, |
|
) |
|
|
|
def enable_deterministic(self): |
|
self.deterministic = True |
|
|
|
def disable_deterministic(self): |
|
self.deterministic = False |
|
|
|
def forward( |
|
self, |
|
img: torch.Tensor, |
|
txt: torch.Tensor, |
|
vec: torch.Tensor, |
|
cu_seqlens_q: Optional[torch.Tensor] = None, |
|
cu_seqlens_kv: Optional[torch.Tensor] = None, |
|
max_seqlen_q: Optional[int] = None, |
|
max_seqlen_kv: Optional[int] = None, |
|
freqs_cis: tuple = None, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
upcast_rope: bool = True, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
( |
|
img_mod1_shift, |
|
img_mod1_scale, |
|
img_mod1_gate, |
|
img_mod2_shift, |
|
img_mod2_scale, |
|
img_mod2_gate, |
|
) = self.img_mod(vec).chunk(6, dim=-1) |
|
( |
|
txt_mod1_shift, |
|
txt_mod1_scale, |
|
txt_mod1_gate, |
|
txt_mod2_shift, |
|
txt_mod2_scale, |
|
txt_mod2_gate, |
|
) = self.txt_mod(vec).chunk(6, dim=-1) |
|
|
|
|
|
img_modulated = self.img_norm1(img) |
|
img_modulated = modulate( |
|
img_modulated, shift=img_mod1_shift, scale=img_mod1_scale |
|
) |
|
img_qkv = self.img_attn_qkv(img_modulated) |
|
img_q, img_k, img_v = rearrange( |
|
img_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num |
|
) |
|
|
|
img_q = self.img_attn_q_norm(img_q).to(img_v) |
|
img_k = self.img_attn_k_norm(img_k).to(img_v) |
|
|
|
|
|
if freqs_cis is not None: |
|
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, upcast=upcast_rope) |
|
|
|
|
|
txt_modulated = self.txt_norm1(txt) |
|
txt_modulated = modulate( |
|
txt_modulated, shift=txt_mod1_shift, scale=txt_mod1_scale |
|
) |
|
txt_qkv = self.txt_attn_qkv(txt_modulated) |
|
txt_q, txt_k, txt_v = rearrange( |
|
txt_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num |
|
) |
|
|
|
|
|
txt_q = self.txt_attn_q_norm(txt_q).to(txt_v) |
|
txt_k = self.txt_attn_k_norm(txt_k).to(txt_v) |
|
|
|
if is_enhance_enabled_double(): |
|
feta_scores = get_feta_scores(img_q, img_k) |
|
|
|
|
|
q = torch.cat((img_q, txt_q), dim=1) |
|
k = torch.cat((img_k, txt_k), dim=1) |
|
v = torch.cat((img_v, txt_v), dim=1) |
|
|
|
attn = attention( |
|
q, |
|
k, |
|
v, |
|
heads = self.heads_num, |
|
mode=self.attention_mode, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_kv=cu_seqlens_kv, |
|
max_seqlen_q=max_seqlen_q, |
|
max_seqlen_kv=max_seqlen_kv, |
|
batch_size=img_k.shape[0], |
|
attn_mask=attn_mask |
|
) |
|
|
|
img_attn, txt_attn = attn[:, : img.shape[1]], attn[:, img.shape[1] :] |
|
if is_enhance_enabled_double(): |
|
img_attn *= feta_scores |
|
|
|
|
|
img = img + apply_gate(self.img_attn_proj(img_attn), gate=img_mod1_gate) |
|
img = img + apply_gate( |
|
self.img_mlp( |
|
modulate( |
|
self.img_norm2(img), shift=img_mod2_shift, scale=img_mod2_scale |
|
) |
|
), |
|
gate=img_mod2_gate, |
|
) |
|
|
|
|
|
txt = txt + apply_gate(self.txt_attn_proj(txt_attn), gate=txt_mod1_gate) |
|
txt = txt + apply_gate( |
|
self.txt_mlp( |
|
modulate( |
|
self.txt_norm2(txt), shift=txt_mod2_shift, scale=txt_mod2_scale |
|
) |
|
), |
|
gate=txt_mod2_gate, |
|
) |
|
|
|
return img, txt |
|
|
|
|
|
class MMSingleStreamBlock(nn.Module): |
|
""" |
|
A DiT block with parallel linear layers as described in |
|
https://arxiv.org/abs/2302.05442 and adapted modulation interface. |
|
Also refer to (SD3): https://arxiv.org/abs/2403.03206 |
|
(Flux.1): https://github.com/black-forest-labs/flux |
|
""" |
|
|
|
def __init__( |
|
self, |
|
hidden_size: int, |
|
heads_num: int, |
|
mlp_width_ratio: float = 4.0, |
|
mlp_act_type: str = "gelu_tanh", |
|
qk_norm: bool = True, |
|
qk_norm_type: str = "rms", |
|
qk_scale: float = None, |
|
dtype: Optional[torch.dtype] = None, |
|
device: Optional[torch.device] = None, |
|
attention_mode: str = "sdpa", |
|
): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
|
|
self.attention_mode = attention_mode |
|
|
|
self.deterministic = False |
|
self.hidden_size = hidden_size |
|
self.heads_num = heads_num |
|
head_dim = hidden_size // heads_num |
|
mlp_hidden_dim = int(hidden_size * mlp_width_ratio) |
|
self.mlp_hidden_dim = mlp_hidden_dim |
|
self.scale = qk_scale or head_dim ** -0.5 |
|
|
|
|
|
self.linear1 = nn.Linear( |
|
hidden_size, hidden_size * 3 + mlp_hidden_dim, **factory_kwargs |
|
) |
|
|
|
self.linear2 = nn.Linear( |
|
hidden_size + mlp_hidden_dim, hidden_size, **factory_kwargs |
|
) |
|
|
|
qk_norm_layer = get_norm_layer(qk_norm_type) |
|
self.q_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) |
|
if qk_norm |
|
else nn.Identity() |
|
) |
|
self.k_norm = ( |
|
qk_norm_layer(head_dim, elementwise_affine=True, eps=1e-6, **factory_kwargs) |
|
if qk_norm |
|
else nn.Identity() |
|
) |
|
|
|
self.pre_norm = nn.LayerNorm( |
|
hidden_size, elementwise_affine=False, eps=1e-6, **factory_kwargs |
|
) |
|
|
|
self.mlp_act = get_activation_layer(mlp_act_type)() |
|
self.modulation = ModulateDiT( |
|
hidden_size, |
|
factor=3, |
|
act_layer=get_activation_layer("silu"), |
|
**factory_kwargs, |
|
) |
|
|
|
def enable_deterministic(self): |
|
self.deterministic = True |
|
|
|
def disable_deterministic(self): |
|
self.deterministic = False |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
vec: torch.Tensor, |
|
txt_len: int, |
|
cu_seqlens_q: Optional[torch.Tensor] = None, |
|
cu_seqlens_kv: Optional[torch.Tensor] = None, |
|
max_seqlen_q: Optional[int] = None, |
|
max_seqlen_kv: Optional[int] = None, |
|
freqs_cis: Tuple[torch.Tensor, torch.Tensor] = None, |
|
attn_mask: Optional[torch.Tensor] = None, |
|
upcast_rope: bool = True, |
|
stg_mode: Optional[str] = None, |
|
) -> torch.Tensor: |
|
mod_shift, mod_scale, mod_gate = self.modulation(vec).chunk(3, dim=-1) |
|
x_mod = modulate(self.pre_norm(x), shift=mod_shift, scale=mod_scale) |
|
qkv, mlp = torch.split( |
|
self.linear1(x_mod), [3 * self.hidden_size, self.mlp_hidden_dim], dim=-1 |
|
) |
|
|
|
q, k, v = rearrange(qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num) |
|
|
|
|
|
q = self.q_norm(q).to(v) |
|
k = self.k_norm(k).to(v) |
|
|
|
|
|
if freqs_cis is not None: |
|
img_q, txt_q = q[:, :-txt_len, :, :], q[:, -txt_len:, :, :] |
|
img_k, txt_k = k[:, :-txt_len, :, :], k[:, -txt_len:, :, :] |
|
img_q, img_k = apply_rotary_emb(img_q, img_k, freqs_cis, upcast=upcast_rope) |
|
|
|
|
|
|
|
q = torch.cat((img_q, txt_q), dim=1) |
|
k = torch.cat((img_k, txt_k), dim=1) |
|
|
|
if is_enhance_enabled_single(): |
|
feta_scores = get_feta_scores(img_q, img_k) |
|
|
|
|
|
|
|
|
|
|
|
if stg_mode is not None: |
|
if stg_mode == "STG-A": |
|
attn = attention( |
|
q, |
|
k, |
|
v, |
|
heads = self.heads_num, |
|
mode=self.attention_mode, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_kv=cu_seqlens_kv, |
|
max_seqlen_q=max_seqlen_q, |
|
max_seqlen_kv=max_seqlen_kv, |
|
batch_size=x.shape[0], |
|
do_stg=True, |
|
txt_len=txt_len, |
|
attn_mask=attn_mask |
|
) |
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) |
|
return x + apply_gate(output, gate=mod_gate) |
|
elif stg_mode == "STG-R": |
|
attn = attention( |
|
q, |
|
k, |
|
v, |
|
heads = self.heads_num, |
|
mode=self.attention_mode, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_kv=cu_seqlens_kv, |
|
max_seqlen_q=max_seqlen_q, |
|
max_seqlen_kv=max_seqlen_kv, |
|
batch_size=x.shape[0], |
|
attn_mask=attn_mask |
|
) |
|
|
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) |
|
output = apply_gate(output, gate=mod_gate) |
|
batch_size = output.shape[0] |
|
output[:batch_size-1, :, :] = 0 |
|
return x + output |
|
else: |
|
attn = attention( |
|
q, |
|
k, |
|
v, |
|
heads = self.heads_num, |
|
mode=self.attention_mode, |
|
cu_seqlens_q=cu_seqlens_q, |
|
cu_seqlens_kv=cu_seqlens_kv, |
|
max_seqlen_q=max_seqlen_q, |
|
max_seqlen_kv=max_seqlen_kv, |
|
batch_size=x.shape[0], |
|
attn_mask=attn_mask |
|
) |
|
if is_enhance_enabled_single(): |
|
attn *= feta_scores |
|
|
|
|
|
|
|
output = self.linear2(torch.cat((attn, self.mlp_act(mlp)), 2)) |
|
output = x + apply_gate(output, gate=mod_gate) |
|
|
|
|
|
return output |
|
|
|
|
|
class HYVideoDiffusionTransformer(ModelMixin, ConfigMixin): |
|
""" |
|
HunyuanVideo Transformer backbone |
|
|
|
Inherited from ModelMixin and ConfigMixin for compatibility with diffusers' sampler StableDiffusionPipeline. |
|
|
|
Reference: |
|
[1] Flux.1: https://github.com/black-forest-labs/flux |
|
[2] MMDiT: http://arxiv.org/abs/2403.03206 |
|
|
|
Parameters |
|
---------- |
|
args: argparse.Namespace |
|
The arguments parsed by argparse. |
|
patch_size: list |
|
The size of the patch. |
|
in_channels: int |
|
The number of input channels. |
|
out_channels: int |
|
The number of output channels. |
|
hidden_size: int |
|
The hidden size of the transformer backbone. |
|
heads_num: int |
|
The number of attention heads. |
|
mlp_width_ratio: float |
|
The ratio of the hidden size of the MLP in the transformer block. |
|
mlp_act_type: str |
|
The activation function of the MLP in the transformer block. |
|
depth_double_blocks: int |
|
The number of transformer blocks in the double blocks. |
|
depth_single_blocks: int |
|
The number of transformer blocks in the single blocks. |
|
rope_dim_list: list |
|
The dimension of the rotary embedding for t, h, w. |
|
qkv_bias: bool |
|
Whether to use bias in the qkv linear layer. |
|
qk_norm: bool |
|
Whether to use qk norm. |
|
qk_norm_type: str |
|
The type of qk norm. |
|
guidance_embed: bool |
|
Whether to use guidance embedding for distillation. |
|
text_projection: str |
|
The type of the text projection, default is single_refiner. |
|
use_attention_mask: bool |
|
Whether to use attention mask for text encoder. |
|
dtype: torch.dtype |
|
The dtype of the model. |
|
device: torch.device |
|
The device of the model. |
|
""" |
|
|
|
@register_to_config |
|
def __init__( |
|
self, |
|
patch_size: list = [1, 2, 2], |
|
in_channels: int = 4, |
|
out_channels: int = None, |
|
hidden_size: int = 3072, |
|
heads_num: int = 24, |
|
mlp_width_ratio: float = 4.0, |
|
mlp_act_type: str = "gelu_tanh", |
|
mm_double_blocks_depth: int = 20, |
|
mm_single_blocks_depth: int = 40, |
|
rope_dim_list: List[int] = [16, 56, 56], |
|
qkv_bias: bool = True, |
|
qk_norm: bool = True, |
|
qk_norm_type: str = "rms", |
|
guidance_embed: bool = False, |
|
text_projection: str = "single_refiner", |
|
use_attention_mask: bool = True, |
|
text_states_dim: int = 4096, |
|
text_states_dim_2: int = 768, |
|
dtype: Optional[torch.dtype] = None, |
|
device: Optional[torch.device] = None, |
|
main_device: Optional[torch.device] = None, |
|
offload_device: Optional[torch.device] = None, |
|
attention_mode: str = "sdpa", |
|
): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
|
|
self.patch_size = patch_size |
|
self.in_channels = in_channels |
|
self.out_channels = in_channels if out_channels is None else out_channels |
|
self.unpatchify_channels = self.out_channels |
|
self.guidance_embed = guidance_embed |
|
self.rope_dim_list = rope_dim_list |
|
|
|
self.main_device = main_device |
|
self.offload_device = offload_device |
|
self.attention_mode = attention_mode |
|
|
|
|
|
|
|
self.use_attention_mask = use_attention_mask |
|
self.text_projection = text_projection |
|
|
|
self.text_states_dim = text_states_dim |
|
self.text_states_dim_2 = text_states_dim_2 |
|
|
|
if hidden_size % heads_num != 0: |
|
raise ValueError( |
|
f"Hidden size {hidden_size} must be divisible by heads_num {heads_num}" |
|
) |
|
pe_dim = hidden_size // heads_num |
|
if sum(rope_dim_list) != pe_dim: |
|
raise ValueError( |
|
f"Got {rope_dim_list} but expected positional dim {pe_dim}" |
|
) |
|
self.hidden_size = hidden_size |
|
self.heads_num = heads_num |
|
|
|
|
|
self.img_in = PatchEmbed( |
|
self.patch_size, self.in_channels, self.hidden_size, **factory_kwargs |
|
) |
|
|
|
|
|
if self.text_projection == "linear": |
|
self.txt_in = TextProjection( |
|
self.text_states_dim, |
|
self.hidden_size, |
|
get_activation_layer("silu"), |
|
**factory_kwargs, |
|
) |
|
elif self.text_projection == "single_refiner": |
|
self.txt_in = SingleTokenRefiner( |
|
self.text_states_dim, hidden_size, heads_num, depth=2, **factory_kwargs |
|
) |
|
else: |
|
raise NotImplementedError( |
|
f"Unsupported text_projection: {self.text_projection}" |
|
) |
|
|
|
|
|
self.time_in = TimestepEmbedder( |
|
self.hidden_size, get_activation_layer("silu"), **factory_kwargs |
|
) |
|
|
|
|
|
self.vector_in = MLPEmbedder( |
|
self.text_states_dim_2, self.hidden_size, **factory_kwargs |
|
) |
|
|
|
|
|
self.guidance_in = ( |
|
TimestepEmbedder( |
|
self.hidden_size, get_activation_layer("silu"), **factory_kwargs |
|
) |
|
if guidance_embed |
|
else None |
|
) |
|
|
|
|
|
self.double_blocks = nn.ModuleList( |
|
[ |
|
MMDoubleStreamBlock( |
|
self.hidden_size, |
|
self.heads_num, |
|
mlp_width_ratio=mlp_width_ratio, |
|
mlp_act_type=mlp_act_type, |
|
qk_norm=qk_norm, |
|
qk_norm_type=qk_norm_type, |
|
qkv_bias=qkv_bias, |
|
attention_mode=attention_mode, |
|
**factory_kwargs, |
|
) |
|
for _ in range(mm_double_blocks_depth) |
|
] |
|
) |
|
|
|
|
|
self.single_blocks = nn.ModuleList( |
|
[ |
|
MMSingleStreamBlock( |
|
self.hidden_size, |
|
self.heads_num, |
|
mlp_width_ratio=mlp_width_ratio, |
|
mlp_act_type=mlp_act_type, |
|
qk_norm=qk_norm, |
|
qk_norm_type=qk_norm_type, |
|
attention_mode=attention_mode, |
|
**factory_kwargs, |
|
) |
|
for _ in range(mm_single_blocks_depth) |
|
] |
|
) |
|
|
|
self.final_layer = FinalLayer( |
|
self.hidden_size, |
|
self.patch_size, |
|
self.out_channels, |
|
get_activation_layer("silu"), |
|
**factory_kwargs, |
|
) |
|
|
|
self.upcast_rope = True |
|
|
|
|
|
self.double_blocks_to_swap = -1 |
|
self.single_blocks_to_swap = -1 |
|
self.offload_txt_in = False |
|
self.offload_img_in = False |
|
|
|
|
|
self.enable_teacache = False |
|
self.cnt = 0 |
|
self.num_steps = 0 |
|
self.rel_l1_thresh = 0.15 |
|
self.accumulated_rel_l1_distance = 0 |
|
self.previous_modulated_input = None |
|
self.previous_residual = None |
|
self.last_dimensions = None |
|
self.last_frame_count = None |
|
|
|
|
|
def block_swap(self, double_blocks_to_swap, single_blocks_to_swap, offload_txt_in=False, offload_img_in=False): |
|
print(f"Swapping {double_blocks_to_swap + 1} double blocks and {single_blocks_to_swap + 1} single blocks") |
|
self.double_blocks_to_swap = double_blocks_to_swap |
|
self.single_blocks_to_swap = single_blocks_to_swap |
|
self.offload_txt_in = offload_txt_in |
|
self.offload_img_in = offload_img_in |
|
for b, block in enumerate(self.double_blocks): |
|
if b > self.double_blocks_to_swap: |
|
|
|
block.to(self.main_device) |
|
else: |
|
|
|
block.to(self.offload_device) |
|
for b, block in enumerate(self.single_blocks): |
|
if b > self.single_blocks_to_swap: |
|
block.to(self.main_device) |
|
else: |
|
block.to(self.offload_device) |
|
|
|
def enable_auto_offload(self, dtype=torch.bfloat16, device="cuda"): |
|
def cast_to(weight, dtype=None, device=None, copy=False): |
|
if device is None or weight.device == device: |
|
if not copy: |
|
if dtype is None or weight.dtype == dtype: |
|
return weight |
|
return weight.to(dtype=dtype, copy=copy) |
|
|
|
r = torch.empty_like(weight, dtype=dtype, device=device) |
|
r.copy_(weight) |
|
return r |
|
|
|
def cast_weight(s, input=None, dtype=None, device=None): |
|
if input is not None: |
|
if dtype is None: |
|
dtype = input.dtype |
|
if device is None: |
|
device = input.device |
|
weight = cast_to(s.weight, dtype, device) |
|
return weight |
|
|
|
def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): |
|
if input is not None: |
|
if dtype is None: |
|
dtype = input.dtype |
|
if bias_dtype is None: |
|
bias_dtype = dtype |
|
if device is None: |
|
device = input.device |
|
weight = cast_to(s.weight, dtype, device) |
|
bias = cast_to(s.bias, bias_dtype, device) if s.bias is not None else None |
|
return weight, bias |
|
|
|
class quantized_layer: |
|
class Linear(torch.nn.Linear): |
|
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.dtype = dtype |
|
self.device = device |
|
|
|
def block_forward_(self, x, i, j, dtype, device): |
|
weight_ = cast_to( |
|
self.weight[j * self.block_size: (j + 1) * self.block_size, i * self.block_size: (i + 1) * self.block_size], |
|
dtype=dtype, device=device |
|
) |
|
if self.bias is None or i > 0: |
|
bias_ = None |
|
else: |
|
bias_ = cast_to(self.bias[j * self.block_size: (j + 1) * self.block_size], dtype=dtype, device=device) |
|
x_ = x[..., i * self.block_size: (i + 1) * self.block_size] |
|
y_ = torch.nn.functional.linear(x_, weight_, bias_) |
|
del x_, weight_, bias_ |
|
torch.cuda.empty_cache() |
|
return y_ |
|
|
|
def block_forward(self, x, **kwargs): |
|
|
|
y = torch.zeros(x.shape[:-1] + (self.out_features,), dtype=x.dtype, device=x.device) |
|
for i in range((self.in_features + self.block_size - 1) // self.block_size): |
|
for j in range((self.out_features + self.block_size - 1) // self.block_size): |
|
y[..., j * self.block_size: (j + 1) * self.block_size] += self.block_forward_(x, i, j, dtype=x.dtype, device=x.device) |
|
return y |
|
|
|
def forward(self, x, **kwargs): |
|
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device) |
|
return torch.nn.functional.linear(x, weight, bias) |
|
|
|
|
|
class RMSNorm(torch.nn.Module): |
|
def __init__(self, module, dtype=torch.bfloat16, device="cuda"): |
|
super().__init__() |
|
self.module = module |
|
self.dtype = dtype |
|
self.device = device |
|
|
|
def forward(self, hidden_states, **kwargs): |
|
input_dtype = hidden_states.dtype |
|
variance = hidden_states.to(torch.float32).square().mean(-1, keepdim=True) |
|
hidden_states = hidden_states * torch.rsqrt(variance + self.module.eps) |
|
hidden_states = hidden_states.to(input_dtype) |
|
if self.module.weight is not None: |
|
weight = cast_weight(self.module, hidden_states, dtype=torch.bfloat16, device="cuda") |
|
hidden_states = hidden_states * weight |
|
return hidden_states |
|
|
|
class Conv3d(torch.nn.Conv3d): |
|
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.dtype = dtype |
|
self.device = device |
|
|
|
def forward(self, x): |
|
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device) |
|
return torch.nn.functional.conv3d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups) |
|
|
|
class LayerNorm(torch.nn.LayerNorm): |
|
def __init__(self, *args, dtype=torch.bfloat16, device="cuda", **kwargs): |
|
super().__init__(*args, **kwargs) |
|
self.dtype = dtype |
|
self.device = device |
|
|
|
def forward(self, x): |
|
if self.weight is not None and self.bias is not None: |
|
weight, bias = cast_bias_weight(self, x, dtype=self.dtype, device=self.device) |
|
return torch.nn.functional.layer_norm(x, self.normalized_shape, weight, bias, self.eps) |
|
else: |
|
return torch.nn.functional.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) |
|
def replace_layer(model, dtype=torch.bfloat16, device="cuda"): |
|
for name, module in model.named_children(): |
|
if isinstance(module, torch.nn.Linear): |
|
with init_weights_on_device(): |
|
new_layer = quantized_layer.Linear( |
|
module.in_features, module.out_features, bias=module.bias is not None, |
|
dtype=dtype, device=device |
|
) |
|
new_layer.load_state_dict(module.state_dict(), assign=True) |
|
setattr(model, name, new_layer) |
|
elif isinstance(module, torch.nn.Conv3d): |
|
with init_weights_on_device(): |
|
new_layer = quantized_layer.Conv3d( |
|
module.in_channels, module.out_channels, kernel_size=module.kernel_size, stride=module.stride, |
|
dtype=dtype, device=device |
|
) |
|
new_layer.load_state_dict(module.state_dict(), assign=True) |
|
setattr(model, name, new_layer) |
|
elif isinstance(module, RMSNorm): |
|
new_layer = quantized_layer.RMSNorm( |
|
module, |
|
dtype=dtype, device=device |
|
) |
|
setattr(model, name, new_layer) |
|
elif isinstance(module, torch.nn.LayerNorm): |
|
with init_weights_on_device(): |
|
new_layer = quantized_layer.LayerNorm( |
|
module.normalized_shape, elementwise_affine=module.elementwise_affine, eps=module.eps, |
|
dtype=dtype, device=device |
|
) |
|
new_layer.load_state_dict(module.state_dict(), assign=True) |
|
setattr(model, name, new_layer) |
|
else: |
|
replace_layer(module, dtype=dtype, device=device) |
|
|
|
replace_layer(self, dtype=dtype, device=device) |
|
|
|
def enable_deterministic(self): |
|
for block in self.double_blocks: |
|
block.enable_deterministic() |
|
for block in self.single_blocks: |
|
block.enable_deterministic() |
|
|
|
def disable_deterministic(self): |
|
for block in self.double_blocks: |
|
block.disable_deterministic() |
|
for block in self.single_blocks: |
|
block.disable_deterministic() |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
t: torch.Tensor, |
|
text_states: torch.Tensor = None, |
|
text_mask: torch.Tensor = None, |
|
text_states_2: Optional[torch.Tensor] = None, |
|
freqs_cos: Optional[torch.Tensor] = None, |
|
freqs_sin: Optional[torch.Tensor] = None, |
|
guidance: torch.Tensor = None, |
|
stg_mode: str = None, |
|
stg_block_idx: int = -1, |
|
return_dict: bool = True, |
|
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]: |
|
|
|
def _process_double_blocks(img, txt, vec, block_args): |
|
for b, block in enumerate(self.double_blocks): |
|
if b <= self.double_blocks_to_swap and self.double_blocks_to_swap >= 0: |
|
block.to(self.main_device) |
|
|
|
img, txt = block(img, txt, vec, *block_args) |
|
|
|
if b <= self.double_blocks_to_swap and self.double_blocks_to_swap >= 0: |
|
block.to(self.offload_device, non_blocking=True) |
|
return img, txt |
|
|
|
def _process_single_blocks(x, vec, txt_seq_len, block_args, stg_mode=None, stg_block_idx=None): |
|
for b, block in enumerate(self.single_blocks): |
|
if b <= self.single_blocks_to_swap and self.single_blocks_to_swap >= 0: |
|
block.to(self.main_device) |
|
|
|
curr_stg_mode = stg_mode if b == stg_block_idx else None |
|
x = block(x, vec, txt_seq_len, *block_args, curr_stg_mode) |
|
|
|
if b <= self.single_blocks_to_swap and self.single_blocks_to_swap >= 0: |
|
block.to(self.offload_device, non_blocking=True) |
|
return x |
|
|
|
out = {} |
|
img = x |
|
txt = text_states |
|
_, _, ot, oh, ow = x.shape |
|
tt, th, tw = ( |
|
ot // self.patch_size[0], |
|
oh // self.patch_size[1], |
|
ow // self.patch_size[2], |
|
) |
|
set_num_frames(img.shape[2]) |
|
|
|
current_dims = (ot, oh, ow) |
|
|
|
|
|
if not hasattr(self, 'last_dims') or self.last_dims != current_dims: |
|
|
|
self.cnt = 0 |
|
self.accumulated_rel_l1_distance = 0 |
|
self.previous_modulated_input = None |
|
self.previous_residual = None |
|
self.last_dims = current_dims |
|
|
|
|
|
vec = self.time_in(t) |
|
|
|
|
|
if text_states_2 is not None: |
|
vec = vec + self.vector_in(text_states_2) |
|
|
|
|
|
if guidance is not None: |
|
|
|
vec = vec + self.guidance_in(guidance) |
|
|
|
|
|
if self.offload_txt_in: |
|
self.txt_in.to(self.main_device) |
|
if self.offload_img_in: |
|
self.img_in.to(self.main_device) |
|
|
|
img = self.img_in(img) |
|
if self.text_projection == "linear": |
|
txt = self.txt_in(txt) |
|
elif self.text_projection == "single_refiner": |
|
txt = self.txt_in(txt, t, text_mask if self.use_attention_mask else None) |
|
else: |
|
raise NotImplementedError( |
|
f"Unsupported text_projection: {self.text_projection}" |
|
) |
|
if self.offload_txt_in: |
|
self.txt_in.to(self.offload_device, non_blocking=True) |
|
if self.offload_img_in: |
|
self.img_in.to(self.offload_device, non_blocking=True) |
|
|
|
txt_seq_len = txt.shape[1] |
|
img_seq_len = img.shape[1] |
|
max_seqlen_q = max_seqlen_kv = img_seq_len + txt_seq_len |
|
|
|
if "varlen" not in self.attention_mode: |
|
cu_seqlens_q, cu_seqlens_kv = None, None |
|
|
|
attn_mask = torch.zeros((1, max_seqlen_q, max_seqlen_q), dtype=torch.bool, device=text_mask.device) |
|
|
|
|
|
text_len = text_mask[0].sum().item() |
|
total_len = text_len + img_seq_len |
|
|
|
|
|
attn_mask[0, :total_len, :total_len] = True |
|
else: |
|
attn_mask = None |
|
|
|
cu_seqlens_q = get_cu_seqlens(text_mask, img_seq_len) |
|
cu_seqlens_kv = cu_seqlens_q |
|
|
|
freqs_cis = (freqs_cos, freqs_sin) if freqs_cos is not None else None |
|
|
|
block_args = [cu_seqlens_q, cu_seqlens_kv, max_seqlen_q, max_seqlen_kv, freqs_cis, attn_mask, self.upcast_rope] |
|
|
|
|
|
if self.enable_teacache: |
|
inp = img.clone() |
|
vec_ = vec.clone() |
|
txt_ = txt.clone() |
|
self.double_blocks[0].to(self.main_device) |
|
( |
|
img_mod1_shift, |
|
img_mod1_scale, |
|
img_mod1_gate, |
|
img_mod2_shift, |
|
img_mod2_scale, |
|
img_mod2_gate, |
|
) = self.double_blocks[0].img_mod(vec_).chunk(6, dim=-1) |
|
normed_inp = self.double_blocks[0].img_norm1(inp) |
|
modulated_inp = modulate( |
|
normed_inp, shift=img_mod1_shift, scale=img_mod1_scale |
|
) |
|
|
|
if self.cnt == 0 or self.cnt == self.num_steps-1: |
|
should_calc = True |
|
self.accumulated_rel_l1_distance = 0 |
|
self.previous_modulated_input = modulated_inp.clone() |
|
else: |
|
coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] |
|
rescale_func = np.poly1d(coefficients) |
|
self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) |
|
if self.accumulated_rel_l1_distance < self.rel_l1_thresh: |
|
should_calc = False |
|
else: |
|
should_calc = True |
|
self.accumulated_rel_l1_distance = 0 |
|
self.previous_modulated_input = modulated_inp.clone() |
|
self.cnt += 1 |
|
if self.cnt == self.num_steps: |
|
self.cnt = 0 |
|
|
|
if not should_calc and self.previous_residual is not None: |
|
|
|
if img.shape == self.previous_residual.shape: |
|
img = img + self.previous_residual |
|
else: |
|
should_calc = True |
|
|
|
if should_calc: |
|
ori_img = img.clone() |
|
|
|
img, txt = _process_double_blocks(img, txt, vec, block_args) |
|
|
|
x = torch.cat((img, txt), 1) |
|
x = _process_single_blocks(x, vec, txt.shape[1], block_args, stg_mode, stg_block_idx) |
|
|
|
img = x[:, :img_seq_len, ...] |
|
self.previous_residual = img - ori_img |
|
else: |
|
|
|
img, txt = _process_double_blocks(img, txt, vec, block_args) |
|
|
|
x = torch.cat((img, txt), 1) |
|
x = _process_single_blocks(x, vec, txt.shape[1], block_args, stg_mode, stg_block_idx) |
|
img = x[:, :img_seq_len, ...] |
|
|
|
|
|
img = self.final_layer(img, vec) |
|
|
|
img = self.unpatchify(img, tt, th, tw) |
|
if return_dict: |
|
out["x"] = img |
|
return out |
|
return img |
|
|
|
def unpatchify(self, x, t, h, w): |
|
""" |
|
x: (N, T, patch_size**2 * C) |
|
imgs: (N, H, W, C) |
|
""" |
|
c = self.unpatchify_channels |
|
pt, ph, pw = self.patch_size |
|
assert t * h * w == x.shape[1] |
|
|
|
x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw)) |
|
x = torch.einsum("nthwcopq->nctohpwq", x) |
|
imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw)) |
|
|
|
return imgs |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|