Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import math | |
import torch.nn.functional as F | |
from torch import nn, einsum | |
from inspect import isfunction | |
def exists(val): | |
return val is not None | |
def uniq(arr): | |
return{el: True for el in arr}.keys() | |
def default(val, d): | |
if exists(val): | |
return val | |
return d() if isfunction(d) else d | |
def max_neg_value(t): | |
return -torch.finfo(t.dtype).max | |
def init_(tensor): | |
dim = tensor.shape[-1] | |
std = 1 / math.sqrt(dim) | |
tensor.uniform_(-std, std) | |
return tensor | |
# feedforward | |
class GEGLU(nn.Module): | |
def __init__(self, dim_in, dim_out): | |
super().__init__() | |
self.proj = nn.Linear(dim_in, dim_out * 2) | |
def forward(self, x): | |
x, gate = self.proj(x).chunk(2, dim=-1) | |
return x * F.gelu(gate) | |
class FeedForward(nn.Module): | |
def __init__(self, dim, dim_out=None, mult=4, glu=True, dropout=0.): | |
super().__init__() | |
inner_dim = int(dim * mult) | |
dim_out = default(dim_out, dim) | |
project_in = nn.Sequential( | |
nn.Linear(dim, inner_dim), | |
nn.GELU() | |
) if not glu else GEGLU(dim, inner_dim) | |
self.net = nn.Sequential( | |
project_in, | |
nn.Dropout(dropout), | |
nn.Linear(inner_dim, dim_out) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class SelfAttention(nn.Module): | |
def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.): | |
super().__init__() | |
inner_dim = dim_head * heads | |
self.scale = dim_head ** -0.5 | |
self.heads = heads | |
self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
self.to_k = nn.Linear(query_dim, inner_dim, bias=False) | |
self.to_v = nn.Linear(query_dim, inner_dim, bias=False) | |
self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) | |
def forward(self, x): | |
q = self.to_q(x) # B*N*(H*C) | |
k = self.to_k(x) # B*N*(H*C) | |
v = self.to_v(x) # B*N*(H*C) | |
B, N, HC = q.shape | |
H = self.heads | |
C = HC // H | |
q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C | |
k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C | |
v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C | |
sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale # (B*H)*N*N | |
attn = sim.softmax(dim=-1) # (B*H)*N*N | |
out = torch.einsum('b i j, b j c -> b i c', attn, v) # (B*H)*N*C | |
out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C) | |
return self.to_out(out) | |
class Resampler(nn.Module): | |
def __init__(self, query_dim=1024, n_heads=8, d_head=64): | |
super().__init__() | |
self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head) | |
self.ff = FeedForward(query_dim, glu=True) | |
self.norm1 = nn.LayerNorm(query_dim) | |
self.norm2 = nn.LayerNorm(query_dim) | |
def forward(self, x): | |
x = x + self.attn(self.norm1(x)) | |
x = x + self.ff(self.norm2(x)) | |
return x |