# coding=utf-8 from __future__ import absolute_import from __future__ import division from __future__ import print_function import copy import logging import math from os.path import join as pjoin import torch import torch.nn as nn import numpy as np from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm from torch.nn.modules.utils import _pair from scipy import ndimage ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu} class Attention(nn.Module): def __init__(self, config): super(Attention, self).__init__() self.num_attention_heads = config["num_heads"] # 12 self.attention_head_size = int(config['hidden_size'] / self.num_attention_heads) # 42 self.all_head_size = self.num_attention_heads * self.attention_head_size # 12*42=504 self.query = Linear(config['hidden_size'], self.all_head_size) # (512, 504) self.key = Linear(config['hidden_size'], self.all_head_size) self.value = Linear(config['hidden_size'], self.all_head_size) # self.out = Linear(config['hidden_size'], config['hidden_size']) self.out = Linear(self.all_head_size, config['hidden_size']) self.attn_dropout = Dropout(config["attention_dropout_rate"]) self.proj_dropout = Dropout(config["attention_dropout_rate"]) self.softmax = Softmax(dim=-1) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_probs = self.softmax(attention_scores) attention_probs = self.attn_dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) attention_output = self.out(context_layer) attention_output = self.proj_dropout(attention_output) return attention_output class Mlp(nn.Module): def __init__(self, config): super(Mlp, self).__init__() self.fc1 = Linear(config['hidden_size'], config["mlp_dim"]) self.fc2 = Linear(config["mlp_dim"], config['hidden_size']) self.act_fn = ACT2FN["gelu"] self.dropout = Dropout(config["dropout_rate"]) self._init_weights() def _init_weights(self): nn.init.xavier_uniform_(self.fc1.weight) nn.init.xavier_uniform_(self.fc2.weight) nn.init.normal_(self.fc1.bias, std=1e-6) nn.init.normal_(self.fc2.bias, std=1e-6) def forward(self, x): x = self.fc1(x) x = self.act_fn(x) x = self.dropout(x) x = self.fc2(x) x = self.dropout(x) return x class Block(nn.Module): def __init__(self, config): super(Block, self).__init__() self.flag = config['num_heads'] self.hidden_size = config['hidden_size'] self.ffn_norm = LayerNorm(config['hidden_size'], eps=1e-6) self.ffn = Mlp(config) self.attn = Attention(config) self.attention_norm = LayerNorm(config['hidden_size'], eps=1e-6) def forward(self, x): h = x x = self.attention_norm(x) x = self.attn(x) x = x + h h = x x = self.ffn_norm(x) x = self.ffn(x) x = x + h return x class Encoder(nn.Module): def __init__(self, config): super(Encoder, self).__init__() self.layer = nn.ModuleList() self.encoder_norm = LayerNorm(config['hidden_size'], eps=1e-6) for _ in range(config["num_layers"]): layer = Block(config) self.layer.append(copy.deepcopy(layer)) def forward(self, hidden_states): for layer_block in self.layer: hidden_states = layer_block(hidden_states) encoded = self.encoder_norm(hidden_states) return encoded