mueller-franzes's picture
init
f85e212
raw
history blame
14.1 kB
import torch.nn.functional as F
import torch.nn as nn
import torch
from monai.networks.blocks import TransformerBlock
from monai.networks.layers.utils import get_norm_layer, get_dropout_layer
from monai.networks.layers.factories import Conv
from einops import rearrange
class GEGLU(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.norm = nn.LayerNorm(in_channels)
self.proj = nn.Linear(in_channels, out_channels*2, bias=True)
def forward(self, x):
# x expected to be [B, C, *]
# Workaround as layer norm can't currently be applied on arbitrary dimension: https://github.com/pytorch/pytorch/issues/71465
b, c, *spatial = x.shape
x = x.reshape(b, c, -1).transpose(1, 2) # -> [B, C, N] -> [B, N, C]
x = self.norm(x)
x, gate = self.proj(x).chunk(2, dim=-1)
x = x * F.gelu(gate)
return x.transpose(1, 2).reshape(b, -1, *spatial) # -> [B, C, N] -> [B, C, *]
def zero_module(module):
"""
Zero out the parameters of a module and return it.
"""
for p in module.parameters():
p.detach().zero_()
return module
def compute_attention(q,k,v , num_heads, scale):
q, k, v = map(lambda t: rearrange(t, 'b (h d) n -> (b h) d n', h=num_heads), (q, k, v)) # [(BxHeads), Dim_per_head, N]
attn = (torch.einsum('b d i, b d j -> b i j', q*scale, k*scale)).softmax(dim=-1) # Matrix product = [(BxHeads), Dim_per_head, N] * [(BxHeads), Dim_per_head, N'] =[(BxHeads), N, N']
out = torch.einsum('b i j, b d j-> b d i', attn, v) # Matrix product: [(BxHeads), N, N'] * [(BxHeads), Dim_per_head, N'] = [(BxHeads), Dim_per_head, N]
out = rearrange(out, '(b h) d n-> b (h d) n', h=num_heads) # -> [B, (Heads x Dim_per_head), N]
return out
class LinearTransformerNd(nn.Module):
""" Combines multi-head self-attention and multi-head cross-attention.
Multi-Head Self-Attention:
Similar to multi-head self-attention (https://arxiv.org/abs/1706.03762) without Norm+MLP (compare Monai TransformerBlock)
Proposed here: https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
Similar to: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/diffusionmodules/openaimodel.py#L278
Similar to: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L80
Similar to: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/dfbafee555bdae80b55d63a989073836bbfc257e/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L209
Similar to: https://github.com/CompVis/stable-diffusion/blob/21f890f9da3cfbeaba8e2ac3c425ee9e998d5229/ldm/modules/diffusionmodules/model.py#L150
CrossAttention:
Proposed here: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L152
"""
def __init__(
self,
spatial_dims,
in_channels,
out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled
num_heads=8,
ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs)
norm_name=("GROUP", {'num_groups':32, "affine": True}), # Or use LayerNorm but be aware of https://github.com/pytorch/pytorch/issues/71465 (=> GroupNorm with num_groups=1)
dropout=None,
emb_dim=None,
):
super().__init__()
hid_channels = num_heads*ch_per_head
self.num_heads = num_heads
self.scale = ch_per_head**-0.25 # Should be 1/sqrt("queries and keys of dimension"), Note: additional sqrt needed as it follows OpenAI: (q * scale) * (k * scale) instead of (q *k) * scale
self.norm_x = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels)
emb_dim = in_channels if emb_dim is None else emb_dim
Convolution = Conv["conv", spatial_dims]
self.to_q = Convolution(in_channels, hid_channels, 1)
self.to_k = Convolution(emb_dim, hid_channels, 1)
self.to_v = Convolution(emb_dim, hid_channels, 1)
self.to_out = nn.Sequential(
zero_module(Convolution(hid_channels, out_channels, 1)),
nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims)
)
def forward(self, x, embedding=None):
# x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *]
# if no embedding is given, cross-attention defaults to self-attention
# Normalize
b, c, *spatial = x.shape
x_n = self.norm_x(x)
# Attention: embedding (cross-attention) or x (self-attention)
if embedding is None:
embedding = x_n # WARNING: This assumes that emb_dim==in_channels
else:
if embedding.ndim == 2:
embedding = embedding.reshape(*embedding.shape[:2], *[1]*(x.ndim-2)) # [B, C*] -> [B, C*, *]
# Why no normalization for embedding here?
# Convolution
q = self.to_q(x_n) # -> [B, (Heads x Dim_per_head), *]
k = self.to_k(embedding) # -> [B, (Heads x Dim_per_head), *]
v = self.to_v(embedding) # -> [B, (Heads x Dim_per_head), *]
# Flatten
q = q.reshape(b, c, -1) # -> [B, (Heads x Dim_per_head), N]
k = k.reshape(*embedding.shape[:2], -1) # -> [B, (Heads x Dim_per_head), N']
v = v.reshape(*embedding.shape[:2], -1) # -> [B, (Heads x Dim_per_head), N']
# Apply attention
out = compute_attention(q, k, v, self.num_heads, self.scale)
out = out.reshape(*out.shape[:2], *spatial) # -> [B, (Heads x Dim_per_head), *]
out = self.to_out(out) # -> [B, C', *]
if x.shape == out.shape:
out = x + out
return out # [B, C', *]
class LinearTransformer(nn.Module):
""" See LinearTransformer, however this implementation is fixed to Conv1d/Linear"""
def __init__(
self,
spatial_dims,
in_channels,
out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled
num_heads,
ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs)
norm_name=("GROUP", {'num_groups':32, "affine": True}),
dropout=None,
emb_dim=None
):
super().__init__()
hid_channels = num_heads*ch_per_head
self.num_heads = num_heads
self.scale = ch_per_head**-0.25 # Should be 1/sqrt("queries and keys of dimension"), Note: additional sqrt needed as it follows OpenAI: (q * scale) * (k * scale) instead of (q *k) * scale
self.norm_x = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels)
emb_dim = in_channels if emb_dim is None else emb_dim
# Note: Conv1d and Linear are interchangeable but order of input changes [B, C, N] <-> [B, N, C]
self.to_q = nn.Conv1d(in_channels, hid_channels, 1)
self.to_k = nn.Conv1d(emb_dim, hid_channels, 1)
self.to_v = nn.Conv1d(emb_dim, hid_channels, 1)
# self.to_qkv = nn.Conv1d(emb_dim, hid_channels*3, 1)
self.to_out = nn.Sequential(
zero_module(nn.Conv1d(hid_channels, out_channels, 1)),
nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims)
)
def forward(self, x, embedding=None):
# x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *]
# if no embedding is given, cross-attention defaults to self-attention
# Normalize
b, c, *spatial = x.shape
x_n = self.norm_x(x)
# Attention: embedding (cross-attention) or x (self-attention)
if embedding is None:
embedding = x_n # WARNING: This assumes that emb_dim==in_channels
else:
if embedding.ndim == 2:
embedding = embedding.reshape(*embedding.shape[:2], *[1]*(x.ndim-2)) # [B, C*] -> [B, C*, *]
# Why no normalization for embedding here?
# Flatten
x_n = x_n.reshape(b, c, -1) # [B, C, *] -> [B, C, N]
embedding = embedding.reshape(*embedding.shape[:2], -1) # [B, C*, *] -> [B, C*, N']
# Convolution
q = self.to_q(x_n) # -> [B, (Heads x Dim_per_head), N]
k = self.to_k(embedding) # -> [B, (Heads x Dim_per_head), N']
v = self.to_v(embedding) # -> [B, (Heads x Dim_per_head), N']
# qkv = self.to_qkv(x_n)
# q,k,v = qkv.split(qkv.shape[1]//3, dim=1)
# Apply attention
out = compute_attention(q, k, v, self.num_heads, self.scale)
out = self.to_out(out) # -> [B, C', N]
out = out.reshape(*out.shape[:2], *spatial) # -> [B, C', *]
if x.shape == out.shape:
out = x + out
return out # [B, C', *]
class BasicTransformerBlock(nn.Module):
def __init__(
self,
spatial_dims,
in_channels,
out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled
num_heads,
ch_per_head=32,
norm_name=("GROUP", {'num_groups':32, "affine": True}),
dropout=None,
emb_dim=None
):
super().__init__()
self.self_atn = LinearTransformer(spatial_dims, in_channels, in_channels, num_heads, ch_per_head, norm_name, dropout, None)
if emb_dim is not None:
self.cros_atn = LinearTransformer(spatial_dims, in_channels, in_channels, num_heads, ch_per_head, norm_name, dropout, emb_dim)
self.proj_out = nn.Sequential(
GEGLU(in_channels, in_channels*4),
nn.Identity() if dropout is None else get_dropout_layer(name=dropout, dropout_dim=spatial_dims),
Conv["conv", spatial_dims](in_channels*4, out_channels, 1, bias=True)
)
def forward(self, x, embedding=None):
# x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *]
x = self.self_atn(x)
if embedding is not None:
x = self.cros_atn(x, embedding=embedding)
out = self.proj_out(x)
if out.shape[1] == x.shape[1]:
return out + x
return x
class SpatialTransformer(nn.Module):
""" Proposed here: https://github.com/CompVis/stable-diffusion/blob/69ae4b35e0a0f6ee1af8bb9a5d0016ccb27e36dc/ldm/modules/attention.py#L218
Unrelated to: https://arxiv.org/abs/1506.02025
"""
def __init__(
self,
spatial_dims,
in_channels,
out_channels, # WARNING: if out_channels != in_channels, skip connection is disabled
num_heads,
ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs)
norm_name = ("GROUP", {'num_groups':32, "affine": True}),
dropout=None,
emb_dim=None,
depth=1
):
super().__init__()
self.in_channels = in_channels
self.norm = get_norm_layer(norm_name, spatial_dims=spatial_dims, channels=in_channels)
conv_class = Conv["conv", spatial_dims]
hid_channels = num_heads*ch_per_head
self.proj_in = conv_class(
in_channels,
hid_channels,
kernel_size=1,
stride=1,
padding=0,
)
self.transformer_blocks = nn.ModuleList([
BasicTransformerBlock(spatial_dims, hid_channels, hid_channels, num_heads, ch_per_head, norm_name, dropout=dropout, emb_dim=emb_dim)
for _ in range(depth)]
)
self.proj_out = conv_class( # Note: zero_module is used in original code
hid_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
)
def forward(self, x, embedding=None):
# x expected to be [B, C, *] and embedding is None or [B, C*] or [B, C*, *]
# Note: if no embedding is given, cross-attention is disabled
h = self.norm(x)
h = self.proj_in(h)
for block in self.transformer_blocks:
h = block(h, embedding=embedding)
h = self.proj_out(h) # -> [B, C'', *]
if h.shape == x.shape:
return h + x
return h
class Attention(nn.Module):
def __init__(
self,
spatial_dims,
in_channels,
out_channels,
num_heads=8,
ch_per_head=32, # rule of thumb: 32 or 64 channels per head (see stable-diffusion / diffusion models beat GANs)
norm_name = ("GROUP", {'num_groups':32, "affine": True}),
dropout=0,
emb_dim=None,
depth=1,
attention_type='linear'
) -> None:
super().__init__()
if attention_type == 'spatial':
self.attention = SpatialTransformer(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
num_heads=num_heads,
ch_per_head=ch_per_head,
depth=depth,
norm_name=norm_name,
dropout=dropout,
emb_dim=emb_dim
)
elif attention_type == 'linear':
self.attention = LinearTransformer(
spatial_dims=spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
num_heads=num_heads,
ch_per_head=ch_per_head,
norm_name=norm_name,
dropout=dropout,
emb_dim=emb_dim
)
def forward(self, x, emb=None):
if hasattr(self, 'attention'):
return self.attention(x, emb)
else:
return x