|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import copy |
|
import logging |
|
import math |
|
|
|
from os.path import join as pjoin |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from torch.nn import BCEWithLogitsLoss,CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm |
|
from torch.nn.modules.utils import _pair |
|
from scipy import ndimage |
|
|
|
import models.configs as configs |
|
from models.attention import Attention |
|
from models.embed import Embeddings |
|
|
|
import pdb |
|
|
|
def swish(x): |
|
return x * torch.sigmoid(x) |
|
|
|
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish} |
|
|
|
class Mlp(nn.Module): |
|
def __init__(self, config): |
|
super(Mlp, self).__init__() |
|
self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"]) |
|
self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size) |
|
self.act_fn = ACT2FN["gelu"] |
|
self.dropout = Dropout(config.transformer["dropout_rate"]) |
|
|
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
nn.init.xavier_uniform_(self.fc1.weight) |
|
nn.init.xavier_uniform_(self.fc2.weight) |
|
nn.init.normal_(self.fc1.bias, std=1e-6) |
|
nn.init.normal_(self.fc2.bias, std=1e-6) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act_fn(x) |
|
x = self.dropout(x) |
|
x = self.fc2(x) |
|
x = self.dropout(x) |
|
return x |
|
|
|
|
|
|
|
|