nesterus
moved contents of presentations repo
d90acf0
raw
history blame
2.93 kB
import math
import torch
from torch import nn, einsum
from einops import rearrange, repeat
from .utils import exist
class Identity(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
@staticmethod
def forward(x, *args, **kwargs):
return x
class SinusoidalPosEmb(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=x.device, dtype=x.dtype) * -emb)
emb = rearrange(x, 'i -> i 1') * rearrange(emb, 'j -> 1 j')
return torch.cat((emb.sin(), emb.cos()), dim=-1)
class ConditionalGroupNorm(nn.Module):
def __init__(self, groups, normalized_shape, context_dim):
super().__init__()
self.norm = nn.GroupNorm(groups, normalized_shape, affine=False)
self.context_mlp = nn.Sequential(
nn.SiLU(),
nn.Linear(context_dim, 2 * normalized_shape)
)
self.context_mlp[1].weight.data.zero_()
self.context_mlp[1].bias.data.zero_()
def forward(self, x, context):
context = self.context_mlp(context)
ndims = ' 1' * len(x.shape[2:])
context = rearrange(context, f'b c -> b c{ndims}')
scale, shift = context.chunk(2, dim=1)
x = self.norm(x) * (scale + 1.) + shift
return x
class Attention(nn.Module):
def __init__(self, in_channels, out_channels, context_dim, head_dim=64):
super().__init__()
assert out_channels % head_dim == 0
self.num_heads = out_channels // head_dim
self.scale = head_dim ** -0.5
self.to_query = nn.Linear(in_channels, out_channels, bias=False)
self.to_key = nn.Linear(context_dim, out_channels, bias=False)
self.to_value = nn.Linear(context_dim, out_channels, bias=False)
self.output_layer = nn.Linear(out_channels, out_channels, bias=False)
def forward(self, x, context, context_mask=None):
query = rearrange(self.to_query(x), 'b n (h d) -> b h n d', h=self.num_heads)
key = rearrange(self.to_key(context), 'b n (h d) -> b h n d', h=self.num_heads)
value = rearrange(self.to_value(context), 'b n (h d) -> b h n d', h=self.num_heads)
attention_matrix = einsum('b h i d, b h j d -> b h i j', query, key) * self.scale
if exist(context_mask):
max_neg_value = -torch.finfo(attention_matrix.dtype).max
context_mask = rearrange(context_mask, 'b j -> b 1 1 j')
attention_matrix = attention_matrix.masked_fill(~context_mask, max_neg_value)
attention_matrix = attention_matrix.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attention_matrix, value)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.output_layer(out)
return out