Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from typing import Optional, Tuple, Union | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.checkpoint | |
from einops import rearrange, repeat | |
from diffusers.models.embeddings import ( | |
GaussianFourierProjection, | |
TimestepEmbedding, | |
Timesteps, | |
) | |
from diffusers.models.unet_2d_blocks import ( | |
DownBlock2D, | |
UpBlock2D, | |
) | |
from diffusers.models.resnet import Downsample2D, Upsample2D | |
from diffusers.configuration_utils import ConfigMixin | |
from diffusers.models.modeling_utils import ModelMixin | |
# xFormers imports | |
try: | |
from xformers.ops import memory_efficient_attention, unbind | |
XFORMERS_AVAILABLE = True | |
except ImportError: | |
print("xFormers not available") | |
XFORMERS_AVAILABLE = False | |
def modulate(x, shift, scale): | |
return x * (1.0 + scale.unsqueeze(1)) + shift.unsqueeze(1) | |
def pair(t): | |
return t if isinstance(t, tuple) else (t, t) | |
def build_2d_sincos_posemb(h, w, embed_dim=1024, temperature=10000.): | |
"""Sine-cosine positional embeddings as used in MoCo-v3 | |
Returns positional embedding of shape [B, H, W, D] | |
""" | |
grid_w = torch.arange(w, dtype=torch.float32) | |
grid_h = torch.arange(h, dtype=torch.float32) | |
grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing='ij') | |
assert embed_dim % 4 == 0, 'Embed dimension must be divisible by 4 for 2D sin-cos position embedding' | |
pos_dim = embed_dim // 4 | |
omega = torch.arange(pos_dim, dtype=torch.float32) / pos_dim | |
omega = 1. / (temperature ** omega) | |
out_w = torch.einsum('m,d->md', [grid_w.flatten(), omega]) | |
out_h = torch.einsum('m,d->md', [grid_h.flatten(), omega]) | |
pos_emb = torch.cat([torch.sin(out_w), torch.cos(out_w), torch.sin(out_h), torch.cos(out_h)], dim=1)[None, :, :] | |
pos_emb = rearrange(pos_emb, 'b (h w) d -> b d h w', h=h, w=w) | |
return pos_emb | |
def drop_path(x, drop_prob: float = 0., training: bool = False): | |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, | |
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... | |
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for | |
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use | |
'survival rate' as the argument. | |
""" | |
if drop_prob == 0. or not training: | |
return x | |
keep_prob = 1 - drop_prob | |
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets | |
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) | |
random_tensor.floor_() # binarize | |
output = x.div(keep_prob) * random_tensor | |
return output | |
class DropPath(nn.Module): | |
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). | |
""" | |
def __init__(self, drop_prob=None): | |
super(DropPath, self).__init__() | |
self.drop_prob = drop_prob | |
def forward(self, x): | |
return drop_path(x, self.drop_prob, self.training) | |
def extra_repr(self) -> str: | |
return 'p={}'.format(self.drop_prob) | |
class Mlp(nn.Module): | |
def __init__(self, in_features, temb_dim=None, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): | |
super().__init__() | |
out_features = out_features or in_features | |
self.hidden_features = hidden_features or in_features | |
self.fc1 = nn.Linear(in_features, self.hidden_features) | |
self.act = act_layer() | |
self.fc2 = nn.Linear(self.hidden_features, out_features) | |
self.drop = nn.Dropout(drop) | |
if temb_dim is not None: | |
self.adaLN_modulation = nn.Linear(temb_dim, 2 * self.hidden_features) | |
def forward(self, x, temb=None): | |
x = self.fc1(x) | |
x = self.act(x) | |
# Shift and scale using time emb (see https://arxiv.org/abs/2301.11093) | |
if hasattr(self, 'adaLN_modulation'): | |
shift, scale = self.adaLN_modulation(F.silu(temb)).chunk(2, dim=-1) | |
x = modulate(x, shift, scale) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class Attention(nn.Module): | |
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): | |
super().__init__() | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = head_dim ** -0.5 | |
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x, mask=None): | |
B, N, C = x.shape | |
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) | |
if XFORMERS_AVAILABLE: | |
q, k, v = unbind(qkv, 2) | |
if mask is not None: | |
# Wherever mask is True it becomes -infinity, otherwise 0 | |
mask = mask.to(q.dtype) * -torch.finfo(q.dtype).max | |
x = memory_efficient_attention(q, k, v, attn_bias=mask) | |
x = x.reshape([B, N, C]) | |
else: | |
qkv = qkv.permute(2, 0, 3, 1, 4) | |
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) | |
attn = (q @ k.transpose(-2, -1)) * self.scale | |
if mask is not None: | |
mask = mask.unsqueeze(1) | |
attn = attn.masked_fill(mask, -torch.finfo(attn.dtype).max) | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class CrossAttention(nn.Module): | |
def __init__(self, dim, dim_context=None, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.): | |
super().__init__() | |
dim_context = dim_context or dim | |
self.num_heads = num_heads | |
head_dim = dim // num_heads | |
self.scale = head_dim ** -0.5 | |
self.q = nn.Linear(dim, dim, bias=qkv_bias) | |
self.kv = nn.Linear(dim_context, dim * 2, bias=qkv_bias) | |
self.attn_drop = nn.Dropout(attn_drop) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(proj_drop) | |
def forward(self, x, context, mask=None): | |
B, N, C = x.shape | |
_, M, _ = context.shape | |
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads) | |
kv = self.kv(context).reshape(B, M, 2, self.num_heads, C // self.num_heads) | |
if XFORMERS_AVAILABLE: | |
k, v = unbind(kv, 2) | |
if mask is not None: | |
# Wherever mask is True it becomes -infinity, otherwise 0 | |
mask = mask.to(q.dtype) * -torch.finfo(q.dtype).max | |
x = memory_efficient_attention(q, k, v, attn_bias=mask) | |
x = x.reshape([B, N, C]) | |
else: | |
q = q.permute(0, 2, 1, 3) | |
kv = kv.permute(2, 0, 3, 1, 4) | |
k, v = kv[0], kv[1] | |
attn = (q @ k.transpose(-2, -1)) * self.scale | |
if mask is not None: | |
mask = rearrange(mask, "b n m -> b 1 n m") | |
attn = attn.masked_fill(mask, -torch.finfo(attn.dtype).max) | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, -1) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x | |
class Block(nn.Module): | |
def __init__(self, dim, num_heads, temb_dim=None, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., | |
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, temb_in_mlp=False, temb_after_norm=True, temb_gate=True): | |
super().__init__() | |
self.norm1 = norm_layer(dim) | |
self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp(in_features=dim, temb_dim=temb_dim if temb_in_mlp else None, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | |
if temb_after_norm and temb_dim is not None: | |
# adaLN modulation (see https://arxiv.org/abs/2212.09748) | |
self.adaLN_modulation = nn.Linear(temb_dim, 4 * dim) | |
if temb_gate and temb_dim is not None: | |
# adaLN-Zero gate (see https://arxiv.org/abs/2212.09748) | |
self.adaLN_gate = nn.Linear(temb_dim, 2 * dim) | |
nn.init.zeros_(self.adaLN_gate.weight) | |
nn.init.zeros_(self.adaLN_gate.bias) | |
self.skip_linear = nn.Linear(2*dim, dim) if skip else None | |
def forward(self, x, temb=None, mask=None, skip_connection=None): | |
gate_msa, gate_mlp = self.adaLN_gate(F.silu(temb)).unsqueeze(1).chunk(2, dim=-1) if hasattr(self, 'adaLN_gate') else (1.0, 1.0) | |
shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(F.silu(temb)).chunk(4, dim=-1) if hasattr(self, 'adaLN_modulation') else 4*[0.0] | |
if self.skip_linear is not None: | |
x = self.skip_linear(torch.cat([x, skip_connection], dim=-1)) | |
x = x + gate_msa * self.drop_path(self.attn(modulate(self.norm1(x), shift_msa, scale_msa), mask)) | |
x = x + gate_mlp * self.drop_path(self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), temb)) | |
return x | |
class DecoderBlock(nn.Module): | |
def __init__(self, dim, num_heads, temb_dim=None, dim_context=None, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., | |
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, skip=False, temb_in_mlp=False, temb_after_norm=True, temb_gate=True): | |
super().__init__() | |
dim_context = dim_context or dim | |
self.norm1 = norm_layer(dim) | |
self.self_attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) | |
self.cross_attn = CrossAttention(dim, dim_context=dim_context, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop) | |
self.query_norm = norm_layer(dim) | |
self.context_norm = norm_layer(dim_context) | |
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
self.norm2 = norm_layer(dim) | |
mlp_hidden_dim = int(dim * mlp_ratio) | |
self.mlp = Mlp(in_features=dim, temb_dim=temb_dim if temb_in_mlp else None, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) | |
if temb_after_norm and temb_dim is not None: | |
# adaLN modulation (see https://arxiv.org/abs/2212.09748) | |
self.adaLN_modulation = nn.Linear(temb_dim, 6 * dim) | |
if temb_gate and temb_dim is not None: | |
# adaLN-Zero gate (see https://arxiv.org/abs/2212.09748) | |
self.adaLN_gate = nn.Linear(temb_dim, 3 * dim) | |
nn.init.zeros_(self.adaLN_gate.weight) | |
nn.init.zeros_(self.adaLN_gate.bias) | |
self.skip_linear = nn.Linear(2*dim, dim) if skip else None | |
def forward(self, x, context, temb=None, sa_mask=None, xa_mask=None, skip_connection=None): | |
gate_msa, gate_mxa, gate_mlp = self.adaLN_gate(F.silu(temb)).unsqueeze(1).chunk(3, dim=-1) if hasattr(self, 'adaLN_gate') else (1.0, 1.0, 1.0) | |
shift_msa, scale_msa, shift_mxa, scale_mxa, shift_mlp, scale_mlp = self.adaLN_modulation(F.silu(temb)).chunk(6, dim=-1) if hasattr(self, 'adaLN_modulation') else 6*[0.0] | |
if self.skip_linear is not None: | |
x = self.skip_linear(torch.cat([x, skip_connection], dim=-1)) | |
x = x + gate_msa * self.drop_path(self.self_attn(modulate(self.norm1(x), shift_msa, scale_msa), sa_mask)) | |
x = x + gate_mxa * self.drop_path(self.cross_attn(modulate(self.query_norm(x), shift_mxa, scale_mxa), self.context_norm(context), xa_mask)) | |
x = x + gate_mlp * self.drop_path(self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp), temb)) | |
return x | |
class TransformerConcatCond(nn.Module): | |
"""UViT Transformer bottleneck that concatenates the condition to the input. | |
Args: | |
unet_dim: Number of channels in the last UNet down block. | |
cond_dim: Number of channels in the condition. | |
mid_layers: Number of Transformer layers. | |
mid_num_heads: Number of attention heads. | |
mid_dim: Transformer dimension. | |
mid_mlp_ratio: Ratio of MLP hidden dim to Transformer dim. | |
mid_qkv_bias: Whether to add bias to the query, key, and value projections. | |
mid_drop_rate: Dropout rate. | |
mid_attn_drop_rate: Attention dropout rate. | |
mid_drop_path_rate: Stochastic depth rate. | |
time_embed_dim: Dimension of the time embedding. | |
hw_posemb: Size (side) of the 2D positional embedding. | |
use_long_skip: Whether to use long skip connections. | |
See https://arxiv.org/abs/2209.12152 for more details. | |
""" | |
def __init__( | |
self, | |
unet_dim: int = 1024, | |
cond_dim: int = 32, | |
mid_layers: int = 12, | |
mid_num_heads: int = 12, | |
mid_dim: int = 768, | |
mid_mlp_ratio: int = 4, | |
mid_qkv_bias: bool = True, | |
mid_drop_rate: float = 0.0, | |
mid_attn_drop_rate: float = 0.0, | |
mid_drop_path_rate: float = 0.0, | |
time_embed_dim: int = 512, | |
hw_posemb: int = 16, | |
use_long_skip: bool = False, | |
): | |
super().__init__() | |
self.mid_pos_emb = build_2d_sincos_posemb(h=hw_posemb, w=hw_posemb, embed_dim=mid_dim) | |
self.mid_pos_emb = nn.Parameter(self.mid_pos_emb, requires_grad=False) | |
self.use_long_skip = use_long_skip | |
if use_long_skip: | |
assert mid_layers % 2 == 1, 'mid_layers must be odd when using long skip connection' | |
dpr = [x.item() for x in torch.linspace(0, mid_drop_path_rate, mid_layers)] # stochastic depth decay rule | |
self.mid_block = nn.ModuleList([ | |
Block(dim=mid_dim, temb_dim=time_embed_dim, num_heads=mid_num_heads, mlp_ratio=mid_mlp_ratio, qkv_bias=mid_qkv_bias, | |
drop=mid_drop_rate, attn_drop=mid_attn_drop_rate, drop_path=dpr[i], skip=i > mid_layers//2 and use_long_skip) | |
for i in range(mid_layers) | |
]) | |
self.mid_cond_proj = nn.Linear(cond_dim, mid_dim) | |
self.mid_proj_in = nn.Linear(unet_dim, mid_dim) | |
self.mid_proj_out = nn.Linear(mid_dim, unet_dim) | |
self.mask_token = nn.Parameter(torch.zeros(mid_dim), requires_grad=True) | |
def forward(self, | |
x: torch.Tensor, | |
temb: torch.Tensor, | |
cond: torch.Tensor, | |
cond_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor: | |
"""TransformerConcatCond forward pass. | |
Args: | |
x: UNet features from the last down block of shape [B, C_mid, H_mid, W_mid]. | |
temb: Time embedding of shape [B, temb_dim]. | |
cond: Condition of shape [B, cond_dim, H_cond, W_cond]. If H_cond and W_cond are | |
different from H_mid and W_mid, cond is interpolated to match the spatial size | |
of x. | |
cond_mask: Condition mask of shape [B, H_mid, W_mid]. If a mask is | |
defined, replaces masked-out tokens by a learnable mask-token. | |
Wherever cond_mask is True, the condition gets replaced by the mask token. | |
Returns: | |
Features of shape [B, C_mid, H_mid, W_mid] to pass to the UNet up blocks. | |
""" | |
B, C_mid, H_mid, W_mid = x.shape | |
# Rearrange and proj UNet features to sequence of tokens | |
x = rearrange(x, 'b d h w -> b (h w) d') | |
x = self.mid_proj_in(x) | |
# Rearrange and proj conditioning to sequence of tokens | |
cond = F.interpolate(cond, (H_mid, W_mid)) # Interpolate if necessary | |
cond = rearrange(cond, 'b d h w -> b (h w) d') | |
cond = self.mid_cond_proj(cond) | |
# If a mask is defined, replace masked-out tokens by a learnable mask-token | |
# Wherever cond_mask is True, the condition gets replaced by the mask token | |
if cond_mask is not None: | |
cond_mask = F.interpolate(cond_mask.unsqueeze(1).float(), (H_mid, W_mid), mode='nearest') > 0.5 | |
cond_mask = rearrange(cond_mask, 'b 1 h w -> b (h w)') | |
cond[cond_mask] = self.mask_token.to(dtype=cond.dtype) | |
x = x + cond | |
# Interpolate and rearrange positional embedding to sequence of tokens | |
mid_pos_emb = F.interpolate(self.mid_pos_emb, (H_mid, W_mid), mode='bicubic', align_corners=False) | |
mid_pos_emb = rearrange(mid_pos_emb, 'b d h w -> b (h w) d') | |
x = x + mid_pos_emb | |
# Transformer forward pass with or without long skip connections | |
if not self.use_long_skip: | |
for blk in self.mid_block: | |
x = blk(x, temb) | |
else: | |
skip_connections = [] | |
num_skips = len(self.mid_block) // 2 | |
for blk in self.mid_block[:num_skips]: | |
x = blk(x, temb) | |
skip_connections.append(x) | |
x = self.mid_block[num_skips](x, temb) | |
for blk in self.mid_block[num_skips+1:]: | |
x = blk(x, temb, skip_connection=skip_connections.pop()) | |
x = self.mid_proj_out(x) # Project Transformer output back to UNet channels | |
x = rearrange(x, 'b (h w) d -> b d h w', h=H_mid, w=W_mid) # Rearrange Transformer tokens to a spatial feature map for conv layers | |
return x | |
class TransformerXattnCond(nn.Module): | |
"""UViT Transformer bottleneck that incroporates the condition via cross-attention. | |
Args: | |
unet_dim: Number of channels in the last UNet down block. | |
cond_dim: Number of channels in the condition. | |
mid_layers: Number of Transformer layers. | |
mid_num_heads: Number of attention heads. | |
mid_dim: Transformer dimension. | |
mid_mlp_ratio: Ratio of MLP hidden dim to Transformer dim. | |
mid_qkv_bias: Whether to add bias to the query, key, and value projections. | |
mid_drop_rate: Dropout rate. | |
mid_attn_drop_rate: Attention dropout rate. | |
mid_drop_path_rate: Stochastic depth rate. | |
time_embed_dim: Dimension of the time embedding. | |
hw_posemb: Size (side) of the 2D positional embedding. | |
use_long_skip: Whether to use long skip connections. | |
See https://arxiv.org/abs/2209.12152 for more details. | |
""" | |
def __init__( | |
self, | |
unet_dim: int = 1024, | |
cond_dim: int = 32, | |
mid_layers: int = 12, | |
mid_num_heads: int = 12, | |
mid_dim: int = 768, | |
mid_mlp_ratio: int = 4, | |
mid_qkv_bias: bool = True, | |
mid_drop_rate: float = 0.0, | |
mid_attn_drop_rate: float = 0.0, | |
mid_drop_path_rate: float = 0.0, | |
time_embed_dim: int = 512, | |
hw_posemb: int = 16, | |
use_long_skip: bool = False, | |
): | |
super().__init__() | |
self.mid_pos_emb = build_2d_sincos_posemb(h=hw_posemb, w=hw_posemb, embed_dim=mid_dim) | |
self.mid_pos_emb = nn.Parameter(self.mid_pos_emb, requires_grad=False) | |
self.use_long_skip = use_long_skip | |
if use_long_skip: | |
assert mid_layers % 2 == 1, 'mid_layers must be odd when using long skip connection' | |
dpr = [x.item() for x in torch.linspace(0, mid_drop_path_rate, mid_layers)] # stochastic depth decay rule | |
self.mid_block = nn.ModuleList([ | |
DecoderBlock( | |
dim=mid_dim, temb_dim=time_embed_dim, num_heads=mid_num_heads, dim_context=cond_dim, | |
mlp_ratio=mid_mlp_ratio, qkv_bias=mid_qkv_bias, drop=mid_drop_rate, | |
attn_drop=mid_attn_drop_rate, drop_path=dpr[i], | |
skip=i > mid_layers//2 and use_long_skip | |
) | |
for i in range(mid_layers) | |
]) | |
self.mid_proj_in = nn.Linear(unet_dim, mid_dim) | |
self.mid_proj_out = nn.Linear(mid_dim, unet_dim) | |
def forward(self, | |
x: torch.Tensor, | |
temb: torch.Tensor, | |
cond: torch.Tensor, | |
cond_mask: Optional[torch.BoolTensor] = None) -> torch.Tensor: | |
"""TransformerXattnCond forward pass. | |
Args: | |
x: UNet features from the last down block of shape [B, C_mid, H_mid, W_mid]. | |
temb: Time embedding of shape [B, temb_dim]. | |
cond: Condition of shape [B, cond_dim, H_cond, W_cond]. | |
cond_mask: Condition cross-attention mask of shape [B, H_cond, W_cond]. | |
If a mask is defined, wherever cond_mask is True, the condition at that | |
spatial location is not cross-attended to. | |
Returns: | |
Features of shape [B, C_mid, H_mid, W_mid] to pass to the UNet up blocks. | |
""" | |
B, C_mid, H_mid, W_mid = x.shape | |
# Rearrange and proj UNet features to sequence of tokens | |
x = rearrange(x, 'b d h w -> b (h w) d') | |
x = self.mid_proj_in(x) | |
# Rearrange conditioning to sequence of tokens | |
cond = rearrange(cond, 'b d h w -> b (h w) d') | |
# Interpolate and rearrange positional embedding to sequence of tokens | |
mid_pos_emb = F.interpolate(self.mid_pos_emb, (H_mid, W_mid), mode='nearest') | |
mid_pos_emb = rearrange(mid_pos_emb, 'b d h w -> b (h w) d') | |
# Add UNet mid-block features and positional embedding | |
x = x + mid_pos_emb | |
# Prepare the conditioning cross-attention mask | |
xa_mask = repeat(cond_mask, 'b h w -> b n (h w)', n=x.shape[1]) if cond_mask is not None else None | |
# Transformer forward pass with or without long skip connections. | |
# In each layer, cross-attend to the conditioning. | |
if not self.use_long_skip: | |
for blk in self.mid_block: | |
x = blk(x, cond, temb, xa_mask=xa_mask) | |
else: | |
skip_connections = [] | |
num_skips = len(self.mid_block) // 2 | |
for blk in self.mid_block[:num_skips]: | |
x = blk(x, cond, temb, xa_mask=xa_mask) | |
skip_connections.append(x) | |
x = self.mid_block[num_skips](x, cond, temb, xa_mask=xa_mask) | |
for blk in self.mid_block[num_skips+1:]: | |
x = blk(x, cond, temb, xa_mask=xa_mask, skip_connection=skip_connections.pop()) | |
x = self.mid_proj_out(x) # Project Transformer output back to UNet channels | |
x = rearrange(x, 'b (h w) d -> b d h w', h=H_mid, w=W_mid) # Rearrange Transformer tokens to a spatial feature map for conv layers | |
return x | |
class UViT(ModelMixin, ConfigMixin): | |
"""UViT model = Conditional UNet with Transformer bottleneck | |
blocks and optionalpatching. | |
See https://arxiv.org/abs/2301.11093 for more details. | |
Args: | |
sample_size: Size of the input images. | |
in_channels: Number of input channels. | |
out_channels: Number of output channels. | |
patch_size: Size of the input patching operation. | |
See https://arxiv.org/abs/2207.04316 for more details. | |
block_out_channels: Number of output channels of each UNet ResNet-block. | |
layers_per_block: Number of ResNet blocks per UNet block. | |
downsample_before_mid: Whether to downsample before the Transformer bottleneck. | |
mid_layers: Number of Transformer blocks. | |
mid_num_heads: Number of attention heads. | |
mid_dim: Transformer dimension. | |
mid_mlp_ratio: Transformer MLP ratio. | |
mid_qkv_bias: Whether to use bias in the Transformer QKV projection. | |
mid_drop_rate: Dropout rate of the Transformer MLP and attention output projection. | |
mid_attn_drop_rate: Dropout rate of the Transformer attention. | |
mid_drop_path_rate: Stochastic depth rate of the Transformer blocks. | |
mid_hw_posemb: Size (side) of the Transformer positional embedding. | |
mid_use_long_skip: Whether to use long skip connections in the Transformer blocks. | |
See https://arxiv.org/abs/2209.12152 for more details. | |
cond_dim: Dimension of the conditioning vector. | |
cond_type: Type of conditioning. | |
'concat' for concatenation, 'xattn' for cross-attention. | |
downsample_padding: Padding of the UNet downsampling convolutions. | |
act_fn: Activation function. | |
norm_num_groups: Number of groups in the UNet ResNet-block normalization. | |
norm_eps: Epsilon of the UNet ResNet-block normalization. | |
resnet_time_scale_shift: Time scale shift of the UNet ResNet-blocks. | |
resnet_out_scale_factor: Output scale factor of the UNet ResNet-blocks. | |
time_embedding_type: Type of the time embedding. | |
'positional' for positional, 'fourier' for Fourier. | |
time_embedding_dim: Dimension of the time embedding. | |
time_embedding_act_fn: Activation function of the time embedding. | |
timestep_post_act: Activation function after the time embedding. | |
time_cond_proj_dim: Dimension of the optional conditioning projection. | |
flip_sin_to_cos: Whether to flip the sine to cosine in the time embedding. | |
freq_shift: Frequency shift of the time embedding. | |
res_embedding: Whether to perform original resolution conditioning. | |
See SDXL https://arxiv.org/abs/2307.01952 for more details. | |
""" | |
def __init__(self, | |
# UNet settings | |
sample_size: Optional[int] = None, | |
in_channels: int = 3, | |
out_channels: int = 3, | |
patch_size: int = 4, | |
block_out_channels: Tuple[int] = (128, 256, 512), | |
layers_per_block: Union[int, Tuple[int]] = 2, | |
downsample_before_mid: bool = False, | |
# Mid-block Transformer settings | |
mid_layers: int = 12, | |
mid_num_heads: int = 12, | |
mid_dim: int = 768, | |
mid_mlp_ratio: int = 4, | |
mid_qkv_bias: bool = True, | |
mid_drop_rate: float = 0.0, | |
mid_attn_drop_rate: float = 0.0, | |
mid_drop_path_rate: float = 0.0, | |
mid_hw_posemb: int = 32, | |
mid_use_long_skip: bool = False, | |
# Conditioning settings | |
cond_dim: int = 32, | |
cond_type: str = 'concat', | |
# ResNet blocks settings | |
downsample_padding: int = 1, | |
act_fn: str = "silu", | |
norm_num_groups: Optional[int] = 32, | |
norm_eps: float = 1e-5, | |
resnet_time_scale_shift: str = "default", | |
resnet_out_scale_factor: int = 1.0, | |
# Time embedding settings | |
time_embedding_type: str = "positional", | |
time_embedding_dim: Optional[int] = None, | |
time_embedding_act_fn: Optional[str] = None, | |
timestep_post_act: Optional[str] = None, | |
time_cond_proj_dim: Optional[int] = None, | |
flip_sin_to_cos: bool = True, | |
freq_shift: int = 0, | |
# Original resolution embedding settings | |
res_embedding: bool = False): | |
super().__init__() | |
self.sample_size = sample_size | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.mid_dim = block_out_channels[-1] | |
self.res_embedding = res_embedding | |
# input patching | |
self.conv_in = nn.Conv2d( | |
in_channels, block_out_channels[0], kernel_size=patch_size, padding=0, stride=patch_size | |
) | |
# time | |
if time_embedding_type == "fourier": | |
time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 | |
if time_embed_dim % 2 != 0: | |
raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") | |
self.time_proj = GaussianFourierProjection( | |
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos | |
) | |
timestep_input_dim = time_embed_dim | |
elif time_embedding_type == "positional": | |
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 | |
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) | |
timestep_input_dim = block_out_channels[0] | |
else: | |
raise ValueError( | |
f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." | |
) | |
self.time_embedding = TimestepEmbedding( | |
timestep_input_dim, | |
time_embed_dim, | |
act_fn=act_fn, | |
post_act_fn=timestep_post_act, | |
cond_proj_dim=time_cond_proj_dim, | |
) | |
if time_embedding_act_fn is None: | |
self.time_embed_act = None | |
elif time_embedding_act_fn == "swish": | |
self.time_embed_act = lambda x: F.silu(x) | |
elif time_embedding_act_fn == "mish": | |
self.time_embed_act = nn.Mish() | |
elif time_embedding_act_fn == "silu": | |
self.time_embed_act = nn.SiLU() | |
elif time_embedding_act_fn == "gelu": | |
self.time_embed_act = nn.GELU() | |
else: | |
raise ValueError(f"Unsupported activation function: {time_embedding_act_fn}") | |
# original resolution embedding | |
if res_embedding: | |
if time_embedding_type == "fourier": | |
self.h_proj = GaussianFourierProjection( | |
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos | |
) | |
self.w_proj = GaussianFourierProjection( | |
time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos | |
) | |
elif time_embedding_type == "positional": | |
self.height_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) | |
self.width_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) | |
self.height_embedding = TimestepEmbedding( | |
timestep_input_dim, time_embed_dim, act_fn=act_fn, | |
post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, | |
) | |
self.width_embedding = TimestepEmbedding( | |
timestep_input_dim, time_embed_dim, act_fn=act_fn, | |
post_act_fn=timestep_post_act, cond_proj_dim=time_cond_proj_dim, | |
) | |
self.down_blocks = nn.ModuleList([]) | |
self.up_blocks = nn.ModuleList([]) | |
if isinstance(layers_per_block, int): | |
layers_per_block = [layers_per_block] * len(block_out_channels) | |
# down | |
output_channel = block_out_channels[0] | |
for i in range(len(block_out_channels)): | |
input_channel = output_channel | |
output_channel = block_out_channels[i] | |
is_final_block = i == len(block_out_channels) - 1 | |
down_block = DownBlock2D( | |
num_layers=layers_per_block[i], | |
in_channels=input_channel, | |
out_channels=output_channel, | |
temb_channels=time_embed_dim, | |
add_downsample=not is_final_block, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
downsample_padding=downsample_padding, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
output_scale_factor=resnet_out_scale_factor, | |
) | |
self.down_blocks.append(down_block) | |
if downsample_before_mid: | |
self.downsample_mid = Downsample2D(self.mid_dim, use_conv=True, out_channels=self.mid_dim) | |
# mid | |
if cond_type == 'concat': | |
self.mid_block = TransformerConcatCond( | |
unet_dim=self.mid_dim, cond_dim=cond_dim, mid_layers=mid_layers, mid_num_heads=mid_num_heads, | |
mid_dim=mid_dim, mid_mlp_ratio=mid_mlp_ratio, mid_qkv_bias=mid_qkv_bias, | |
mid_drop_rate=mid_drop_rate, mid_attn_drop_rate=mid_attn_drop_rate, mid_drop_path_rate=mid_drop_path_rate, | |
time_embed_dim=time_embed_dim, hw_posemb=mid_hw_posemb, use_long_skip=mid_use_long_skip, | |
) | |
elif cond_type == 'xattn': | |
self.mid_block = TransformerXattnCond( | |
unet_dim=self.mid_dim, cond_dim=cond_dim, mid_layers=mid_layers, mid_num_heads=mid_num_heads, | |
mid_dim=mid_dim, mid_mlp_ratio=mid_mlp_ratio, mid_qkv_bias=mid_qkv_bias, | |
mid_drop_rate=mid_drop_rate, mid_attn_drop_rate=mid_attn_drop_rate, mid_drop_path_rate=mid_drop_path_rate, | |
time_embed_dim=time_embed_dim, hw_posemb=mid_hw_posemb, use_long_skip=mid_use_long_skip, | |
) | |
else: | |
raise ValueError(f"Unsupported cond_type: {cond_type}") | |
# count how many layers upsample the images | |
self.num_upsamplers = 0 | |
# up | |
if downsample_before_mid: | |
self.upsample_mid = Upsample2D(self.mid_dim, use_conv=True, out_channels=self.mid_dim) | |
reversed_block_out_channels = list(reversed(block_out_channels)) | |
reversed_layers_per_block = list(reversed(layers_per_block)) | |
output_channel = reversed_block_out_channels[0] | |
for i in range(len(reversed_block_out_channels)): | |
is_final_block = i == len(block_out_channels) - 1 | |
prev_output_channel = output_channel | |
output_channel = reversed_block_out_channels[i] | |
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] | |
# add upsample block for all BUT final layer | |
if not is_final_block: | |
add_upsample = True | |
self.num_upsamplers += 1 | |
else: | |
add_upsample = False | |
up_block = UpBlock2D( | |
num_layers=reversed_layers_per_block[i] + 1, | |
in_channels=input_channel, | |
out_channels=output_channel, | |
prev_output_channel=prev_output_channel, | |
temb_channels=time_embed_dim, | |
add_upsample=add_upsample, | |
resnet_eps=norm_eps, | |
resnet_act_fn=act_fn, | |
resnet_groups=norm_num_groups, | |
resnet_time_scale_shift=resnet_time_scale_shift, | |
output_scale_factor=resnet_out_scale_factor, | |
) | |
self.up_blocks.append(up_block) | |
prev_output_channel = output_channel | |
# out | |
if norm_num_groups is not None: | |
self.conv_norm_out = nn.GroupNorm( | |
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps | |
) | |
if act_fn == "swish": | |
self.conv_act = lambda x: F.silu(x) | |
elif act_fn == "mish": | |
self.conv_act = nn.Mish() | |
elif act_fn == "silu": | |
self.conv_act = nn.SiLU() | |
elif act_fn == "gelu": | |
self.conv_act = nn.GELU() | |
else: | |
raise ValueError(f"Unsupported activation function: {act_fn}") | |
else: | |
self.conv_norm_out = None | |
self.conv_act = None | |
self.conv_out = nn.ConvTranspose2d( | |
block_out_channels[0], out_channels, kernel_size=patch_size, stride=patch_size | |
) | |
self.init_weights() | |
def init_weights(self) -> None: | |
"""Weight initialization following MAE's initialization scheme""" | |
for name, m in self.named_modules(): | |
# Handle already zero-init gates | |
if "adaLN_gate" in name: | |
continue | |
# Handle ResNet gates that were not initialized by diffusers | |
if "conv2" in name: | |
nn.init.zeros_(m.weight) | |
nn.init.zeros_(m.bias) | |
# Linear | |
elif isinstance(m, nn.Linear): | |
if 'qkv' in name: | |
# treat the weights of Q, K, V separately | |
val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1])) | |
nn.init.uniform_(m.weight, -val, val) | |
elif 'kv' in name: | |
# treat the weights of K, V separately | |
val = math.sqrt(6. / float(m.weight.shape[0] // 2 + m.weight.shape[1])) | |
nn.init.uniform_(m.weight, -val, val) | |
else: | |
nn.init.xavier_uniform_(m.weight) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
nn.init.constant_(m.bias, 0) | |
# LayerNorm | |
elif isinstance(m, nn.LayerNorm): | |
nn.init.constant_(m.bias, 0) | |
nn.init.constant_(m.weight, 1.0) | |
# Embedding | |
elif isinstance(m, nn.Embedding): | |
nn.init.normal_(m.weight, std=self.init_std) | |
# Conv2d | |
elif isinstance(m, nn.Conv2d): | |
if '.conv_in' in name or '.conv_out' in name: | |
# From MAE, initialize projection like nn.Linear (instead of nn.Conv2d) | |
w = m.weight.data | |
nn.init.xavier_uniform_(w.view([w.shape[0], -1])) | |
def forward( | |
self, | |
sample: torch.FloatTensor, | |
timestep: Union[torch.Tensor, float, int], | |
condition: torch.Tensor, | |
cond_mask: Optional[torch.Tensor] = None, | |
timestep_cond: Optional[torch.Tensor] = None, | |
orig_res: Optional[Union[torch.LongTensor, Tuple[int, int]]] = None, | |
**kwargs, | |
) -> torch.Tensor: | |
"""UViT forward pass. | |
Args: | |
sample: Noisy image of shape (B, C, H, W). | |
timestep: Timestep(s) of the current batch. | |
condition: Conditioning tensor of shape (B, C_cond, H_cond, W_cond). When concatenating | |
the condition, it is interpolated to the resolution of the transformer (H_mid, W_mid). | |
cond_mask: Mask tensor of shape (B, H_mid, W_mid) when concatenating the condition | |
to the transformer, and (B, H_cond, W_cond) when using cross-attention. True for | |
masked out / ignored regions. | |
timestep_cond: Optional conditioning to add to the timestep embedding. | |
orig_res: The original resolution of the image to condition the diffusion on. Ignored if None. | |
See SDXL https://arxiv.org/abs/2307.01952 for more details. | |
Returns: | |
Diffusion objective target image of shape (B, C, H, W). | |
""" | |
# By default samples have to be AT least a multiple of the overall upsampling factor. | |
# The overall upsampling factor is equal to 2 ** (# num of upsampling layers). | |
# However, the upsampling interpolation output size can be forced to fit any upsampling size | |
# on the fly if necessary. | |
default_overall_up_factor = 2**self.num_upsamplers | |
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` | |
forward_upsample_size = False | |
upsample_size = None | |
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): | |
forward_upsample_size = True | |
# 1. time | |
timesteps = timestep | |
is_mps = sample.device.type == "mps" | |
if not torch.is_tensor(timesteps): | |
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can | |
# This would be a good case for the `match` statement (Python 3.10+) | |
if isinstance(timestep, float): | |
dtype = torch.float32 if is_mps else torch.float64 | |
else: | |
dtype = torch.int32 if is_mps else torch.int64 | |
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) | |
elif len(timesteps.shape) == 0: | |
timesteps = timesteps[None].to(sample.device) | |
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML | |
timesteps = timesteps.expand(sample.shape[0]) | |
t_emb = self.time_proj(timesteps) | |
# `Timesteps` does not contain any weights and will always return f32 tensors | |
# but time_embedding might actually be running in fp16. so we need to cast here. | |
# there might be better ways to encapsulate this. | |
t_emb = t_emb.to(dtype=sample.dtype) | |
emb = self.time_embedding(t_emb, timestep_cond) | |
# 1.5 original resolution conditioning (see SDXL paper) | |
if orig_res is not None and self.res_embedding: | |
if not torch.is_tensor(orig_res): | |
h_orig, w_orig = orig_res | |
dtype = torch.int32 if is_mps else torch.int64 | |
h_orig = torch.tensor([h_orig], dtype=dtype, device=sample.device).expand(sample.shape[0]) | |
w_orig = torch.tensor([w_orig], dtype=dtype, device=sample.device).expand(sample.shape[0]) | |
else: | |
h_orig, w_orig = orig_res[:,0], orig_res[:,1] | |
h_emb = self.height_proj(h_orig).to(dtype=sample.dtype) | |
w_emb = self.width_proj(w_orig).to(dtype=sample.dtype) | |
emb = emb + self.height_embedding(h_emb) | |
emb = emb + self.width_embedding(w_emb) | |
if self.time_embed_act is not None: | |
emb = self.time_embed_act(emb) | |
# 2. pre-process | |
sample = self.conv_in(sample) | |
# 3. down | |
down_block_res_samples = (sample,) | |
for downsample_block in self.down_blocks: | |
sample, res_samples = downsample_block(hidden_states=sample, temb=emb) | |
down_block_res_samples += res_samples | |
if hasattr(self, 'downsample_mid'): | |
sample = self.downsample_mid(sample) | |
# 4. mid | |
sample = self.mid_block(sample, emb, condition, cond_mask) | |
# 5. up | |
if hasattr(self, 'upsample_mid'): | |
sample = self.upsample_mid(sample) | |
for i, upsample_block in enumerate(self.up_blocks): | |
is_final_block = i == len(self.up_blocks) - 1 | |
res_samples = down_block_res_samples[-len(upsample_block.resnets) :] | |
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] | |
# if we have not reached the final block and need to forward the | |
# upsample size, we do it here | |
if not is_final_block and forward_upsample_size: | |
upsample_size = down_block_res_samples[-1].shape[2:] | |
sample = upsample_block( | |
hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size | |
) | |
# 6. post-process | |
if self.conv_norm_out: | |
sample = self.conv_norm_out(sample) | |
sample = self.conv_act(sample) | |
sample = self.conv_out(sample) | |
return sample | |
def uvit_b_p4_f16(**kwargs): | |
return UViT( | |
patch_size=4, | |
block_out_channels=(128, 256), | |
layers_per_block=2, | |
downsample_before_mid=True, | |
mid_layers=12, | |
mid_num_heads=12, | |
mid_dim=768, | |
mid_mlp_ratio=4, | |
mid_qkv_bias=True, | |
**kwargs | |
) | |
def uvit_l_p4_f16(**kwargs): | |
return UViT( | |
patch_size=4, | |
block_out_channels=(128, 256), | |
layers_per_block=2, | |
downsample_before_mid=True, | |
mid_layers=24, | |
mid_num_heads=16, | |
mid_dim=1024, | |
mid_mlp_ratio=4, | |
mid_qkv_bias=True, | |
**kwargs | |
) | |
def uvit_h_p4_f16(**kwargs): | |
return UViT( | |
patch_size=4, | |
block_out_channels=(128, 256), | |
layers_per_block=2, | |
downsample_before_mid=True, | |
mid_layers=32, | |
mid_num_heads=16, | |
mid_dim=1280, | |
mid_mlp_ratio=4, | |
mid_qkv_bias=True, | |
**kwargs | |
) | |
def uvit_b_p4_f16_longskip(**kwargs): | |
return UViT( | |
patch_size=4, | |
block_out_channels=(128, 256), | |
layers_per_block=2, | |
downsample_before_mid=True, | |
mid_layers=13, | |
mid_num_heads=12, | |
mid_dim=768, | |
mid_mlp_ratio=4, | |
mid_qkv_bias=True, | |
mid_use_long_skip=True, | |
**kwargs | |
) | |
def uvit_l_p4_f16_longskip(**kwargs): | |
return UViT( | |
patch_size=4, | |
block_out_channels=(128, 256), | |
layers_per_block=2, | |
downsample_before_mid=True, | |
mid_layers=25, | |
mid_num_heads=16, | |
mid_dim=1024, | |
mid_mlp_ratio=4, | |
mid_qkv_bias=True, | |
mid_use_long_skip=True, | |
**kwargs | |
) | |
def uvit_b_p4_f8(**kwargs): | |
return UViT( | |
patch_size=4, | |
block_out_channels=(128, 256), | |
layers_per_block=2, | |
downsample_before_mid=False, | |
mid_layers=12, | |
mid_num_heads=12, | |
mid_dim=768, | |
mid_mlp_ratio=4, | |
mid_qkv_bias=True, | |
**kwargs | |
) | |
def uvit_l_p4_f8(**kwargs): | |
return UViT( | |
patch_size=4, | |
block_out_channels=(128, 256), | |
layers_per_block=2, | |
downsample_before_mid=False, | |
mid_layers=24, | |
mid_num_heads=16, | |
mid_dim=1024, | |
mid_mlp_ratio=4, | |
mid_qkv_bias=True, | |
**kwargs | |
) | |
def uvit_b_p4_f16_extraconv(**kwargs): | |
return UViT( | |
patch_size=4, | |
block_out_channels=(128, 256, 512), | |
layers_per_block=2, | |
downsample_before_mid=False, | |
mid_layers=12, | |
mid_num_heads=12, | |
mid_dim=768, | |
mid_mlp_ratio=4, | |
mid_qkv_bias=True, | |
**kwargs | |
) | |
def uvit_l_p4_f16_extraconv(**kwargs): | |
return UViT( | |
patch_size=4, | |
block_out_channels=(128, 256, 512), | |
layers_per_block=2, | |
downsample_before_mid=False, | |
mid_layers=24, | |
mid_num_heads=16, | |
mid_dim=1024, | |
mid_mlp_ratio=4, | |
mid_qkv_bias=True, | |
**kwargs | |
) | |