sky24h's picture
init commit
1d24639
import torch
import math
import torch.nn.functional as F
from torch import nn, einsum
from inspect import isfunction
def exists(val):
return val is not None
def uniq(arr):
return{el: True for el in arr}.keys()
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def init_(tensor):
dim = tensor.shape[-1]
std = 1 / math.sqrt(dim)
tensor.uniform_(-std, std)
return tensor
# feedforward
class GEGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.proj = nn.Linear(dim_in, dim_out * 2)
def forward(self, x):
x, gate = self.proj(x).chunk(2, dim=-1)
return x * F.gelu(gate)
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, glu=True, dropout=0.):
super().__init__()
inner_dim = int(dim * mult)
dim_out = default(dim_out, dim)
project_in = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU()
) if not glu else GEGLU(dim, inner_dim)
self.net = nn.Sequential(
project_in,
nn.Dropout(dropout),
nn.Linear(inner_dim, dim_out)
)
def forward(self, x):
return self.net(x)
class SelfAttention(nn.Module):
def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.):
super().__init__()
inner_dim = dim_head * heads
self.scale = dim_head ** -0.5
self.heads = heads
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(query_dim, inner_dim, bias=False)
self.to_v = nn.Linear(query_dim, inner_dim, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) )
def forward(self, x):
q = self.to_q(x) # B*N*(H*C)
k = self.to_k(x) # B*N*(H*C)
v = self.to_v(x) # B*N*(H*C)
B, N, HC = q.shape
H = self.heads
C = HC // H
q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C
sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale # (B*H)*N*N
attn = sim.softmax(dim=-1) # (B*H)*N*N
out = torch.einsum('b i j, b j c -> b i c', attn, v) # (B*H)*N*C
out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C)
return self.to_out(out)
class Resampler(nn.Module):
def __init__(self, query_dim=1024, n_heads=8, d_head=64):
super().__init__()
self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
self.ff = FeedForward(query_dim, glu=True)
self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.ff(self.norm2(x))
return x