Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn.bricks import DropPath | |
from mmengine.utils import digit_version | |
from mmengine.utils.dl_utils import TORCH_VERSION | |
def rope(x, dim): | |
"""Applies Rotary Position Embedding to input tensor. | |
Args: | |
x (torch.Tensor): Input tensor. | |
dim (int | list[int]): The spatial dimension(s) to apply | |
rotary position embedding. | |
Returns: | |
torch.Tensor: The tensor after applying rotary position | |
embedding. | |
Reference: | |
`RoFormer: Enhanced Transformer with Rotary | |
Position Embedding <https://arxiv.org/abs/2104.09864>`_ | |
""" | |
shape = x.shape | |
if isinstance(dim, int): | |
dim = [dim] | |
spatial_shape = [shape[i] for i in dim] | |
total_len = 1 | |
for i in spatial_shape: | |
total_len *= i | |
position = torch.reshape( | |
torch.arange(total_len, dtype=torch.int, device=x.device), | |
spatial_shape) | |
for i in range(dim[-1] + 1, len(shape) - 1, 1): | |
position = torch.unsqueeze(position, dim=-1) | |
half_size = shape[-1] // 2 | |
freq_seq = -torch.arange( | |
half_size, dtype=torch.int, device=x.device) / float(half_size) | |
inv_freq = 10000**-freq_seq | |
sinusoid = position[..., None] * inv_freq[None, None, :] | |
sin = torch.sin(sinusoid) | |
cos = torch.cos(sinusoid) | |
x1, x2 = torch.chunk(x, 2, dim=-1) | |
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1) | |
class Scale(nn.Module): | |
"""Scale vector by element multiplications. | |
Args: | |
dim (int): The dimension of the scale vector. | |
init_value (float, optional): The initial value of the scale vector. | |
Defaults to 1.0. | |
trainable (bool, optional): Whether the scale vector is trainable. | |
Defaults to True. | |
""" | |
def __init__(self, dim, init_value=1., trainable=True): | |
super().__init__() | |
self.scale = nn.Parameter( | |
init_value * torch.ones(dim), requires_grad=trainable) | |
def forward(self, x): | |
"""Forward function.""" | |
return x * self.scale | |
class ScaleNorm(nn.Module): | |
"""Scale Norm. | |
Args: | |
dim (int): The dimension of the scale vector. | |
eps (float, optional): The minimum value in clamp. Defaults to 1e-5. | |
Reference: | |
`Transformers without Tears: Improving the Normalization | |
of Self-Attention <https://arxiv.org/abs/1910.05895>`_ | |
""" | |
def __init__(self, dim, eps=1e-5): | |
super().__init__() | |
self.scale = dim**-0.5 | |
self.eps = eps | |
self.g = nn.Parameter(torch.ones(1)) | |
def forward(self, x): | |
"""Forward function. | |
Args: | |
x (torch.Tensor): Input tensor. | |
Returns: | |
torch.Tensor: The tensor after applying scale norm. | |
""" | |
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale | |
return x / norm.clamp(min=self.eps) * self.g | |
class RTMCCBlock(nn.Module): | |
"""Gated Attention Unit (GAU) in RTMBlock. | |
Args: | |
num_token (int): The number of tokens. | |
in_token_dims (int): The input token dimension. | |
out_token_dims (int): The output token dimension. | |
expansion_factor (int, optional): The expansion factor of the | |
intermediate token dimension. Defaults to 2. | |
s (int, optional): The self-attention feature dimension. | |
Defaults to 128. | |
eps (float, optional): The minimum value in clamp. Defaults to 1e-5. | |
dropout_rate (float, optional): The dropout rate. Defaults to 0.0. | |
drop_path (float, optional): The drop path rate. Defaults to 0.0. | |
attn_type (str, optional): Type of attention which should be one of | |
the following options: | |
- 'self-attn': Self-attention. | |
- 'cross-attn': Cross-attention. | |
Defaults to 'self-attn'. | |
act_fn (str, optional): The activation function which should be one | |
of the following options: | |
- 'ReLU': ReLU activation. | |
- 'SiLU': SiLU activation. | |
Defaults to 'SiLU'. | |
bias (bool, optional): Whether to use bias in linear layers. | |
Defaults to False. | |
use_rel_bias (bool, optional): Whether to use relative bias. | |
Defaults to True. | |
pos_enc (bool, optional): Whether to use rotary position | |
embedding. Defaults to False. | |
Reference: | |
`Transformer Quality in Linear Time | |
<https://arxiv.org/abs/2202.10447>`_ | |
""" | |
def __init__(self, | |
num_token, | |
in_token_dims, | |
out_token_dims, | |
expansion_factor=2, | |
s=128, | |
eps=1e-5, | |
dropout_rate=0., | |
drop_path=0., | |
attn_type='self-attn', | |
act_fn='SiLU', | |
bias=False, | |
use_rel_bias=True, | |
pos_enc=False): | |
super(RTMCCBlock, self).__init__() | |
self.s = s | |
self.num_token = num_token | |
self.use_rel_bias = use_rel_bias | |
self.attn_type = attn_type | |
self.pos_enc = pos_enc | |
self.drop_path = DropPath(drop_path) \ | |
if drop_path > 0. else nn.Identity() | |
self.e = int(in_token_dims * expansion_factor) | |
if use_rel_bias: | |
if attn_type == 'self-attn': | |
self.w = nn.Parameter( | |
torch.rand([2 * num_token - 1], dtype=torch.float)) | |
else: | |
self.a = nn.Parameter(torch.rand([1, s], dtype=torch.float)) | |
self.b = nn.Parameter(torch.rand([1, s], dtype=torch.float)) | |
self.o = nn.Linear(self.e, out_token_dims, bias=bias) | |
if attn_type == 'self-attn': | |
self.uv = nn.Linear(in_token_dims, 2 * self.e + self.s, bias=bias) | |
self.gamma = nn.Parameter(torch.rand((2, self.s))) | |
self.beta = nn.Parameter(torch.rand((2, self.s))) | |
else: | |
self.uv = nn.Linear(in_token_dims, self.e + self.s, bias=bias) | |
self.k_fc = nn.Linear(in_token_dims, self.s, bias=bias) | |
self.v_fc = nn.Linear(in_token_dims, self.e, bias=bias) | |
nn.init.xavier_uniform_(self.k_fc.weight) | |
nn.init.xavier_uniform_(self.v_fc.weight) | |
self.ln = ScaleNorm(in_token_dims, eps=eps) | |
nn.init.xavier_uniform_(self.uv.weight) | |
if act_fn == 'SiLU': | |
assert digit_version(TORCH_VERSION) >= digit_version('1.7.0'), \ | |
'SiLU activation requires PyTorch version >= 1.7' | |
self.act_fn = nn.SiLU(True) | |
else: | |
self.act_fn = nn.ReLU(True) | |
if in_token_dims == out_token_dims: | |
self.shortcut = True | |
self.res_scale = Scale(in_token_dims) | |
else: | |
self.shortcut = False | |
self.sqrt_s = math.sqrt(s) | |
self.dropout_rate = dropout_rate | |
if dropout_rate > 0.: | |
self.dropout = nn.Dropout(dropout_rate) | |
def rel_pos_bias(self, seq_len, k_len=None): | |
"""Add relative position bias.""" | |
if self.attn_type == 'self-attn': | |
t = F.pad(self.w[:2 * seq_len - 1], [0, seq_len]).repeat(seq_len) | |
t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2) | |
r = (2 * seq_len - 1) // 2 | |
t = t[..., r:-r] | |
else: | |
a = rope(self.a.repeat(seq_len, 1), dim=0) | |
b = rope(self.b.repeat(k_len, 1), dim=0) | |
t = torch.bmm(a, b.permute(0, 2, 1)) | |
return t | |
def _forward(self, inputs): | |
"""GAU Forward function.""" | |
if self.attn_type == 'self-attn': | |
x = inputs | |
else: | |
x, k, v = inputs | |
x = self.ln(x) | |
uv = self.uv(x) | |
if self.attn_type == 'self-attn': | |
u, v, base = torch.split( | |
self.act_fn(uv), [self.e, self.e, self.s], dim=-1) | |
base = base.unsqueeze(2) * self.gamma[None, None, :] + self.beta | |
if self.pos_enc: | |
base = rope(base, dim=1) | |
q, k = torch.unbind(base, dim=-2) | |
else: | |
u, q = torch.split(self.act_fn(uv), [self.e, self.s], dim=-1) | |
k = self.k_fc(k) | |
v = self.v_fc(v) | |
if self.pos_enc: | |
q = rope(q, 1) | |
k = rope(k, 1) | |
qk = torch.bmm(q, k.permute(0, 2, 1)) | |
if self.use_rel_bias: | |
if self.attn_type == 'self-attn': | |
bias = self.rel_pos_bias(q.size(1)) | |
else: | |
bias = self.rel_pos_bias(q.size(1), k.size(1)) | |
qk += bias[:, :q.size(1), :k.size(1)] | |
kernel = torch.square(F.relu(qk / self.sqrt_s)) | |
if self.dropout_rate > 0.: | |
kernel = self.dropout(kernel) | |
x = u * torch.bmm(kernel, v) | |
x = self.o(x) | |
return x | |
def forward(self, x): | |
"""Forward function.""" | |
if self.shortcut: | |
if self.attn_type == 'cross-attn': | |
res_shortcut = x[0] | |
else: | |
res_shortcut = x | |
main_branch = self.drop_path(self._forward(x)) | |
return self.res_scale(res_shortcut) + main_branch | |
else: | |
return self.drop_path(self._forward(x)) | |