|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import copy |
|
import logging |
|
import math |
|
|
|
from os.path import join as pjoin |
|
|
|
import torch |
|
import torch.nn as nn |
|
import numpy as np |
|
|
|
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm |
|
from torch.nn.modules.utils import _pair |
|
from scipy import ndimage |
|
|
|
|
|
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu} |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__(self, config): |
|
super(Attention, self).__init__() |
|
self.num_attention_heads = config["num_heads"] |
|
self.attention_head_size = int(config['hidden_size'] / self.num_attention_heads) |
|
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
|
self.query = Linear(config['hidden_size'], self.all_head_size) |
|
self.key = Linear(config['hidden_size'], self.all_head_size) |
|
self.value = Linear(config['hidden_size'], self.all_head_size) |
|
|
|
|
|
self.out = Linear(self.all_head_size, config['hidden_size']) |
|
self.attn_dropout = Dropout(config["attention_dropout_rate"]) |
|
self.proj_dropout = Dropout(config["attention_dropout_rate"]) |
|
|
|
self.softmax = Softmax(dim=-1) |
|
|
|
def transpose_for_scores(self, x): |
|
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) |
|
x = x.view(*new_x_shape) |
|
return x.permute(0, 2, 1, 3) |
|
|
|
def forward(self, hidden_states): |
|
|
|
mixed_query_layer = self.query(hidden_states) |
|
mixed_key_layer = self.key(hidden_states) |
|
mixed_value_layer = self.value(hidden_states) |
|
|
|
query_layer = self.transpose_for_scores(mixed_query_layer) |
|
key_layer = self.transpose_for_scores(mixed_key_layer) |
|
value_layer = self.transpose_for_scores(mixed_value_layer) |
|
|
|
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) |
|
attention_scores = attention_scores / math.sqrt(self.attention_head_size) |
|
attention_probs = self.softmax(attention_scores) |
|
attention_probs = self.attn_dropout(attention_probs) |
|
|
|
context_layer = torch.matmul(attention_probs, value_layer) |
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
|
context_layer = context_layer.view(*new_context_layer_shape) |
|
attention_output = self.out(context_layer) |
|
attention_output = self.proj_dropout(attention_output) |
|
return attention_output |
|
|
|
|
|
class Mlp(nn.Module): |
|
def __init__(self, config): |
|
super(Mlp, self).__init__() |
|
self.fc1 = Linear(config['hidden_size'], config["mlp_dim"]) |
|
self.fc2 = Linear(config["mlp_dim"], config['hidden_size']) |
|
self.act_fn = ACT2FN["gelu"] |
|
self.dropout = Dropout(config["dropout_rate"]) |
|
self._init_weights() |
|
|
|
def _init_weights(self): |
|
nn.init.xavier_uniform_(self.fc1.weight) |
|
nn.init.xavier_uniform_(self.fc2.weight) |
|
nn.init.normal_(self.fc1.bias, std=1e-6) |
|
nn.init.normal_(self.fc2.bias, std=1e-6) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.act_fn(x) |
|
x = self.dropout(x) |
|
x = self.fc2(x) |
|
x = self.dropout(x) |
|
return x |
|
|
|
|
|
class Block(nn.Module): |
|
def __init__(self, config): |
|
super(Block, self).__init__() |
|
self.flag = config['num_heads'] |
|
self.hidden_size = config['hidden_size'] |
|
self.ffn_norm = LayerNorm(config['hidden_size'], eps=1e-6) |
|
self.ffn = Mlp(config) |
|
self.attn = Attention(config) |
|
self.attention_norm = LayerNorm(config['hidden_size'], eps=1e-6) |
|
|
|
def forward(self, x): |
|
h = x |
|
|
|
x = self.attention_norm(x) |
|
x = self.attn(x) |
|
x = x + h |
|
|
|
h = x |
|
x = self.ffn_norm(x) |
|
x = self.ffn(x) |
|
x = x + h |
|
return x |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, config): |
|
super(Encoder, self).__init__() |
|
|
|
self.layer = nn.ModuleList() |
|
self.encoder_norm = LayerNorm(config['hidden_size'], eps=1e-6) |
|
for _ in range(config["num_layers"]): |
|
layer = Block(config) |
|
self.layer.append(copy.deepcopy(layer)) |
|
|
|
def forward(self, hidden_states): |
|
for layer_block in self.layer: |
|
hidden_states = layer_block(hidden_states) |
|
encoded = self.encoder_norm(hidden_states) |
|
|
|
return encoded |
|
|
|
|
|
|