"""Building blocks for TiTok. Copyright (2024) Bytedance Ltd. and/or its affiliates Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. Reference: https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py https://github.com/baofff/U-ViT/blob/main/libs/timm.py """ import math import torch import torch.nn as nn from torch.utils.checkpoint import checkpoint from collections import OrderedDict import einops from einops.layers.torch import Rearrange def modulate(x, shift, scale): return x * (1 + scale) + shift class ResidualAttentionBlock(nn.Module): def __init__( self, d_model, n_head, mlp_ratio=4.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm ): super().__init__() self.ln_1 = norm_layer(d_model) self.attn = nn.MultiheadAttention(d_model, n_head) self.mlp_ratio = mlp_ratio # optionally we can disable the FFN if mlp_ratio > 0: self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.mlp = nn.Sequential( OrderedDict( [ ("c_fc", nn.Linear(d_model, mlp_width)), ("gelu", act_layer()), ("c_proj", nn.Linear(mlp_width, d_model)), ] ) ) def attention(self, x: torch.Tensor): return self.attn(x, x, x, need_weights=False)[0] def forward( self, x: torch.Tensor, ): attn_output = self.attention(x=self.ln_1(x)) x = x + attn_output if self.mlp_ratio > 0: x = x + self.mlp(self.ln_2(x)) return x if hasattr(torch.nn.functional, "scaled_dot_product_attention"): ATTENTION_MODE = "flash" else: try: import xformers import xformers.ops ATTENTION_MODE = "xformers" except: ATTENTION_MODE = "math" print(f"attention mode is {ATTENTION_MODE}") class Attention(nn.Module): def __init__( self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, ): super().__init__() self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim**-0.5 self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) def forward(self, x): B, L, C = x.shape qkv = self.qkv(x) if ATTENTION_MODE == "flash": qkv = einops.rearrange( qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads ).float() q, k, v = qkv[0], qkv[1], qkv[2] # B H L D x = torch.nn.functional.scaled_dot_product_attention(q, k, v) x = einops.rearrange(x, "B H L D -> B L (H D)") elif ATTENTION_MODE == "xformers": qkv = einops.rearrange( qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads ) q, k, v = qkv[0], qkv[1], qkv[2] # B L H D x = xformers.ops.memory_efficient_attention(q, k, v) x = einops.rearrange(x, "B L H D -> B L (H D)", H=self.num_heads) elif ATTENTION_MODE == "math": qkv = einops.rearrange( qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads ) q, k, v = qkv[0], qkv[1], qkv[2] # B H L D attn = (q @ k.transpose(-2, -1)) * self.scale attn = attn.softmax(dim=-1) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B, L, C) else: raise NotImplemented x = self.proj(x) x = self.proj_drop(x) return x def drop_path(x, drop_prob: float = 0.0, training: bool = False): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the argument. """ if drop_prob == 0.0 or not training: return x keep_prob = 1 - drop_prob shape = (x.shape[0],) + (1,) * ( x.ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) random_tensor.floor_() # binarize output = x.div(keep_prob) * random_tensor return output class DropPath(nn.Module): """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" def __init__(self, drop_prob=None): super(DropPath, self).__init__() self.drop_prob = drop_prob def forward(self, x): return drop_path(x, self.drop_prob, self.training) class Mlp(nn.Module): def __init__( self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0, ): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class UViTBlock(nn.Module): def __init__( self, dim, num_heads, mlp_ratio=4.0, qkv_bias=False, qk_scale=None, drop=0.0, attn_drop=0.0, drop_path=0.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, use_checkpoint=False, ): super().__init__() self.norm1 = norm_layer(dim) self.attn = Attention( dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, ) # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, ) self.skip_linear = nn.Linear(2 * dim, dim) if skip else None self.use_checkpoint = use_checkpoint def forward(self, x, skip=None): if self.use_checkpoint: return torch.utils.checkpoint.checkpoint(self._forward, x, skip) else: return self._forward(x, skip) def _forward(self, x, skip=None): if self.skip_linear is not None: x = self.skip_linear(torch.cat([x, skip], dim=-1)) x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x def _expand_token(token, batch_size: int): return token.unsqueeze(0).expand(batch_size, -1, -1) class TiTokEncoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.image_size = config.dataset.preprocessing.crop_size self.patch_size = config.model.vq_model.vit_enc_patch_size self.grid_size = self.image_size // self.patch_size self.model_size = config.model.vq_model.vit_enc_model_size self.num_latent_tokens = config.model.vq_model.num_latent_tokens self.token_size = config.model.vq_model.token_size if config.model.vq_model.get("quantize_mode", "vq") == "vae": self.token_size = self.token_size * 2 # needs to split into mean and std self.is_legacy = config.model.vq_model.get("is_legacy", True) self.width = { "small": 512, "base": 768, "large": 1024, }[self.model_size] self.num_layers = { "small": 8, "base": 12, "large": 24, }[self.model_size] self.num_heads = { "small": 8, "base": 12, "large": 16, }[self.model_size] self.patch_embed = nn.Conv2d( in_channels=3, out_channels=self.width, kernel_size=self.patch_size, stride=self.patch_size, bias=True, ) scale = self.width**-0.5 self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) self.positional_embedding = nn.Parameter( scale * torch.randn(self.grid_size**2 + 1, self.width) ) self.latent_token_positional_embedding = nn.Parameter( scale * torch.randn(self.num_latent_tokens, self.width) ) self.ln_pre = nn.LayerNorm(self.width) self.transformer = nn.ModuleList() for i in range(self.num_layers): self.transformer.append( ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0) ) self.ln_post = nn.LayerNorm(self.width) self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True) def forward(self, pixel_values, latent_tokens): batch_size = pixel_values.shape[0] x = pixel_values x = self.patch_embed(x) x = x.reshape(x.shape[0], x.shape[1], -1) x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] # class embeddings and positional embeddings x = torch.cat( [_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1 ) x = x + self.positional_embedding.to( x.dtype ) # shape = [*, grid ** 2 + 1, width] latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype) latent_tokens = latent_tokens + self.latent_token_positional_embedding.to( x.dtype ) x = torch.cat([x, latent_tokens], dim=1) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND for i in range(self.num_layers): x = self.transformer[i](x) x = x.permute(1, 0, 2) # LND -> NLD latent_tokens = x[:, 1 + self.grid_size**2 :] latent_tokens = self.ln_post(latent_tokens) # fake 2D shape if self.is_legacy: latent_tokens = latent_tokens.reshape( batch_size, self.width, self.num_latent_tokens, 1 ) else: # Fix legacy problem. latent_tokens = latent_tokens.reshape( batch_size, self.num_latent_tokens, self.width, 1 ).permute(0, 2, 1, 3) latent_tokens = self.conv_out(latent_tokens) latent_tokens = latent_tokens.reshape( batch_size, self.token_size, 1, self.num_latent_tokens ) return latent_tokens class TiTokDecoder(nn.Module): def __init__(self, config): super().__init__() self.config = config self.image_size = config.dataset.preprocessing.crop_size self.patch_size = config.model.vq_model.vit_dec_patch_size self.grid_size = self.image_size // self.patch_size self.model_size = config.model.vq_model.vit_dec_model_size self.num_latent_tokens = config.model.vq_model.num_latent_tokens self.token_size = config.model.vq_model.token_size self.is_legacy = config.model.vq_model.get("is_legacy", True) self.width = { "small": 512, "base": 768, "large": 1024, }[self.model_size] self.num_layers = { "small": 8, "base": 12, "large": 24, }[self.model_size] self.num_heads = { "small": 8, "base": 12, "large": 16, }[self.model_size] self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True) scale = self.width**-0.5 self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width)) self.positional_embedding = nn.Parameter( scale * torch.randn(self.grid_size**2 + 1, self.width) ) # add mask token and query pos embed self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width)) self.latent_token_positional_embedding = nn.Parameter( scale * torch.randn(self.num_latent_tokens, self.width) ) self.ln_pre = nn.LayerNorm(self.width) self.transformer = nn.ModuleList() for i in range(self.num_layers): self.transformer.append( ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0) ) self.ln_post = nn.LayerNorm(self.width) if self.is_legacy: self.ffn = nn.Sequential( nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True), nn.Tanh(), nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True), ) self.conv_out = nn.Identity() else: # Directly predicting RGB pixels self.ffn = nn.Sequential( nn.Conv2d( self.width, self.patch_size * self.patch_size * 3, 1, padding=0, bias=True, ), Rearrange( "b (p1 p2 c) h w -> b c (h p1) (w p2)", p1=self.patch_size, p2=self.patch_size, ), ) self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True) def forward(self, z_quantized): N, C, H, W = z_quantized.shape assert ( H == 1 and W == self.num_latent_tokens ), f"{H}, {W}, {self.num_latent_tokens}" x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1) # NLD x = self.decoder_embed(x) batchsize, seq_len, _ = x.shape mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to( x.dtype ) mask_tokens = torch.cat( [ _expand_token(self.class_embedding, mask_tokens.shape[0]).to( mask_tokens.dtype ), mask_tokens, ], dim=1, ) mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype) x = x + self.latent_token_positional_embedding[:seq_len] x = torch.cat([mask_tokens, x], dim=1) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND for i in range(self.num_layers): x = self.transformer[i](x) x = x.permute(1, 0, 2) # LND -> NLD x = x[:, 1 : 1 + self.grid_size**2] # remove cls embed x = self.ln_post(x) # N L D -> N D H W x = x.permute(0, 2, 1).reshape( batchsize, self.width, self.grid_size, self.grid_size ) x = self.ffn(x.contiguous()) x = self.conv_out(x) return x class TATiTokDecoder(TiTokDecoder): def __init__(self, config): super().__init__(config) scale = self.width**-0.5 self.text_context_length = config.model.vq_model.get("text_context_length", 77) self.text_embed_dim = config.model.vq_model.get("text_embed_dim", 768) self.text_guidance_proj = nn.Linear(self.text_embed_dim, self.width) self.text_guidance_positional_embedding = nn.Parameter( scale * torch.randn(self.text_context_length, self.width) ) def forward(self, z_quantized, text_guidance): N, C, H, W = z_quantized.shape assert ( H == 1 and W == self.num_latent_tokens ), f"{H}, {W}, {self.num_latent_tokens}" x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1) # NLD x = self.decoder_embed(x) batchsize, seq_len, _ = x.shape mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to( x.dtype ) mask_tokens = torch.cat( [ _expand_token(self.class_embedding, mask_tokens.shape[0]).to( mask_tokens.dtype ), mask_tokens, ], dim=1, ) mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype) x = x + self.latent_token_positional_embedding[:seq_len] x = torch.cat([mask_tokens, x], dim=1) text_guidance = self.text_guidance_proj(text_guidance) text_guidance = text_guidance + self.text_guidance_positional_embedding x = torch.cat([x, text_guidance], dim=1) x = self.ln_pre(x) x = x.permute(1, 0, 2) # NLD -> LND for i in range(self.num_layers): x = self.transformer[i](x) x = x.permute(1, 0, 2) # LND -> NLD x = x[:, 1 : 1 + self.grid_size**2] # remove cls embed x = self.ln_post(x) # N L D -> N D H W x = x.permute(0, 2, 1).reshape( batchsize, self.width, self.grid_size, self.grid_size ) x = self.ffn(x.contiguous()) x = self.conv_out(x) return x class WeightTiedLMHead(nn.Module): def __init__(self, embeddings, target_codebook_size): super().__init__() self.weight = embeddings.weight self.target_codebook_size = target_codebook_size def forward(self, x): # x shape: [batch_size, seq_len, embed_dim] # Get the weights for the target codebook size weight = self.weight[ : self.target_codebook_size ] # Shape: [target_codebook_size, embed_dim] # Compute the logits by matrix multiplication logits = torch.matmul( x, weight.t() ) # Shape: [batch_size, seq_len, target_codebook_size] return logits class TimestepEmbedder(nn.Module): """ Embeds scalar timesteps into vector representations. """ def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size, bias=True), nn.SiLU(), nn.Linear(hidden_size, hidden_size, bias=True), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): """ Create sinusoidal timestep embeddings. :param t: a 1-D Tensor of N indices, one per batch element. These may be fractional. :param dim: the dimension of the output. :param max_period: controls the minimum frequency of the embeddings. :return: an (N, D) Tensor of positional embeddings. """ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=t.device) args = t[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size) t_emb = self.mlp(t_freq) return t_emb class ResBlock(nn.Module): """ A residual block that can optionally change the number of channels. :param channels: the number of input channels. """ def __init__(self, channels): super().__init__() self.channels = channels self.in_ln = nn.LayerNorm(channels, eps=1e-6) self.mlp = nn.Sequential( nn.Linear(channels, channels, bias=True), nn.SiLU(), nn.Linear(channels, channels, bias=True), ) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True) ) def forward(self, x, y): shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1) h = modulate(self.in_ln(x), shift_mlp, scale_mlp) h = self.mlp(h) return x + gate_mlp * h class FinalLayer(nn.Module): """ The final layer adopted from DiT. """ def __init__(self, model_channels, out_channels): super().__init__() self.norm_final = nn.LayerNorm( model_channels, elementwise_affine=False, eps=1e-6 ) self.linear = nn.Linear(model_channels, out_channels, bias=True) self.adaLN_modulation = nn.Sequential( nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True) ) def forward(self, x, c): shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1) x = modulate(self.norm_final(x), shift, scale) x = self.linear(x) return x class SimpleMLPAdaLN(nn.Module): """ The MLP for Diffusion Loss. :param in_channels: channels in the input Tensor. :param model_channels: base channel count for the model. :param out_channels: channels in the output Tensor. :param z_channels: channels in the condition. :param num_res_blocks: number of residual blocks per downsample. """ def __init__( self, in_channels, model_channels, out_channels, z_channels, num_res_blocks, grad_checkpointing=False, ): super().__init__() self.in_channels = in_channels self.model_channels = model_channels self.out_channels = out_channels self.num_res_blocks = num_res_blocks self.grad_checkpointing = grad_checkpointing self.time_embed = TimestepEmbedder(model_channels) self.cond_embed = nn.Linear(z_channels, model_channels) self.input_proj = nn.Linear(in_channels, model_channels) res_blocks = [] for i in range(num_res_blocks): res_blocks.append( ResBlock( model_channels, ) ) self.res_blocks = nn.ModuleList(res_blocks) self.final_layer = FinalLayer(model_channels, out_channels) self.initialize_weights() def initialize_weights(self): def _basic_init(module): if isinstance(module, nn.Linear): torch.nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0) self.apply(_basic_init) # Initialize timestep embedding MLP nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02) nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02) # Zero-out adaLN modulation layers for block in self.res_blocks: nn.init.constant_(block.adaLN_modulation[-1].weight, 0) nn.init.constant_(block.adaLN_modulation[-1].bias, 0) # Zero-out output layers nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) nn.init.constant_(self.final_layer.linear.weight, 0) nn.init.constant_(self.final_layer.linear.bias, 0) def forward(self, x, t, c): """ Apply the model to an input batch. :param x: an [N x C] Tensor of inputs. :param t: a 1-D batch of timesteps. :param c: conditioning from AR transformer. :return: an [N x C] Tensor of outputs. """ x = self.input_proj(x) t = self.time_embed(t) c = self.cond_embed(c) y = t + c if self.grad_checkpointing and not torch.jit.is_scripting(): for block in self.res_blocks: x = checkpoint(block, x, y) else: for block in self.res_blocks: x = block(x, y) return self.final_layer(x, y) def forward_with_cfg(self, x, t, c, cfg_scale): half = x[: len(x) // 2] combined = torch.cat([half, half], dim=0) model_out = self.forward(combined, t, c) eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :] cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) eps = torch.cat([half_eps, half_eps], dim=0) return torch.cat([eps, rest], dim=1)