import torch import math import torch.nn.functional as F from torch import nn, einsum from inspect import isfunction def exists(val): return val is not None def uniq(arr): return{el: True for el in arr}.keys() def default(val, d): if exists(val): return val return d() if isfunction(d) else d def max_neg_value(t): return -torch.finfo(t.dtype).max def init_(tensor): dim = tensor.shape[-1] std = 1 / math.sqrt(dim) tensor.uniform_(-std, std) return tensor # feedforward class GEGLU(nn.Module): def __init__(self, dim_in, dim_out): super().__init__() self.proj = nn.Linear(dim_in, dim_out * 2) def forward(self, x): x, gate = self.proj(x).chunk(2, dim=-1) return x * F.gelu(gate) class FeedForward(nn.Module): def __init__(self, dim, dim_out=None, mult=4, glu=True, dropout=0.): super().__init__() inner_dim = int(dim * mult) dim_out = default(dim_out, dim) project_in = nn.Sequential( nn.Linear(dim, inner_dim), nn.GELU() ) if not glu else GEGLU(dim, inner_dim) self.net = nn.Sequential( project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) ) def forward(self, x): return self.net(x) class SelfAttention(nn.Module): def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.): super().__init__() inner_dim = dim_head * heads self.scale = dim_head ** -0.5 self.heads = heads self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(query_dim, inner_dim, bias=False) self.to_v = nn.Linear(query_dim, inner_dim, bias=False) self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) def forward(self, x): q = self.to_q(x) # B*N*(H*C) k = self.to_k(x) # B*N*(H*C) v = self.to_v(x) # B*N*(H*C) B, N, HC = q.shape H = self.heads C = HC // H q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale # (B*H)*N*N attn = sim.softmax(dim=-1) # (B*H)*N*N out = torch.einsum('b i j, b j c -> b i c', attn, v) # (B*H)*N*C out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C) return self.to_out(out) class Resampler(nn.Module): def __init__(self, query_dim=1024, n_heads=8, d_head=64): super().__init__() self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head) self.ff = FeedForward(query_dim, glu=True) self.norm1 = nn.LayerNorm(query_dim) self.norm2 = nn.LayerNorm(query_dim) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.ff(self.norm2(x)) return x