Spaces:
Configuration error
Configuration error
import torch | |
import torch.nn as nn | |
class Adapter(nn.Module): | |
def __init__(self, D_features, mlp_ratio=0.25, act_layer=nn.GELU, skip_connect=True): | |
super().__init__() | |
self.skip_connect = skip_connect | |
D_hidden_features = int(D_features * mlp_ratio) | |
self.act = act_layer() | |
self.D_fc1 = nn.Linear(D_features, D_hidden_features) | |
self.D_fc2 = nn.Linear(D_hidden_features, D_features) | |
def forward(self, x): | |
# x is (BT, HW+1, D) | |
xs = self.D_fc1(x) | |
xs = self.act(xs) | |
xs = self.D_fc2(xs) | |
if self.skip_connect: | |
x = x + xs | |
else: | |
x = xs | |
return x |