Spaces:
Sleeping
Sleeping
import torch.nn as nn | |
from timm.models.layers import DropPath | |
class FeedForward(nn.Module): | |
def __init__(self, dim, hidden_dim, dropout, out_dim=None): | |
super().__init__() | |
self.fc1 = nn.Linear(dim, hidden_dim) | |
self.act = nn.GELU() | |
if out_dim is None: | |
out_dim = dim | |
self.fc2 = nn.Linear(hidden_dim, out_dim) | |
self.drop = nn.Dropout(dropout) | |
def unwrapped(self): | |
return self | |
def forward(self, x): | |
x = self.fc1(x) | |
x = self.act(x) | |
x = self.drop(x) | |
x = self.fc2(x) | |
x = self.drop(x) | |
return x | |
class Attention(nn.Module): | |
def __init__(self, dim, heads, dropout): | |
super().__init__() | |
self.heads = heads | |
head_dim = dim // heads | |
self.scale = head_dim**-0.5 | |
self.attn = None | |
self.qkv = nn.Linear(dim, dim * 3) | |
self.attn_drop = nn.Dropout(dropout) | |
self.proj = nn.Linear(dim, dim) | |
self.proj_drop = nn.Dropout(dropout) | |
def unwrapped(self): | |
return self | |
def forward(self, x, mask=None): | |
B, N, C = x.shape | |
qkv = self.qkv(x).reshape(B, N, 3, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) | |
q, k, v = ( | |
qkv[0], | |
qkv[1], | |
qkv[2], | |
) | |
attn = (q @ k.transpose(-2, -1)) * self.scale | |
attn = attn.softmax(dim=-1) | |
attn = self.attn_drop(attn) | |
x = (attn @ v).transpose(1, 2).reshape(B, N, C) | |
x = self.proj(x) | |
x = self.proj_drop(x) | |
return x, attn | |
class Block(nn.Module): | |
def __init__(self, dim, heads, mlp_dim, dropout, drop_path): | |
super().__init__() | |
self.norm1 = nn.LayerNorm(dim) | |
self.norm2 = nn.LayerNorm(dim) | |
self.attn = Attention(dim, heads, dropout) | |
self.mlp = FeedForward(dim, mlp_dim, dropout) | |
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() | |
def forward(self, x, mask=None, return_attention=False): | |
y, attn = self.attn(self.norm1(x), mask) | |
if return_attention: | |
return attn | |
x = x + self.drop_path(y) | |
x = x + self.drop_path(self.mlp(self.norm2(x))) | |
return x | |