huaweilin's picture
update
14ce5a9
"""Building blocks for TiTok.
Copyright (2024) Bytedance Ltd. and/or its affiliates
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
Reference:
https://github.com/mlfoundations/open_clip/blob/main/src/open_clip/transformer.py
https://github.com/baofff/U-ViT/blob/main/libs/timm.py
"""
import math
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from collections import OrderedDict
import einops
from einops.layers.torch import Rearrange
def modulate(x, shift, scale):
return x * (1 + scale) + shift
class ResidualAttentionBlock(nn.Module):
def __init__(
self, d_model, n_head, mlp_ratio=4.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm
):
super().__init__()
self.ln_1 = norm_layer(d_model)
self.attn = nn.MultiheadAttention(d_model, n_head)
self.mlp_ratio = mlp_ratio
# optionally we can disable the FFN
if mlp_ratio > 0:
self.ln_2 = norm_layer(d_model)
mlp_width = int(d_model * mlp_ratio)
self.mlp = nn.Sequential(
OrderedDict(
[
("c_fc", nn.Linear(d_model, mlp_width)),
("gelu", act_layer()),
("c_proj", nn.Linear(mlp_width, d_model)),
]
)
)
def attention(self, x: torch.Tensor):
return self.attn(x, x, x, need_weights=False)[0]
def forward(
self,
x: torch.Tensor,
):
attn_output = self.attention(x=self.ln_1(x))
x = x + attn_output
if self.mlp_ratio > 0:
x = x + self.mlp(self.ln_2(x))
return x
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
ATTENTION_MODE = "flash"
else:
try:
import xformers
import xformers.ops
ATTENTION_MODE = "xformers"
except:
ATTENTION_MODE = "math"
print(f"attention mode is {ATTENTION_MODE}")
class Attention(nn.Module):
def __init__(
self,
dim,
num_heads=8,
qkv_bias=False,
qk_scale=None,
attn_drop=0.0,
proj_drop=0.0,
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, L, C = x.shape
qkv = self.qkv(x)
if ATTENTION_MODE == "flash":
qkv = einops.rearrange(
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
).float()
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = einops.rearrange(x, "B H L D -> B L (H D)")
elif ATTENTION_MODE == "xformers":
qkv = einops.rearrange(
qkv, "B L (K H D) -> K B L H D", K=3, H=self.num_heads
)
q, k, v = qkv[0], qkv[1], qkv[2] # B L H D
x = xformers.ops.memory_efficient_attention(q, k, v)
x = einops.rearrange(x, "B L H D -> B L (H D)", H=self.num_heads)
elif ATTENTION_MODE == "math":
qkv = einops.rearrange(
qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads
)
q, k, v = qkv[0], qkv[1], qkv[2] # B H L D
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
else:
raise NotImplemented
x = self.proj(x)
x = self.proj_drop(x)
return x
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (
x.ndim - 1
) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class Mlp(nn.Module):
def __init__(
self,
in_features,
hidden_features=None,
out_features=None,
act_layer=nn.GELU,
drop=0.0,
):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class UViTBlock(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
skip=False,
use_checkpoint=False,
):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
)
self.skip_linear = nn.Linear(2 * dim, dim) if skip else None
self.use_checkpoint = use_checkpoint
def forward(self, x, skip=None):
if self.use_checkpoint:
return torch.utils.checkpoint.checkpoint(self._forward, x, skip)
else:
return self._forward(x, skip)
def _forward(self, x, skip=None):
if self.skip_linear is not None:
x = self.skip_linear(torch.cat([x, skip], dim=-1))
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
def _expand_token(token, batch_size: int):
return token.unsqueeze(0).expand(batch_size, -1, -1)
class TiTokEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.image_size = config.dataset.preprocessing.crop_size
self.patch_size = config.model.vq_model.vit_enc_patch_size
self.grid_size = self.image_size // self.patch_size
self.model_size = config.model.vq_model.vit_enc_model_size
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
self.token_size = config.model.vq_model.token_size
if config.model.vq_model.get("quantize_mode", "vq") == "vae":
self.token_size = self.token_size * 2 # needs to split into mean and std
self.is_legacy = config.model.vq_model.get("is_legacy", True)
self.width = {
"small": 512,
"base": 768,
"large": 1024,
}[self.model_size]
self.num_layers = {
"small": 8,
"base": 12,
"large": 24,
}[self.model_size]
self.num_heads = {
"small": 8,
"base": 12,
"large": 16,
}[self.model_size]
self.patch_embed = nn.Conv2d(
in_channels=3,
out_channels=self.width,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=True,
)
scale = self.width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
self.positional_embedding = nn.Parameter(
scale * torch.randn(self.grid_size**2 + 1, self.width)
)
self.latent_token_positional_embedding = nn.Parameter(
scale * torch.randn(self.num_latent_tokens, self.width)
)
self.ln_pre = nn.LayerNorm(self.width)
self.transformer = nn.ModuleList()
for i in range(self.num_layers):
self.transformer.append(
ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0)
)
self.ln_post = nn.LayerNorm(self.width)
self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)
def forward(self, pixel_values, latent_tokens):
batch_size = pixel_values.shape[0]
x = pixel_values
x = self.patch_embed(x)
x = x.reshape(x.shape[0], x.shape[1], -1)
x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width]
# class embeddings and positional embeddings
x = torch.cat(
[_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1
)
x = x + self.positional_embedding.to(
x.dtype
) # shape = [*, grid ** 2 + 1, width]
latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)
latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(
x.dtype
)
x = torch.cat([x, latent_tokens], dim=1)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
for i in range(self.num_layers):
x = self.transformer[i](x)
x = x.permute(1, 0, 2) # LND -> NLD
latent_tokens = x[:, 1 + self.grid_size**2 :]
latent_tokens = self.ln_post(latent_tokens)
# fake 2D shape
if self.is_legacy:
latent_tokens = latent_tokens.reshape(
batch_size, self.width, self.num_latent_tokens, 1
)
else:
# Fix legacy problem.
latent_tokens = latent_tokens.reshape(
batch_size, self.num_latent_tokens, self.width, 1
).permute(0, 2, 1, 3)
latent_tokens = self.conv_out(latent_tokens)
latent_tokens = latent_tokens.reshape(
batch_size, self.token_size, 1, self.num_latent_tokens
)
return latent_tokens
class TiTokDecoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.image_size = config.dataset.preprocessing.crop_size
self.patch_size = config.model.vq_model.vit_dec_patch_size
self.grid_size = self.image_size // self.patch_size
self.model_size = config.model.vq_model.vit_dec_model_size
self.num_latent_tokens = config.model.vq_model.num_latent_tokens
self.token_size = config.model.vq_model.token_size
self.is_legacy = config.model.vq_model.get("is_legacy", True)
self.width = {
"small": 512,
"base": 768,
"large": 1024,
}[self.model_size]
self.num_layers = {
"small": 8,
"base": 12,
"large": 24,
}[self.model_size]
self.num_heads = {
"small": 8,
"base": 12,
"large": 16,
}[self.model_size]
self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True)
scale = self.width**-0.5
self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))
self.positional_embedding = nn.Parameter(
scale * torch.randn(self.grid_size**2 + 1, self.width)
)
# add mask token and query pos embed
self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))
self.latent_token_positional_embedding = nn.Parameter(
scale * torch.randn(self.num_latent_tokens, self.width)
)
self.ln_pre = nn.LayerNorm(self.width)
self.transformer = nn.ModuleList()
for i in range(self.num_layers):
self.transformer.append(
ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0)
)
self.ln_post = nn.LayerNorm(self.width)
if self.is_legacy:
self.ffn = nn.Sequential(
nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),
nn.Tanh(),
nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),
)
self.conv_out = nn.Identity()
else:
# Directly predicting RGB pixels
self.ffn = nn.Sequential(
nn.Conv2d(
self.width,
self.patch_size * self.patch_size * 3,
1,
padding=0,
bias=True,
),
Rearrange(
"b (p1 p2 c) h w -> b c (h p1) (w p2)",
p1=self.patch_size,
p2=self.patch_size,
),
)
self.conv_out = nn.Conv2d(3, 3, 3, padding=1, bias=True)
def forward(self, z_quantized):
N, C, H, W = z_quantized.shape
assert (
H == 1 and W == self.num_latent_tokens
), f"{H}, {W}, {self.num_latent_tokens}"
x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1) # NLD
x = self.decoder_embed(x)
batchsize, seq_len, _ = x.shape
mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(
x.dtype
)
mask_tokens = torch.cat(
[
_expand_token(self.class_embedding, mask_tokens.shape[0]).to(
mask_tokens.dtype
),
mask_tokens,
],
dim=1,
)
mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
x = x + self.latent_token_positional_embedding[:seq_len]
x = torch.cat([mask_tokens, x], dim=1)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
for i in range(self.num_layers):
x = self.transformer[i](x)
x = x.permute(1, 0, 2) # LND -> NLD
x = x[:, 1 : 1 + self.grid_size**2] # remove cls embed
x = self.ln_post(x)
# N L D -> N D H W
x = x.permute(0, 2, 1).reshape(
batchsize, self.width, self.grid_size, self.grid_size
)
x = self.ffn(x.contiguous())
x = self.conv_out(x)
return x
class TATiTokDecoder(TiTokDecoder):
def __init__(self, config):
super().__init__(config)
scale = self.width**-0.5
self.text_context_length = config.model.vq_model.get("text_context_length", 77)
self.text_embed_dim = config.model.vq_model.get("text_embed_dim", 768)
self.text_guidance_proj = nn.Linear(self.text_embed_dim, self.width)
self.text_guidance_positional_embedding = nn.Parameter(
scale * torch.randn(self.text_context_length, self.width)
)
def forward(self, z_quantized, text_guidance):
N, C, H, W = z_quantized.shape
assert (
H == 1 and W == self.num_latent_tokens
), f"{H}, {W}, {self.num_latent_tokens}"
x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1) # NLD
x = self.decoder_embed(x)
batchsize, seq_len, _ = x.shape
mask_tokens = self.mask_token.repeat(batchsize, self.grid_size**2, 1).to(
x.dtype
)
mask_tokens = torch.cat(
[
_expand_token(self.class_embedding, mask_tokens.shape[0]).to(
mask_tokens.dtype
),
mask_tokens,
],
dim=1,
)
mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)
x = x + self.latent_token_positional_embedding[:seq_len]
x = torch.cat([mask_tokens, x], dim=1)
text_guidance = self.text_guidance_proj(text_guidance)
text_guidance = text_guidance + self.text_guidance_positional_embedding
x = torch.cat([x, text_guidance], dim=1)
x = self.ln_pre(x)
x = x.permute(1, 0, 2) # NLD -> LND
for i in range(self.num_layers):
x = self.transformer[i](x)
x = x.permute(1, 0, 2) # LND -> NLD
x = x[:, 1 : 1 + self.grid_size**2] # remove cls embed
x = self.ln_post(x)
# N L D -> N D H W
x = x.permute(0, 2, 1).reshape(
batchsize, self.width, self.grid_size, self.grid_size
)
x = self.ffn(x.contiguous())
x = self.conv_out(x)
return x
class WeightTiedLMHead(nn.Module):
def __init__(self, embeddings, target_codebook_size):
super().__init__()
self.weight = embeddings.weight
self.target_codebook_size = target_codebook_size
def forward(self, x):
# x shape: [batch_size, seq_len, embed_dim]
# Get the weights for the target codebook size
weight = self.weight[
: self.target_codebook_size
] # Shape: [target_codebook_size, embed_dim]
# Compute the logits by matrix multiplication
logits = torch.matmul(
x, weight.t()
) # Shape: [batch_size, seq_len, target_codebook_size]
return logits
class TimestepEmbedder(nn.Module):
"""
Embeds scalar timesteps into vector representations.
"""
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size, bias=True),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param t: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an (N, D) Tensor of positional embeddings.
"""
# https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=t.device)
args = t[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
t_emb = self.mlp(t_freq)
return t_emb
class ResBlock(nn.Module):
"""
A residual block that can optionally change the number of channels.
:param channels: the number of input channels.
"""
def __init__(self, channels):
super().__init__()
self.channels = channels
self.in_ln = nn.LayerNorm(channels, eps=1e-6)
self.mlp = nn.Sequential(
nn.Linear(channels, channels, bias=True),
nn.SiLU(),
nn.Linear(channels, channels, bias=True),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(channels, 3 * channels, bias=True)
)
def forward(self, x, y):
shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(y).chunk(3, dim=-1)
h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
h = self.mlp(h)
return x + gate_mlp * h
class FinalLayer(nn.Module):
"""
The final layer adopted from DiT.
"""
def __init__(self, model_channels, out_channels):
super().__init__()
self.norm_final = nn.LayerNorm(
model_channels, elementwise_affine=False, eps=1e-6
)
self.linear = nn.Linear(model_channels, out_channels, bias=True)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(model_channels, 2 * model_channels, bias=True)
)
def forward(self, x, c):
shift, scale = self.adaLN_modulation(c).chunk(2, dim=-1)
x = modulate(self.norm_final(x), shift, scale)
x = self.linear(x)
return x
class SimpleMLPAdaLN(nn.Module):
"""
The MLP for Diffusion Loss.
:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param z_channels: channels in the condition.
:param num_res_blocks: number of residual blocks per downsample.
"""
def __init__(
self,
in_channels,
model_channels,
out_channels,
z_channels,
num_res_blocks,
grad_checkpointing=False,
):
super().__init__()
self.in_channels = in_channels
self.model_channels = model_channels
self.out_channels = out_channels
self.num_res_blocks = num_res_blocks
self.grad_checkpointing = grad_checkpointing
self.time_embed = TimestepEmbedder(model_channels)
self.cond_embed = nn.Linear(z_channels, model_channels)
self.input_proj = nn.Linear(in_channels, model_channels)
res_blocks = []
for i in range(num_res_blocks):
res_blocks.append(
ResBlock(
model_channels,
)
)
self.res_blocks = nn.ModuleList(res_blocks)
self.final_layer = FinalLayer(model_channels, out_channels)
self.initialize_weights()
def initialize_weights(self):
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
# Initialize timestep embedding MLP
nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
# Zero-out adaLN modulation layers
for block in self.res_blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
# Zero-out output layers
nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
nn.init.constant_(self.final_layer.linear.weight, 0)
nn.init.constant_(self.final_layer.linear.bias, 0)
def forward(self, x, t, c):
"""
Apply the model to an input batch.
:param x: an [N x C] Tensor of inputs.
:param t: a 1-D batch of timesteps.
:param c: conditioning from AR transformer.
:return: an [N x C] Tensor of outputs.
"""
x = self.input_proj(x)
t = self.time_embed(t)
c = self.cond_embed(c)
y = t + c
if self.grad_checkpointing and not torch.jit.is_scripting():
for block in self.res_blocks:
x = checkpoint(block, x, y)
else:
for block in self.res_blocks:
x = block(x, y)
return self.final_layer(x, y)
def forward_with_cfg(self, x, t, c, cfg_scale):
half = x[: len(x) // 2]
combined = torch.cat([half, half], dim=0)
model_out = self.forward(combined, t, c)
eps, rest = model_out[:, : self.in_channels], model_out[:, self.in_channels :]
cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
eps = torch.cat([half_eps, half_eps], dim=0)
return torch.cat([eps, rest], dim=1)