Spaces:
Runtime error
Runtime error
File size: 9,365 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn.bricks import DropPath
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION
def rope(x, dim):
"""Applies Rotary Position Embedding to input tensor.
Args:
x (torch.Tensor): Input tensor.
dim (int | list[int]): The spatial dimension(s) to apply
rotary position embedding.
Returns:
torch.Tensor: The tensor after applying rotary position
embedding.
Reference:
`RoFormer: Enhanced Transformer with Rotary
Position Embedding <https://arxiv.org/abs/2104.09864>`_
"""
shape = x.shape
if isinstance(dim, int):
dim = [dim]
spatial_shape = [shape[i] for i in dim]
total_len = 1
for i in spatial_shape:
total_len *= i
position = torch.reshape(
torch.arange(total_len, dtype=torch.int, device=x.device),
spatial_shape)
for i in range(dim[-1] + 1, len(shape) - 1, 1):
position = torch.unsqueeze(position, dim=-1)
half_size = shape[-1] // 2
freq_seq = -torch.arange(
half_size, dtype=torch.int, device=x.device) / float(half_size)
inv_freq = 10000**-freq_seq
sinusoid = position[..., None] * inv_freq[None, None, :]
sin = torch.sin(sinusoid)
cos = torch.cos(sinusoid)
x1, x2 = torch.chunk(x, 2, dim=-1)
return torch.cat([x1 * cos - x2 * sin, x2 * cos + x1 * sin], dim=-1)
class Scale(nn.Module):
"""Scale vector by element multiplications.
Args:
dim (int): The dimension of the scale vector.
init_value (float, optional): The initial value of the scale vector.
Defaults to 1.0.
trainable (bool, optional): Whether the scale vector is trainable.
Defaults to True.
"""
def __init__(self, dim, init_value=1., trainable=True):
super().__init__()
self.scale = nn.Parameter(
init_value * torch.ones(dim), requires_grad=trainable)
def forward(self, x):
"""Forward function."""
return x * self.scale
class ScaleNorm(nn.Module):
"""Scale Norm.
Args:
dim (int): The dimension of the scale vector.
eps (float, optional): The minimum value in clamp. Defaults to 1e-5.
Reference:
`Transformers without Tears: Improving the Normalization
of Self-Attention <https://arxiv.org/abs/1910.05895>`_
"""
def __init__(self, dim, eps=1e-5):
super().__init__()
self.scale = dim**-0.5
self.eps = eps
self.g = nn.Parameter(torch.ones(1))
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: The tensor after applying scale norm.
"""
norm = torch.norm(x, dim=-1, keepdim=True) * self.scale
return x / norm.clamp(min=self.eps) * self.g
class RTMCCBlock(nn.Module):
"""Gated Attention Unit (GAU) in RTMBlock.
Args:
num_token (int): The number of tokens.
in_token_dims (int): The input token dimension.
out_token_dims (int): The output token dimension.
expansion_factor (int, optional): The expansion factor of the
intermediate token dimension. Defaults to 2.
s (int, optional): The self-attention feature dimension.
Defaults to 128.
eps (float, optional): The minimum value in clamp. Defaults to 1e-5.
dropout_rate (float, optional): The dropout rate. Defaults to 0.0.
drop_path (float, optional): The drop path rate. Defaults to 0.0.
attn_type (str, optional): Type of attention which should be one of
the following options:
- 'self-attn': Self-attention.
- 'cross-attn': Cross-attention.
Defaults to 'self-attn'.
act_fn (str, optional): The activation function which should be one
of the following options:
- 'ReLU': ReLU activation.
- 'SiLU': SiLU activation.
Defaults to 'SiLU'.
bias (bool, optional): Whether to use bias in linear layers.
Defaults to False.
use_rel_bias (bool, optional): Whether to use relative bias.
Defaults to True.
pos_enc (bool, optional): Whether to use rotary position
embedding. Defaults to False.
Reference:
`Transformer Quality in Linear Time
<https://arxiv.org/abs/2202.10447>`_
"""
def __init__(self,
num_token,
in_token_dims,
out_token_dims,
expansion_factor=2,
s=128,
eps=1e-5,
dropout_rate=0.,
drop_path=0.,
attn_type='self-attn',
act_fn='SiLU',
bias=False,
use_rel_bias=True,
pos_enc=False):
super(RTMCCBlock, self).__init__()
self.s = s
self.num_token = num_token
self.use_rel_bias = use_rel_bias
self.attn_type = attn_type
self.pos_enc = pos_enc
self.drop_path = DropPath(drop_path) \
if drop_path > 0. else nn.Identity()
self.e = int(in_token_dims * expansion_factor)
if use_rel_bias:
if attn_type == 'self-attn':
self.w = nn.Parameter(
torch.rand([2 * num_token - 1], dtype=torch.float))
else:
self.a = nn.Parameter(torch.rand([1, s], dtype=torch.float))
self.b = nn.Parameter(torch.rand([1, s], dtype=torch.float))
self.o = nn.Linear(self.e, out_token_dims, bias=bias)
if attn_type == 'self-attn':
self.uv = nn.Linear(in_token_dims, 2 * self.e + self.s, bias=bias)
self.gamma = nn.Parameter(torch.rand((2, self.s)))
self.beta = nn.Parameter(torch.rand((2, self.s)))
else:
self.uv = nn.Linear(in_token_dims, self.e + self.s, bias=bias)
self.k_fc = nn.Linear(in_token_dims, self.s, bias=bias)
self.v_fc = nn.Linear(in_token_dims, self.e, bias=bias)
nn.init.xavier_uniform_(self.k_fc.weight)
nn.init.xavier_uniform_(self.v_fc.weight)
self.ln = ScaleNorm(in_token_dims, eps=eps)
nn.init.xavier_uniform_(self.uv.weight)
if act_fn == 'SiLU':
assert digit_version(TORCH_VERSION) >= digit_version('1.7.0'), \
'SiLU activation requires PyTorch version >= 1.7'
self.act_fn = nn.SiLU(True)
else:
self.act_fn = nn.ReLU(True)
if in_token_dims == out_token_dims:
self.shortcut = True
self.res_scale = Scale(in_token_dims)
else:
self.shortcut = False
self.sqrt_s = math.sqrt(s)
self.dropout_rate = dropout_rate
if dropout_rate > 0.:
self.dropout = nn.Dropout(dropout_rate)
def rel_pos_bias(self, seq_len, k_len=None):
"""Add relative position bias."""
if self.attn_type == 'self-attn':
t = F.pad(self.w[:2 * seq_len - 1], [0, seq_len]).repeat(seq_len)
t = t[..., :-seq_len].reshape(-1, seq_len, 3 * seq_len - 2)
r = (2 * seq_len - 1) // 2
t = t[..., r:-r]
else:
a = rope(self.a.repeat(seq_len, 1), dim=0)
b = rope(self.b.repeat(k_len, 1), dim=0)
t = torch.bmm(a, b.permute(0, 2, 1))
return t
def _forward(self, inputs):
"""GAU Forward function."""
if self.attn_type == 'self-attn':
x = inputs
else:
x, k, v = inputs
x = self.ln(x)
uv = self.uv(x)
if self.attn_type == 'self-attn':
u, v, base = torch.split(
self.act_fn(uv), [self.e, self.e, self.s], dim=-1)
base = base.unsqueeze(2) * self.gamma[None, None, :] + self.beta
if self.pos_enc:
base = rope(base, dim=1)
q, k = torch.unbind(base, dim=-2)
else:
u, q = torch.split(self.act_fn(uv), [self.e, self.s], dim=-1)
k = self.k_fc(k)
v = self.v_fc(v)
if self.pos_enc:
q = rope(q, 1)
k = rope(k, 1)
qk = torch.bmm(q, k.permute(0, 2, 1))
if self.use_rel_bias:
if self.attn_type == 'self-attn':
bias = self.rel_pos_bias(q.size(1))
else:
bias = self.rel_pos_bias(q.size(1), k.size(1))
qk += bias[:, :q.size(1), :k.size(1)]
kernel = torch.square(F.relu(qk / self.sqrt_s))
if self.dropout_rate > 0.:
kernel = self.dropout(kernel)
x = u * torch.bmm(kernel, v)
x = self.o(x)
return x
def forward(self, x):
"""Forward function."""
if self.shortcut:
if self.attn_type == 'cross-attn':
res_shortcut = x[0]
else:
res_shortcut = x
main_branch = self.drop_path(self._forward(x))
return self.res_scale(res_shortcut) + main_branch
else:
return self.drop_path(self._forward(x))
|