Spaces:
Runtime error
Runtime error
import torch.nn.functional as F | |
import torch.nn as nn | |
import torch | |
from monai.networks.blocks import TransformerBlock | |
from monai.networks.layers.utils import get_norm_layer, get_dropout_layer | |
from monai.networks.layers.factories import Conv | |
from einops import rearrange | |
class GEGLU(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super().__init__() | |
self.norm = nn.LayerNorm(in_channels) | |
self.proj = nn.Linear(in_channels, out_channels*2, bias=True) | |
def forward(self, x): | |
# x expected to be [B, C, *] | |
# Workaround as layer norm can't currently be applied on arbitrary dimension: https://github.com/pytorch/pytorch/issues/71465 | |
b, c, *spatial = x.shape | |
x = x.reshape(b, c, -1).transpose(1, 2) # -> [B, C, N] -> [B, N, C] | |
x = self.norm(x) | |
x, gate = self.proj(x).chunk(2, dim=-1) | |
x = x * F.gelu(gate) | |
return x.transpose(1, 2).reshape(b, -1, *spatial) # -> [B, C, N] -> [B, C, *] | |
def zero_module(module): | |
""" | |
Zero out the parameters of a module and return it. | |
""" | |
for p in module.parameters(): | |
p.detach().zero_() | |
return module | |
def compute_attention(q,k,v , num_heads, scale): | |
q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> (b h) d n', h=num_heads), (q, k, v)) # [(BxHeads), Dim_per_head, N] | |
attn = (torch.einsum('b d i, b d j -> b i j', q*scale, k*scale)).softmax(dim=-1) # Matrix product = [(BxHeads), Dim_per_head, N] * [(BxHeads), Dim_per_head, N'] =[(BxHeads), N, N'] | |
out = torch.einsum('b i j, b d j-> b d i', attn, v) # Matrix product: [(BxHeads), N, N'] * [(BxHeads), Dim_per_head, N'] = [(BxHeads), Dim_per_head, N] | |
out = rearrange(out, '(b h) d n-> b (h d) n', h=num_heads) # -> [B, (Heads x Dim_per_head), N] | |
return out | |
class LinearTransformerNd(nn.Module): | |
""" Combines multi-head self-attention and multi-head cross-attention. | |
Multi-Head Self-Attention: | |
Similar to multi-head self-attention (https://arxiv.org/abs/1706.03762) without Norm+MLP (compare Monai TransformerBlock) | |
Proposed here: https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. | |
Similar to: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/diffusionmodules/openaimodel.py#L278 | |
Similar to: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L80 | |
Similar to: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/dfbafee555bdae80b55d63a989073836bbfc257e/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L209 | |
Similar to: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py#L150 | |
CrossAttention: | |
Proposed here: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L152 | |
""" | |
def __init__( | |
self, | |
spatial_dims, | |
in_channels, | |
out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled | |
num_heads=8, | |
ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs) | |
norm_name=("GROUP", {'num_groups':32, "affine": True}), # Or use LayerNorm but be aware of https://github.com/pytorch/pytorch/issues/71465 (=> GroupNorm with num_groups=1) | |
dropout=None, | |
emb_dim=None, | |
): | |
super().__init__() | |
hid_channels = num_heads*ch_per_head | |
self.num_heads = num_heads | |
self.scale = ch_per_head**-0.25 # Should be 1/sqrt("queries and keys of dimension"), Note: additional sqrt needed as it follows OpenAI: (q * scale) * (k * scale) instead of (q *k) * scale | |
self.norm_x = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels) | |
emb_dim = in_channels if emb_dim is None else emb_dim | |
Convolution = Conv["conv", spatial_dims] | |
self.to_q = Convolution(in_channels, hid_channels, 1) | |
self.to_k = Convolution(emb_dim, hid_channels, 1) | |
self.to_v = Convolution(emb_dim, hid_channels, 1) | |
self.to_out = nn.Sequential( | |
zero_module(Convolution(hid_channels, out_channels, 1)), | |
nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims) | |
) | |
def forward(self, x, embedding=None): | |
# x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *] | |
# if no embedding is given, cross-attention defaults to self-attention | |
# Normalize | |
b, c, *spatial = x.shape | |
x_n = self.norm_x(x) | |
# Attention: embedding (cross-attention) or x (self-attention) | |
if embedding is None: | |
embedding = x_n # WARNING: This assumes that emb_dim==in_channels | |
else: | |
if embedding.ndim == 2: | |
embedding = embedding.reshape(*embedding.shape[:2], *[1]*(x.ndim-2)) # [B, C*] -> [B, C*, *] | |
# Why no normalization for embedding here? | |
# Convolution | |
q = self.to_q(x_n) # -> [B, (Heads x Dim_per_head), *] | |
k = self.to_k(embedding) # -> [B, (Heads x Dim_per_head), *] | |
v = self.to_v(embedding) # -> [B, (Heads x Dim_per_head), *] | |
# Flatten | |
q = q.reshape(b, c, -1) # -> [B, (Heads x Dim_per_head), N] | |
k = k.reshape(*embedding.shape[:2], -1) # -> [B, (Heads x Dim_per_head), N'] | |
v = v.reshape(*embedding.shape[:2], -1) # -> [B, (Heads x Dim_per_head), N'] | |
# Apply attention | |
out = compute_attention(q, k, v, self.num_heads, self.scale) | |
out = out.reshape(*out.shape[:2], *spatial) # -> [B, (Heads x Dim_per_head), *] | |
out = self.to_out(out) # -> [B, C', *] | |
if x.shape == out.shape: | |
out = x + out | |
return out # [B, C', *] | |
class LinearTransformer(nn.Module): | |
""" See LinearTransformer, however this implementation is fixed to Conv1d/Linear""" | |
def __init__( | |
self, | |
spatial_dims, | |
in_channels, | |
out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled | |
num_heads, | |
ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs) | |
norm_name=("GROUP", {'num_groups':32, "affine": True}), | |
dropout=None, | |
emb_dim=None | |
): | |
super().__init__() | |
hid_channels = num_heads*ch_per_head | |
self.num_heads = num_heads | |
self.scale = ch_per_head**-0.25 # Should be 1/sqrt("queries and keys of dimension"), Note: additional sqrt needed as it follows OpenAI: (q * scale) * (k * scale) instead of (q *k) * scale | |
self.norm_x = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels) | |
emb_dim = in_channels if emb_dim is None else emb_dim | |
# Note: Conv1d and Linear are interchangeable but order of input changes [B, C, N] <-> [B, N, C] | |
self.to_q = nn.Conv1d(in_channels, hid_channels, 1) | |
self.to_k = nn.Conv1d(emb_dim, hid_channels, 1) | |
self.to_v = nn.Conv1d(emb_dim, hid_channels, 1) | |
# self.to_qkv = nn.Conv1d(emb_dim, hid_channels*3, 1) | |
self.to_out = nn.Sequential( | |
zero_module(nn.Conv1d(hid_channels, out_channels, 1)), | |
nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims) | |
) | |
def forward(self, x, embedding=None): | |
# x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *] | |
# if no embedding is given, cross-attention defaults to self-attention | |
# Normalize | |
b, c, *spatial = x.shape | |
x_n = self.norm_x(x) | |
# Attention: embedding (cross-attention) or x (self-attention) | |
if embedding is None: | |
embedding = x_n # WARNING: This assumes that emb_dim==in_channels | |
else: | |
if embedding.ndim == 2: | |
embedding = embedding.reshape(*embedding.shape[:2], *[1]*(x.ndim-2)) # [B, C*] -> [B, C*, *] | |
# Why no normalization for embedding here? | |
# Flatten | |
x_n = x_n.reshape(b, c, -1) # [B, C, *] -> [B, C, N] | |
embedding = embedding.reshape(*embedding.shape[:2], -1) # [B, C*, *] -> [B, C*, N'] | |
# Convolution | |
q = self.to_q(x_n) # -> [B, (Heads x Dim_per_head), N] | |
k = self.to_k(embedding) # -> [B, (Heads x Dim_per_head), N'] | |
v = self.to_v(embedding) # -> [B, (Heads x Dim_per_head), N'] | |
# qkv = self.to_qkv(x_n) | |
# q,k,v = qkv.split(qkv.shape[1]//3, dim=1) | |
# Apply attention | |
out = compute_attention(q, k, v, self.num_heads, self.scale) | |
out = self.to_out(out) # -> [B, C', N] | |
out = out.reshape(*out.shape[:2], *spatial) # -> [B, C', *] | |
if x.shape == out.shape: | |
out = x + out | |
return out # [B, C', *] | |
class BasicTransformerBlock(nn.Module): | |
def __init__( | |
self, | |
spatial_dims, | |
in_channels, | |
out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled | |
num_heads, | |
ch_per_head=32, | |
norm_name=("GROUP", {'num_groups':32, "affine": True}), | |
dropout=None, | |
emb_dim=None | |
): | |
super().__init__() | |
self.self_atn = LinearTransformer(spatial_dims, in_channels, in_channels, num_heads, ch_per_head, norm_name, dropout, None) | |
if emb_dim is not None: | |
self.cros_atn = LinearTransformer(spatial_dims, in_channels, in_channels, num_heads, ch_per_head, norm_name, dropout, emb_dim) | |
self.proj_out = nn.Sequential( | |
GEGLU(in_channels, in_channels*4), | |
nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims), | |
Conv["conv", spatial_dims](in_channels*4, out_channels, 1, bias=True) | |
) | |
def forward(self, x, embedding=None): | |
# x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *] | |
x = self.self_atn(x) | |
if embedding is not None: | |
x = self.cros_atn(x, embedding=embedding) | |
out = self.proj_out(x) | |
if out.shape[1] == x.shape[1]: | |
return out + x | |
return x | |
class SpatialTransformer(nn.Module): | |
""" Proposed here: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L218 | |
Unrelated to: https://arxiv.org/abs/1506.02025 | |
""" | |
def __init__( | |
self, | |
spatial_dims, | |
in_channels, | |
out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled | |
num_heads, | |
ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs) | |
norm_name = ("GROUP", {'num_groups':32, "affine": True}), | |
dropout=None, | |
emb_dim=None, | |
depth=1 | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.norm = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels) | |
conv_class = Conv["conv", spatial_dims] | |
hid_channels = num_heads*ch_per_head | |
self.proj_in = conv_class( | |
in_channels, | |
hid_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
self.transformer_blocks = nn.ModuleList([ | |
BasicTransformerBlock(spatial_dims, hid_channels, hid_channels, num_heads, ch_per_head, norm_name, dropout=dropout, emb_dim=emb_dim) | |
for _ in range(depth)] | |
) | |
self.proj_out = conv_class( # Note: zero_module is used in original code | |
hid_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
) | |
def forward(self, x, embedding=None): | |
# x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *] | |
# Note: if no embedding is given, cross-attention is disabled | |
h = self.norm(x) | |
h = self.proj_in(h) | |
for block in self.transformer_blocks: | |
h = block(h, embedding=embedding) | |
h = self.proj_out(h) # -> [B, C'', *] | |
if h.shape == x.shape: | |
return h + x | |
return h | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
spatial_dims, | |
in_channels, | |
out_channels, | |
num_heads=8, | |
ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs) | |
norm_name = ("GROUP", {'num_groups':32, "affine": True}), | |
dropout=0, | |
emb_dim=None, | |
depth=1, | |
attention_type='linear' | |
) -> None: | |
super().__init__() | |
if attention_type == 'spatial': | |
self.attention = SpatialTransformer( | |
spatial_dims=spatial_dims, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
num_heads=num_heads, | |
ch_per_head=ch_per_head, | |
depth=depth, | |
norm_name=norm_name, | |
dropout=dropout, | |
emb_dim=emb_dim | |
) | |
elif attention_type == 'linear': | |
self.attention = LinearTransformer( | |
spatial_dims=spatial_dims, | |
in_channels=in_channels, | |
out_channels=out_channels, | |
num_heads=num_heads, | |
ch_per_head=ch_per_head, | |
norm_name=norm_name, | |
dropout=dropout, | |
emb_dim=emb_dim | |
) | |
def forward(self, x, emb=None): | |
if hasattr(self, 'attention'): | |
return self.attention(x, emb) | |
else: | |
return x |