PixelFlow / pixelflow /model.py
ShoufaChen's picture
init
137645c verified
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.embeddings import Timesteps, TimestepEmbedding, LabelEmbedding
import warnings
try:
from flash_attn import flash_attn_varlen_func
except ImportError:
warnings.warn("`flash-attn` is not installed. Training mode may not work properly.", UserWarning)
flash_attn_varlen_func = None
def apply_rotary_emb(
x: torch.Tensor,
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
) -> Tuple[torch.Tensor, torch.Tensor]:
cos, sin = freqs_cis.unbind(-1)
cos = cos[None, None]
sin = sin[None, None]
cos, sin = cos.to(x.device), sin.to(x.device)
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
return out
class PatchEmbed(nn.Module):
def __init__(self, patch_size, in_channels, embed_dim, bias=True):
super().__init__()
self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias)
def forward_unfold(self, x):
out_unfold = x.matmul(self.proj.weight.view(self.proj.weight.size(0), -1).t())
if self.proj.bias is not None:
out_unfold += self.proj.bias.to(out_unfold.dtype)
return out_unfold
# force fp32 for strict numerical reproducibility (debug only)
# @torch.autocast('cuda', enabled=False)
def forward(self, x):
if self.training:
return self.forward_unfold(x)
out = self.proj(x)
out = out.flatten(2).transpose(1, 2) # BCHW -> BNC
return out
class AdaLayerNorm(nn.Module):
def __init__(self, embedding_dim):
super().__init__()
self.embedding_dim = embedding_dim
self.silu = nn.SiLU()
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, timestep, seqlen_list=None):
input_dtype = x.dtype
emb = self.linear(self.silu(timestep))
if seqlen_list is not None:
# equivalent to `torch.repeat_interleave` but faster
emb = torch.cat([one_emb[None].expand(repeat_time, -1) for one_emb, repeat_time in zip(emb, seqlen_list)])
else:
emb = emb.unsqueeze(1)
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.float().chunk(6, dim=-1)
x = self.norm(x).float() * (1 + scale_msa) + shift_msa
return x.to(input_dtype), gate_msa, shift_mlp, scale_mlp, gate_mlp
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, inner_dim=None, bias=True):
super().__init__()
inner_dim = int(dim * mult) if inner_dim is None else inner_dim
dim_out = dim_out if dim_out is not None else dim
self.fc1 = nn.Linear(dim, inner_dim, bias=bias)
self.fc2 = nn.Linear(inner_dim, dim_out, bias=bias)
def forward(self, hidden_states):
hidden_states = self.fc1(hidden_states)
hidden_states = F.gelu(hidden_states, approximate="tanh")
hidden_states = self.fc2(hidden_states)
return hidden_states
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps=1e-6):
super().__init__()
self.weight = nn.Parameter(torch.ones(dim))
self.eps = eps
def forward(self, x):
output = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps)
return (self.weight * output).to(x.dtype)
class Attention(nn.Module):
def __init__(self, q_dim, kv_dim=None, heads=8, head_dim=64, dropout=0.0, bias=False):
super().__init__()
self.q_dim = q_dim
self.kv_dim = kv_dim if kv_dim is not None else q_dim
self.inner_dim = head_dim * heads
self.dropout = dropout
self.head_dim = head_dim
self.num_heads = heads
self.q_proj = nn.Linear(self.q_dim, self.inner_dim, bias=bias)
self.k_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias)
self.v_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias)
self.o_proj = nn.Linear(self.inner_dim, self.q_dim, bias=bias)
self.q_norm = RMSNorm(self.inner_dim)
self.k_norm = RMSNorm(self.inner_dim)
def prepare_attention_mask(
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L694
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3
):
head_size = self.num_heads
if attention_mask is None:
return attention_mask
current_length: int = attention_mask.shape[-1]
if current_length != target_length:
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
if out_dim == 3:
if attention_mask.shape[0] < batch_size * head_size:
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
elif out_dim == 4:
attention_mask = attention_mask.unsqueeze(1)
attention_mask = attention_mask.repeat_interleave(head_size, dim=1)
return attention_mask
def forward(
self,
inputs_q,
inputs_kv,
attention_mask=None,
cross_attention=False,
rope_pos_embed=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
max_seqlen_q=None,
max_seqlen_k=None,
):
inputs_kv = inputs_q if inputs_kv is None else inputs_kv
query_states = self.q_proj(inputs_q)
key_states = self.k_proj(inputs_kv)
value_states = self.v_proj(inputs_kv)
query_states = self.q_norm(query_states)
key_states = self.k_norm(key_states)
if max_seqlen_q is None:
assert not self.training, "PixelFlow needs sequence packing for training"
bsz, q_len, _ = inputs_q.shape
_, kv_len, _ = inputs_kv.shape
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2)
query_states = apply_rotary_emb(query_states, rope_pos_embed)
if not cross_attention:
key_states = apply_rotary_emb(key_states, rope_pos_embed)
if attention_mask is not None:
attention_mask = self.prepare_attention_mask(attention_mask, kv_len, bsz)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(bsz, self.num_heads, -1, attention_mask.shape[-1])
# with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]): # strict numerical reproducibility (debug only)
attn_output = F.scaled_dot_product_attention(
query_states,
key_states,
value_states,
attn_mask=attention_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=False,
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(bsz, q_len, self.inner_dim)
attn_output = self.o_proj(attn_output)
return attn_output
else:
# sequence packing mode
query_states = query_states.view(-1, self.num_heads, self.head_dim)
key_states = key_states.view(-1, self.num_heads, self.head_dim)
value_states = value_states.view(-1, self.num_heads, self.head_dim)
query_states = apply_rotary_emb(query_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2)
if not cross_attention:
key_states = apply_rotary_emb(key_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2)
attn_output = flash_attn_varlen_func(
query_states,
key_states,
value_states,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max_seqlen_q,
max_seqlen_k=max_seqlen_k,
)
attn_output = attn_output.view(-1, self.num_heads * self.head_dim)
attn_output = self.o_proj(attn_output)
return attn_output
class TransformerBlock(nn.Module):
def __init__(self, dim, num_attention_heads, attention_head_dim, dropout=0.0,
cross_attention_dim=None, attention_bias=False,
):
super().__init__()
self.norm1 = AdaLayerNorm(dim)
# Self Attention
self.attn1 = Attention(q_dim=dim, kv_dim=None, heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, bias=attention_bias)
if cross_attention_dim is not None:
# Cross Attention
self.norm2 = RMSNorm(dim, eps=1e-6)
self.attn2 = Attention(q_dim=dim, kv_dim=cross_attention_dim, heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, bias=attention_bias)
else:
self.attn2 = None
self.norm3 = RMSNorm(dim, eps=1e-6)
self.mlp = FeedForward(dim)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
encoder_attention_mask=None,
timestep=None,
rope_pos_embed=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
seqlen_list_q=None,
seqlen_list_k=None,
):
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, timestep, seqlen_list_q)
attn_output = self.attn1(
inputs_q=norm_hidden_states,
inputs_kv=None,
attention_mask=None,
cross_attention=False,
rope_pos_embed=rope_pos_embed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_q,
max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None,
max_seqlen_k=max(seqlen_list_q) if seqlen_list_q is not None else None,
)
attn_output = (gate_msa * attn_output.float()).to(attn_output.dtype)
hidden_states = attn_output + hidden_states
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(
inputs_q=norm_hidden_states,
inputs_kv=encoder_hidden_states,
attention_mask=encoder_attention_mask,
cross_attention=True,
rope_pos_embed=rope_pos_embed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None,
max_seqlen_k=max(seqlen_list_k) if seqlen_list_k is not None else None,
)
hidden_states = hidden_states + attn_output
norm_hidden_states = self.norm3(hidden_states)
norm_hidden_states = (norm_hidden_states.float() * (1 + scale_mlp) + shift_mlp).to(norm_hidden_states.dtype)
ff_output = self.mlp(norm_hidden_states)
ff_output = (gate_mlp * ff_output.float()).to(ff_output.dtype)
hidden_states = ff_output + hidden_states
return hidden_states
class PixelFlowModel(torch.nn.Module):
def __init__(self, in_channels, out_channels, num_attention_heads, attention_head_dim,
depth, patch_size, dropout=0.0, cross_attention_dim=None, attention_bias=True, num_classes=0,
):
super().__init__()
self.patch_size = patch_size
self.attention_head_dim = attention_head_dim
self.num_classes = num_classes
self.out_channels = out_channels
embed_dim = num_attention_heads * attention_head_dim
self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim)
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1)
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim)
# [stage] embedding
self.latent_size_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim)
if self.num_classes > 0:
# class conditional
self.class_embedder = LabelEmbedding(num_classes, embed_dim, dropout_prob=0.1)
self.transformer_blocks = nn.ModuleList([
TransformerBlock(embed_dim, num_attention_heads, attention_head_dim, dropout, cross_attention_dim, attention_bias) for _ in range(depth)
])
self.norm_out = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6)
self.proj_out_1 = nn.Linear(embed_dim, 2 * embed_dim)
self.proj_out_2 = nn.Linear(embed_dim, patch_size * patch_size * out_channels)
self.initialize_from_scratch()
def initialize_from_scratch(self):
print("Starting Initialization...")
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
w = self.patch_embed.proj.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
nn.init.constant_(self.patch_embed.proj.bias, 0)
nn.init.normal_(self.timestep_embedder.linear_1.weight, std=0.02)
nn.init.normal_(self.timestep_embedder.linear_2.weight, std=0.02)
nn.init.normal_(self.latent_size_embedder.linear_1.weight, std=0.02)
nn.init.normal_(self.latent_size_embedder.linear_2.weight, std=0.02)
if self.num_classes > 0:
nn.init.normal_(self.class_embedder.embedding_table.weight, std=0.02)
for block in self.transformer_blocks:
nn.init.constant_(block.norm1.linear.weight, 0)
nn.init.constant_(block.norm1.linear.bias, 0)
nn.init.constant_(self.proj_out_1.weight, 0)
nn.init.constant_(self.proj_out_1.bias, 0)
nn.init.constant_(self.proj_out_2.weight, 0)
nn.init.constant_(self.proj_out_2.bias, 0)
def forward(
self,
hidden_states,
encoder_hidden_states=None,
class_labels=None,
timestep=None,
latent_size=None,
encoder_attention_mask=None,
pos_embed=None,
cu_seqlens_q=None,
cu_seqlens_k=None,
seqlen_list_q=None,
seqlen_list_k=None,
):
# convert encoder_attention_mask to a bias the same way we do for attention_mask
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
orig_height, orig_width = hidden_states.shape[-2], hidden_states.shape[-1]
hidden_states = hidden_states.to(torch.float32)
hidden_states = self.patch_embed(hidden_states)
# timestep, class_embed, latent_size_embed
timesteps_proj = self.time_proj(timestep)
conditioning = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype))
if self.num_classes > 0:
class_embed = self.class_embedder(class_labels)
conditioning += class_embed
latent_size_proj = self.time_proj(latent_size)
latent_size_embed = self.latent_size_embedder(latent_size_proj.to(dtype=hidden_states.dtype))
conditioning += latent_size_embed
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
timestep=conditioning,
rope_pos_embed=pos_embed,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_k=cu_seqlens_k,
seqlen_list_q=seqlen_list_q,
seqlen_list_k=seqlen_list_k,
)
shift, scale = self.proj_out_1(F.silu(conditioning)).float().chunk(2, dim=1)
if seqlen_list_q is None:
shift = shift.unsqueeze(1)
scale = scale.unsqueeze(1)
else:
shift = torch.cat([shift_i[None].expand(ri, -1) for shift_i, ri in zip(shift, seqlen_list_q)])
scale = torch.cat([scale_i[None].expand(ri, -1) for scale_i, ri in zip(scale, seqlen_list_q)])
hidden_states = (self.norm_out(hidden_states).float() * (1 + scale) + shift).to(hidden_states.dtype)
hidden_states = self.proj_out_2(hidden_states)
if self.training:
hidden_states = hidden_states.reshape(hidden_states.shape[0], self.patch_size, self.patch_size, self.out_channels)
hidden_states = hidden_states.permute(0, 3, 1, 2).flatten(1)
return hidden_states
height, width = orig_height // self.patch_size, orig_width // self.patch_size
hidden_states = hidden_states.reshape(
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
)
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
output = hidden_states.reshape(
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
)
return output
def c2i_forward_cfg_torchdiffq(self, hidden_states, timestep, class_labels, latent_size, pos_embed, cfg_scale):
# used for evaluation with ODE ('dopri5') solver from torchdiffeq
half = hidden_states[: len(hidden_states)//2]
combined = torch.cat([half, half], dim=0)
out = self.forward(
hidden_states=combined,
timestep=timestep,
class_labels=class_labels,
latent_size=latent_size,
pos_embed=pos_embed,
)
uncond_out, cond_out = torch.split(out, len(out)//2, dim=0)
half_output = uncond_out + cfg_scale * (cond_out - uncond_out)
output = torch.cat([half_output, half_output], dim=0)
return output