Spaces:
Running
on
Zero
Running
on
Zero
from typing import Tuple, Union | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from diffusers.models.embeddings import Timesteps, TimestepEmbedding, LabelEmbedding | |
import warnings | |
try: | |
from flash_attn import flash_attn_varlen_func | |
except ImportError: | |
warnings.warn("`flash-attn` is not installed. Training mode may not work properly.", UserWarning) | |
flash_attn_varlen_func = None | |
def apply_rotary_emb( | |
x: torch.Tensor, | |
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
cos, sin = freqs_cis.unbind(-1) | |
cos = cos[None, None] | |
sin = sin[None, None] | |
cos, sin = cos.to(x.device), sin.to(x.device) | |
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] | |
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) | |
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) | |
return out | |
class PatchEmbed(nn.Module): | |
def __init__(self, patch_size, in_channels, embed_dim, bias=True): | |
super().__init__() | |
self.proj = nn.Conv2d(in_channels, embed_dim, patch_size, patch_size, bias=bias) | |
def forward_unfold(self, x): | |
out_unfold = x.matmul(self.proj.weight.view(self.proj.weight.size(0), -1).t()) | |
if self.proj.bias is not None: | |
out_unfold += self.proj.bias.to(out_unfold.dtype) | |
return out_unfold | |
# force fp32 for strict numerical reproducibility (debug only) | |
# @torch.autocast('cuda', enabled=False) | |
def forward(self, x): | |
if self.training: | |
return self.forward_unfold(x) | |
out = self.proj(x) | |
out = out.flatten(2).transpose(1, 2) # BCHW -> BNC | |
return out | |
class AdaLayerNorm(nn.Module): | |
def __init__(self, embedding_dim): | |
super().__init__() | |
self.embedding_dim = embedding_dim | |
self.silu = nn.SiLU() | |
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True) | |
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6) | |
def forward(self, x, timestep, seqlen_list=None): | |
input_dtype = x.dtype | |
emb = self.linear(self.silu(timestep)) | |
if seqlen_list is not None: | |
# equivalent to `torch.repeat_interleave` but faster | |
emb = torch.cat([one_emb[None].expand(repeat_time, -1) for one_emb, repeat_time in zip(emb, seqlen_list)]) | |
else: | |
emb = emb.unsqueeze(1) | |
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = emb.float().chunk(6, dim=-1) | |
x = self.norm(x).float() * (1 + scale_msa) + shift_msa | |
return x.to(input_dtype), gate_msa, shift_mlp, scale_mlp, gate_mlp | |
class FeedForward(nn.Module): | |
def __init__(self, dim, dim_out=None, mult=4, inner_dim=None, bias=True): | |
super().__init__() | |
inner_dim = int(dim * mult) if inner_dim is None else inner_dim | |
dim_out = dim_out if dim_out is not None else dim | |
self.fc1 = nn.Linear(dim, inner_dim, bias=bias) | |
self.fc2 = nn.Linear(inner_dim, dim_out, bias=bias) | |
def forward(self, hidden_states): | |
hidden_states = self.fc1(hidden_states) | |
hidden_states = F.gelu(hidden_states, approximate="tanh") | |
hidden_states = self.fc2(hidden_states) | |
return hidden_states | |
class RMSNorm(nn.Module): | |
def __init__(self, dim: int, eps=1e-6): | |
super().__init__() | |
self.weight = nn.Parameter(torch.ones(dim)) | |
self.eps = eps | |
def forward(self, x): | |
output = x.float() * torch.rsqrt(x.float().pow(2).mean(-1, keepdim=True) + self.eps) | |
return (self.weight * output).to(x.dtype) | |
class Attention(nn.Module): | |
def __init__(self, q_dim, kv_dim=None, heads=8, head_dim=64, dropout=0.0, bias=False): | |
super().__init__() | |
self.q_dim = q_dim | |
self.kv_dim = kv_dim if kv_dim is not None else q_dim | |
self.inner_dim = head_dim * heads | |
self.dropout = dropout | |
self.head_dim = head_dim | |
self.num_heads = heads | |
self.q_proj = nn.Linear(self.q_dim, self.inner_dim, bias=bias) | |
self.k_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias) | |
self.v_proj = nn.Linear(self.kv_dim, self.inner_dim, bias=bias) | |
self.o_proj = nn.Linear(self.inner_dim, self.q_dim, bias=bias) | |
self.q_norm = RMSNorm(self.inner_dim) | |
self.k_norm = RMSNorm(self.inner_dim) | |
def prepare_attention_mask( | |
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py#L694 | |
self, attention_mask: torch.Tensor, target_length: int, batch_size: int, out_dim: int = 3 | |
): | |
head_size = self.num_heads | |
if attention_mask is None: | |
return attention_mask | |
current_length: int = attention_mask.shape[-1] | |
if current_length != target_length: | |
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) | |
if out_dim == 3: | |
if attention_mask.shape[0] < batch_size * head_size: | |
attention_mask = attention_mask.repeat_interleave(head_size, dim=0) | |
elif out_dim == 4: | |
attention_mask = attention_mask.unsqueeze(1) | |
attention_mask = attention_mask.repeat_interleave(head_size, dim=1) | |
return attention_mask | |
def forward( | |
self, | |
inputs_q, | |
inputs_kv, | |
attention_mask=None, | |
cross_attention=False, | |
rope_pos_embed=None, | |
cu_seqlens_q=None, | |
cu_seqlens_k=None, | |
max_seqlen_q=None, | |
max_seqlen_k=None, | |
): | |
inputs_kv = inputs_q if inputs_kv is None else inputs_kv | |
query_states = self.q_proj(inputs_q) | |
key_states = self.k_proj(inputs_kv) | |
value_states = self.v_proj(inputs_kv) | |
query_states = self.q_norm(query_states) | |
key_states = self.k_norm(key_states) | |
if max_seqlen_q is None: | |
assert not self.training, "PixelFlow needs sequence packing for training" | |
bsz, q_len, _ = inputs_q.shape | |
_, kv_len, _ = inputs_kv.shape | |
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) | |
key_states = key_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) | |
value_states = value_states.view(bsz, kv_len, self.num_heads, self.head_dim).transpose(1, 2) | |
query_states = apply_rotary_emb(query_states, rope_pos_embed) | |
if not cross_attention: | |
key_states = apply_rotary_emb(key_states, rope_pos_embed) | |
if attention_mask is not None: | |
attention_mask = self.prepare_attention_mask(attention_mask, kv_len, bsz) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view(bsz, self.num_heads, -1, attention_mask.shape[-1]) | |
# with torch.nn.attention.sdpa_kernel(backends=[torch.nn.attention.SDPBackend.MATH]): # strict numerical reproducibility (debug only) | |
attn_output = F.scaled_dot_product_attention( | |
query_states, | |
key_states, | |
value_states, | |
attn_mask=attention_mask, | |
dropout_p=self.dropout if self.training else 0.0, | |
is_causal=False, | |
) | |
attn_output = attn_output.transpose(1, 2).contiguous() | |
attn_output = attn_output.view(bsz, q_len, self.inner_dim) | |
attn_output = self.o_proj(attn_output) | |
return attn_output | |
else: | |
# sequence packing mode | |
query_states = query_states.view(-1, self.num_heads, self.head_dim) | |
key_states = key_states.view(-1, self.num_heads, self.head_dim) | |
value_states = value_states.view(-1, self.num_heads, self.head_dim) | |
query_states = apply_rotary_emb(query_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2) | |
if not cross_attention: | |
key_states = apply_rotary_emb(key_states.permute(1, 0, 2)[None], rope_pos_embed)[0].permute(1, 0, 2) | |
attn_output = flash_attn_varlen_func( | |
query_states, | |
key_states, | |
value_states, | |
cu_seqlens_q=cu_seqlens_q, | |
cu_seqlens_k=cu_seqlens_k, | |
max_seqlen_q=max_seqlen_q, | |
max_seqlen_k=max_seqlen_k, | |
) | |
attn_output = attn_output.view(-1, self.num_heads * self.head_dim) | |
attn_output = self.o_proj(attn_output) | |
return attn_output | |
class TransformerBlock(nn.Module): | |
def __init__(self, dim, num_attention_heads, attention_head_dim, dropout=0.0, | |
cross_attention_dim=None, attention_bias=False, | |
): | |
super().__init__() | |
self.norm1 = AdaLayerNorm(dim) | |
# Self Attention | |
self.attn1 = Attention(q_dim=dim, kv_dim=None, heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, bias=attention_bias) | |
if cross_attention_dim is not None: | |
# Cross Attention | |
self.norm2 = RMSNorm(dim, eps=1e-6) | |
self.attn2 = Attention(q_dim=dim, kv_dim=cross_attention_dim, heads=num_attention_heads, head_dim=attention_head_dim, dropout=dropout, bias=attention_bias) | |
else: | |
self.attn2 = None | |
self.norm3 = RMSNorm(dim, eps=1e-6) | |
self.mlp = FeedForward(dim) | |
def forward( | |
self, | |
hidden_states, | |
encoder_hidden_states=None, | |
encoder_attention_mask=None, | |
timestep=None, | |
rope_pos_embed=None, | |
cu_seqlens_q=None, | |
cu_seqlens_k=None, | |
seqlen_list_q=None, | |
seqlen_list_k=None, | |
): | |
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, timestep, seqlen_list_q) | |
attn_output = self.attn1( | |
inputs_q=norm_hidden_states, | |
inputs_kv=None, | |
attention_mask=None, | |
cross_attention=False, | |
rope_pos_embed=rope_pos_embed, | |
cu_seqlens_q=cu_seqlens_q, | |
cu_seqlens_k=cu_seqlens_q, | |
max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None, | |
max_seqlen_k=max(seqlen_list_q) if seqlen_list_q is not None else None, | |
) | |
attn_output = (gate_msa * attn_output.float()).to(attn_output.dtype) | |
hidden_states = attn_output + hidden_states | |
if self.attn2 is not None: | |
norm_hidden_states = self.norm2(hidden_states) | |
attn_output = self.attn2( | |
inputs_q=norm_hidden_states, | |
inputs_kv=encoder_hidden_states, | |
attention_mask=encoder_attention_mask, | |
cross_attention=True, | |
rope_pos_embed=rope_pos_embed, | |
cu_seqlens_q=cu_seqlens_q, | |
cu_seqlens_k=cu_seqlens_k, | |
max_seqlen_q=max(seqlen_list_q) if seqlen_list_q is not None else None, | |
max_seqlen_k=max(seqlen_list_k) if seqlen_list_k is not None else None, | |
) | |
hidden_states = hidden_states + attn_output | |
norm_hidden_states = self.norm3(hidden_states) | |
norm_hidden_states = (norm_hidden_states.float() * (1 + scale_mlp) + shift_mlp).to(norm_hidden_states.dtype) | |
ff_output = self.mlp(norm_hidden_states) | |
ff_output = (gate_mlp * ff_output.float()).to(ff_output.dtype) | |
hidden_states = ff_output + hidden_states | |
return hidden_states | |
class PixelFlowModel(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, num_attention_heads, attention_head_dim, | |
depth, patch_size, dropout=0.0, cross_attention_dim=None, attention_bias=True, num_classes=0, | |
): | |
super().__init__() | |
self.patch_size = patch_size | |
self.attention_head_dim = attention_head_dim | |
self.num_classes = num_classes | |
self.out_channels = out_channels | |
embed_dim = num_attention_heads * attention_head_dim | |
self.patch_embed = PatchEmbed(patch_size=patch_size, in_channels=in_channels, embed_dim=embed_dim) | |
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) | |
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim) | |
# [stage] embedding | |
self.latent_size_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embed_dim) | |
if self.num_classes > 0: | |
# class conditional | |
self.class_embedder = LabelEmbedding(num_classes, embed_dim, dropout_prob=0.1) | |
self.transformer_blocks = nn.ModuleList([ | |
TransformerBlock(embed_dim, num_attention_heads, attention_head_dim, dropout, cross_attention_dim, attention_bias) for _ in range(depth) | |
]) | |
self.norm_out = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) | |
self.proj_out_1 = nn.Linear(embed_dim, 2 * embed_dim) | |
self.proj_out_2 = nn.Linear(embed_dim, patch_size * patch_size * out_channels) | |
self.initialize_from_scratch() | |
def initialize_from_scratch(self): | |
print("Starting Initialization...") | |
def _basic_init(module): | |
if isinstance(module, nn.Linear): | |
torch.nn.init.xavier_uniform_(module.weight) | |
if module.bias is not None: | |
nn.init.constant_(module.bias, 0) | |
self.apply(_basic_init) | |
# Initialize patch_embed like nn.Linear (instead of nn.Conv2d): | |
w = self.patch_embed.proj.weight.data | |
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
nn.init.constant_(self.patch_embed.proj.bias, 0) | |
nn.init.normal_(self.timestep_embedder.linear_1.weight, std=0.02) | |
nn.init.normal_(self.timestep_embedder.linear_2.weight, std=0.02) | |
nn.init.normal_(self.latent_size_embedder.linear_1.weight, std=0.02) | |
nn.init.normal_(self.latent_size_embedder.linear_2.weight, std=0.02) | |
if self.num_classes > 0: | |
nn.init.normal_(self.class_embedder.embedding_table.weight, std=0.02) | |
for block in self.transformer_blocks: | |
nn.init.constant_(block.norm1.linear.weight, 0) | |
nn.init.constant_(block.norm1.linear.bias, 0) | |
nn.init.constant_(self.proj_out_1.weight, 0) | |
nn.init.constant_(self.proj_out_1.bias, 0) | |
nn.init.constant_(self.proj_out_2.weight, 0) | |
nn.init.constant_(self.proj_out_2.bias, 0) | |
def forward( | |
self, | |
hidden_states, | |
encoder_hidden_states=None, | |
class_labels=None, | |
timestep=None, | |
latent_size=None, | |
encoder_attention_mask=None, | |
pos_embed=None, | |
cu_seqlens_q=None, | |
cu_seqlens_k=None, | |
seqlen_list_q=None, | |
seqlen_list_k=None, | |
): | |
# convert encoder_attention_mask to a bias the same way we do for attention_mask | |
if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: | |
encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 | |
encoder_attention_mask = encoder_attention_mask.unsqueeze(1) | |
orig_height, orig_width = hidden_states.shape[-2], hidden_states.shape[-1] | |
hidden_states = hidden_states.to(torch.float32) | |
hidden_states = self.patch_embed(hidden_states) | |
# timestep, class_embed, latent_size_embed | |
timesteps_proj = self.time_proj(timestep) | |
conditioning = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) | |
if self.num_classes > 0: | |
class_embed = self.class_embedder(class_labels) | |
conditioning += class_embed | |
latent_size_proj = self.time_proj(latent_size) | |
latent_size_embed = self.latent_size_embedder(latent_size_proj.to(dtype=hidden_states.dtype)) | |
conditioning += latent_size_embed | |
for block in self.transformer_blocks: | |
hidden_states = block( | |
hidden_states, | |
encoder_hidden_states=encoder_hidden_states, | |
encoder_attention_mask=encoder_attention_mask, | |
timestep=conditioning, | |
rope_pos_embed=pos_embed, | |
cu_seqlens_q=cu_seqlens_q, | |
cu_seqlens_k=cu_seqlens_k, | |
seqlen_list_q=seqlen_list_q, | |
seqlen_list_k=seqlen_list_k, | |
) | |
shift, scale = self.proj_out_1(F.silu(conditioning)).float().chunk(2, dim=1) | |
if seqlen_list_q is None: | |
shift = shift.unsqueeze(1) | |
scale = scale.unsqueeze(1) | |
else: | |
shift = torch.cat([shift_i[None].expand(ri, -1) for shift_i, ri in zip(shift, seqlen_list_q)]) | |
scale = torch.cat([scale_i[None].expand(ri, -1) for scale_i, ri in zip(scale, seqlen_list_q)]) | |
hidden_states = (self.norm_out(hidden_states).float() * (1 + scale) + shift).to(hidden_states.dtype) | |
hidden_states = self.proj_out_2(hidden_states) | |
if self.training: | |
hidden_states = hidden_states.reshape(hidden_states.shape[0], self.patch_size, self.patch_size, self.out_channels) | |
hidden_states = hidden_states.permute(0, 3, 1, 2).flatten(1) | |
return hidden_states | |
height, width = orig_height // self.patch_size, orig_width // self.patch_size | |
hidden_states = hidden_states.reshape( | |
shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) | |
) | |
hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) | |
output = hidden_states.reshape( | |
shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) | |
) | |
return output | |
def c2i_forward_cfg_torchdiffq(self, hidden_states, timestep, class_labels, latent_size, pos_embed, cfg_scale): | |
# used for evaluation with ODE ('dopri5') solver from torchdiffeq | |
half = hidden_states[: len(hidden_states)//2] | |
combined = torch.cat([half, half], dim=0) | |
out = self.forward( | |
hidden_states=combined, | |
timestep=timestep, | |
class_labels=class_labels, | |
latent_size=latent_size, | |
pos_embed=pos_embed, | |
) | |
uncond_out, cond_out = torch.split(out, len(out)//2, dim=0) | |
half_output = uncond_out + cfg_scale * (cond_out - uncond_out) | |
output = torch.cat([half_output, half_output], dim=0) | |
return output | |