libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
8.8 kB
import copy
import math
import torch
from torch import nn
class Transformer(nn.Module):
def __init__(self,
input_dim,
emb_size,
max_position_size,
dropout,
n_layer,
intermediate_size,
num_attention_heads,
attention_probs_dropout,
hidden_dropout,
):
super().__init__()
self.emb = Embeddings(input_dim,
emb_size,
max_position_size,
dropout)
self.encoder = MultiLayeredEncoder(n_layer,
emb_size,
intermediate_size,
num_attention_heads,
attention_probs_dropout,
hidden_dropout)
def forward(self, v):
e = v[0].long()
e_mask = v[1].long()
ex_e_mask = e_mask.unsqueeze(1).unsqueeze(2)
ex_e_mask = (1.0 - ex_e_mask) * -10000.0
emb = self.emb(e)
encoded_layers = self.encoder(emb.float(), ex_e_mask.float())
return encoded_layers[:, 0]
class LayerNorm(nn.Module):
def __init__(self, hidden_size, variance_epsilon=1e-12):
super(LayerNorm, self).__init__()
self.gamma = nn.Parameter(torch.ones(hidden_size))
self.beta = nn.Parameter(torch.zeros(hidden_size))
self.variance_epsilon = variance_epsilon
def forward(self, x):
u = x.mean(-1, keepdim=True)
s = (x - u).pow(2).mean(-1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.variance_epsilon)
return self.gamma * x + self.beta
class Embeddings(nn.Module):
"""Construct the embeddings from protein/target, position embeddings.
"""
def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate):
super(Embeddings, self).__init__()
self.word_embeddings = nn.Embedding(vocab_size, hidden_size)
self.position_embeddings = nn.Embedding(max_position_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size)
self.dropout = nn.Dropout(dropout_rate)
def forward(self, input_ids):
seq_length = input_ids.size(1)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
embeddings = words_embeddings + position_embeddings
embeddings = self.LayerNorm(embeddings)
embeddings = self.dropout(embeddings)
return embeddings
class SelfAttention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob):
super(SelfAttention, self).__init__()
if hidden_size % num_attention_heads != 0:
raise ValueError(
"The hidden size (%d) is not a multiple of the number of attention "
"heads (%d)" % (hidden_size, num_attention_heads))
self.num_attention_heads = num_attention_heads
self.attention_head_size = int(hidden_size / num_attention_heads)
self.all_head_size = self.num_attention_heads * self.attention_head_size
self.query = nn.Linear(hidden_size, self.all_head_size)
self.key = nn.Linear(hidden_size, self.all_head_size)
self.value = nn.Linear(hidden_size, self.all_head_size)
self.dropout = nn.Dropout(attention_probs_dropout_prob)
def transpose_for_scores(self, x):
new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
x = x.view(*new_x_shape)
return x.permute(0, 2, 1, 3)
def forward(self, hidden_states, attention_mask):
mixed_query_layer = self.query(hidden_states)
mixed_key_layer = self.key(hidden_states)
mixed_value_layer = self.value(hidden_states)
query_layer = self.transpose_for_scores(mixed_query_layer)
key_layer = self.transpose_for_scores(mixed_key_layer)
value_layer = self.transpose_for_scores(mixed_value_layer)
# Take the dot product between "query" and "key" to get the raw attention scores.
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
attention_scores = attention_scores / math.sqrt(self.attention_head_size)
attention_scores = attention_scores + attention_mask
# Normalize the attention scores to probabilities.
attention_probs = nn.Softmax(dim=-1)(attention_scores)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self.dropout(attention_probs)
context_layer = torch.matmul(attention_probs, value_layer)
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
context_layer = context_layer.view(*new_context_layer_shape)
return context_layer
class SelfOutput(nn.Module):
def __init__(self, hidden_size, hidden_dropout_prob):
super(SelfOutput, self).__init__()
self.dense = nn.Linear(hidden_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class Attention(nn.Module):
def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob):
super(Attention, self).__init__()
self.self = SelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob)
self.output = SelfOutput(hidden_size, hidden_dropout_prob)
def forward(self, input_tensor, attention_mask):
self_output = self.self(input_tensor, attention_mask)
attention_output = self.output(self_output, input_tensor)
return attention_output
class Intermediate(nn.Module):
def __init__(self, hidden_size, intermediate_size):
super(Intermediate, self).__init__()
self.dense = nn.Linear(hidden_size, intermediate_size)
def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
hidden_states = nn.functional.relu(hidden_states)
return hidden_states
class Output(nn.Module):
def __init__(self, intermediate_size, hidden_size, hidden_dropout_prob):
super(Output, self).__init__()
self.dense = nn.Linear(intermediate_size, hidden_size)
self.LayerNorm = LayerNorm(hidden_size)
self.dropout = nn.Dropout(hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class Encoder(nn.Module):
def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob):
super(Encoder, self).__init__()
self.attention = Attention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob)
self.intermediate = Intermediate(hidden_size, intermediate_size)
self.output = Output(intermediate_size, hidden_size, hidden_dropout_prob)
def forward(self, hidden_states, attention_mask):
attention_output = self.attention(hidden_states, attention_mask)
intermediate_output = self.intermediate(attention_output)
layer_output = self.output(intermediate_output, attention_output)
return layer_output
class MultiLayeredEncoder(nn.Module):
def __init__(self, n_layer, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob):
super(MultiLayeredEncoder, self).__init__()
layer = Encoder(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob,
hidden_dropout_prob)
self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layer)])
def forward(self, hidden_states, attention_mask):
for layer_module in self.layer:
hidden_states = layer_module(hidden_states, attention_mask)
return hidden_states