Spaces:
Runtime error
Runtime error
import re | |
import torch.nn as nn | |
from . import register_connector | |
from .base import Connector | |
ACT_TYPE = { | |
'relu': nn.ReLU, | |
'gelu': nn.GELU | |
} | |
class MLPConnector(Connector): | |
def __init__(self, config): | |
super().__init__() | |
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', config.connector_type) | |
act_type = config.connector_type.split('_')[-1] | |
mlp_depth = int(mlp_gelu_match.group(1)) | |
modules = [nn.Linear(config.vision_hidden_size, config.hidden_size)] | |
for _ in range(1, mlp_depth): | |
modules.append(ACT_TYPE[act_type]()) | |
modules.append(nn.Linear(config.hidden_size, config.hidden_size)) | |
self._connector = nn.Sequential(*modules) | |
# @property | |
# def config(self): | |
# return {"connector_type": 'mlp', | |
# "in_hidden_size": self.in_hidden_size, | |
# "out_hidden_size": self.out_hidden_size | |
# } | |