Spaces:
Build error
Build error
""" | |
@Date: 2021/09/01 | |
@description: | |
""" | |
import warnings | |
import math | |
import torch | |
import torch.nn.functional as F | |
from torch import nn, einsum | |
from einops import rearrange | |
def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): | |
# Cut & paste from PyTorch official master until it's in a few official releases - RW | |
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf | |
def norm_cdf(x): | |
# Computes standard normal cumulative distribution function | |
return (1. + math.erf(x / math.sqrt(2.))) / 2. | |
if (mean < a - 2 * std) or (mean > b + 2 * std): | |
warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " | |
"The distribution of values may be incorrect.", | |
stacklevel=2) | |
with torch.no_grad(): | |
# Values are generated by using a truncated uniform distribution and | |
# then using the inverse CDF for the normal distribution. | |
# Get upper and lower cdf values | |
l = norm_cdf((a - mean) / std) | |
u = norm_cdf((b - mean) / std) | |
# Uniformly fill tensor with values from [l, u], then translate to | |
# [2l-1, 2u-1]. | |
tensor.uniform_(2 * l - 1, 2 * u - 1) | |
# Use inverse cdf transform for normal distribution to get truncated | |
# standard normal | |
tensor.erfinv_() | |
# Transform to proper mean, std | |
tensor.mul_(std * math.sqrt(2.)) | |
tensor.add_(mean) | |
# Clamp to ensure it's in the proper range | |
tensor.clamp_(min=a, max=b) | |
return tensor | |
class PreNorm(nn.Module): | |
def __init__(self, dim, fn): | |
super().__init__() | |
self.norm = nn.LayerNorm(dim) | |
self.fn = fn | |
def forward(self, x, **kwargs): | |
return self.fn(self.norm(x), **kwargs) | |
# compatibility pytorch < 1.4 | |
class GELU(nn.Module): | |
def forward(self, input): | |
return F.gelu(input) | |
class Attend(nn.Module): | |
def __init__(self, dim=None): | |
super().__init__() | |
self.dim = dim | |
def forward(self, input): | |
return F.softmax(input, dim=self.dim, dtype=input.dtype) | |
class FeedForward(nn.Module): | |
def __init__(self, dim, hidden_dim, dropout=0.): | |
super().__init__() | |
self.net = nn.Sequential( | |
nn.Linear(dim, hidden_dim), | |
GELU(), | |
nn.Dropout(dropout), | |
nn.Linear(hidden_dim, dim), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x): | |
return self.net(x) | |
class RelativePosition(nn.Module): | |
def __init__(self, heads, patch_num=None, rpe=None): | |
super().__init__() | |
self.rpe = rpe | |
self.heads = heads | |
self.patch_num = patch_num | |
if rpe == 'lr_parameter': | |
# -255 ~ 0 ~ 255 all count : patch * 2 - 1 | |
count = patch_num * 2 - 1 | |
self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) | |
nn.init.xavier_uniform_(self.rpe_table) | |
elif rpe == 'lr_parameter_mirror': | |
# 0 ~ 127 128 ~ 1 all count : patch_num // 2 + 1 | |
count = patch_num // 2 + 1 | |
self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) | |
nn.init.xavier_uniform_(self.rpe_table) | |
elif rpe == 'lr_parameter_half': | |
# -127 ~ 0 ~ 128 all count : patch | |
count = patch_num | |
self.rpe_table = nn.Parameter(torch.Tensor(count, heads)) | |
nn.init.xavier_uniform_(self.rpe_table) | |
elif rpe == 'fix_angle': | |
# 0 ~ 127 128 ~ 1 all count : patch_num // 2 + 1 | |
count = patch_num // 2 + 1 | |
# we think that closer proximity should have stronger relationships | |
rpe_table = (torch.arange(count, 0, -1) / count)[..., None].repeat(1, heads) | |
self.register_buffer('rpe_table', rpe_table) | |
def get_relative_pos_embed(self): | |
range_vec = torch.arange(self.patch_num) | |
distance_mat = range_vec[None, :] - range_vec[:, None] | |
if self.rpe == 'lr_parameter': | |
# -255 ~ 0 ~ 255 -> 0 ~ 255 ~ 255 + 255 | |
distance_mat += self.patch_num - 1 # remove negative | |
return self.rpe_table[distance_mat].permute(2, 0, 1)[None] | |
elif self.rpe == 'lr_parameter_mirror' or self.rpe == 'fix_angle': | |
distance_mat[distance_mat < 0] = -distance_mat[distance_mat < 0] # mirror | |
distance_mat[distance_mat > self.patch_num // 2] = self.patch_num - distance_mat[ | |
distance_mat > self.patch_num // 2] # remove repeat | |
return self.rpe_table[distance_mat].permute(2, 0, 1)[None] | |
elif self.rpe == 'lr_parameter_half': | |
distance_mat[distance_mat > self.patch_num // 2] = distance_mat[ | |
distance_mat > self.patch_num // 2] - self.patch_num # remove repeat > 128 exp: 129 -> -127 | |
distance_mat[distance_mat < -self.patch_num // 2 + 1] = distance_mat[ | |
distance_mat < -self.patch_num // 2 + 1] + self.patch_num # remove repeat < -127 exp: -128 -> 128 | |
# -127 ~ 0 ~ 128 -> 0 ~ 0 ~ 127 + 127 + 128 | |
distance_mat += self.patch_num//2 - 1 # remove negative | |
return self.rpe_table[distance_mat].permute(2, 0, 1)[None] | |
def forward(self, attn): | |
return attn + self.get_relative_pos_embed() | |
class Attention(nn.Module): | |
def __init__(self, dim, heads=8, dim_head=64, dropout=0., patch_num=None, rpe=None, rpe_pos=1): | |
""" | |
:param dim: | |
:param heads: | |
:param dim_head: | |
:param dropout: | |
:param patch_num: | |
:param rpe: relative position embedding | |
""" | |
super().__init__() | |
self.relative_pos_embed = None if patch_num is None or rpe is None else RelativePosition(heads, patch_num, rpe) | |
inner_dim = dim_head * heads | |
project_out = not (heads == 1 and dim_head == dim) | |
self.heads = heads | |
self.scale = dim_head ** -0.5 | |
self.rpe_pos = rpe_pos | |
self.attend = Attend(dim=-1) | |
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) | |
self.to_out = nn.Sequential( | |
nn.Linear(inner_dim, dim), | |
nn.Dropout(dropout) | |
) if project_out else nn.Identity() | |
def forward(self, x): | |
b, n, _, h = *x.shape, self.heads | |
qkv = self.to_qkv(x).chunk(3, dim=-1) | |
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv) | |
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale | |
if self.rpe_pos == 0: | |
if self.relative_pos_embed is not None: | |
dots = self.relative_pos_embed(dots) | |
attn = self.attend(dots) | |
if self.rpe_pos == 1: | |
if self.relative_pos_embed is not None: | |
attn = self.relative_pos_embed(attn) | |
out = einsum('b h i j, b h j d -> b h i d', attn, v) | |
out = rearrange(out, 'b h n d -> b n (h d)') | |
return self.to_out(out) | |
class AbsolutePosition(nn.Module): | |
def __init__(self, dim, dropout=0., patch_num=None, ape=None): | |
super().__init__() | |
self.ape = ape | |
if ape == 'lr_parameter': | |
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, patch_num, dim)) | |
trunc_normal_(self.absolute_pos_embed, std=.02) | |
elif ape == 'fix_angle': | |
angle = torch.arange(0, patch_num, dtype=torch.float) / patch_num * (math.pi * 2) | |
self.absolute_pos_embed = torch.sin(angle)[..., None].repeat(1, dim)[None] | |
def forward(self, x): | |
return x + self.absolute_pos_embed | |
class WinAttention(nn.Module): | |
def __init__(self, dim, win_size=8, shift=0, heads=8, dim_head=64, dropout=0., rpe=None, rpe_pos=1): | |
super().__init__() | |
self.win_size = win_size | |
self.shift = shift | |
self.attend = Attention(dim, heads=heads, dim_head=dim_head, | |
dropout=dropout, patch_num=win_size, rpe=None if rpe is None else 'lr_parameter', | |
rpe_pos=rpe_pos) | |
def forward(self, x): | |
b = x.shape[0] | |
if self.shift != 0: | |
x = torch.roll(x, shifts=self.shift, dims=-2) | |
x = rearrange(x, 'b (m w) d -> (b m) w d', w=self.win_size) # split windows | |
out = self.attend(x) | |
out = rearrange(out, '(b m) w d -> b (m w) d ', b=b) # recover windows | |
if self.shift != 0: | |
out = torch.roll(out, shifts=-self.shift, dims=-2) | |
return out | |
class Conv(nn.Module): | |
def __init__(self, dim, dropout=0.): | |
super().__init__() | |
self.dim = dim | |
self.net = nn.Sequential( | |
nn.Conv1d(dim, dim, kernel_size=3, stride=1, padding=0), | |
nn.Dropout(dropout) | |
) | |
def forward(self, x): | |
x = x.transpose(1, 2) | |
x = torch.cat([x[..., -1:], x, x[..., :1]], dim=-1) | |
x = self.net(x) | |
return x.transpose(1, 2) | |