import copy import math import torch from torch import nn class Transformer(nn.Module): def __init__(self, input_dim, emb_size, max_position_size, dropout, n_layer, intermediate_size, num_attention_heads, attention_probs_dropout, hidden_dropout, ): super().__init__() self.emb = Embeddings(input_dim, emb_size, max_position_size, dropout) self.encoder = MultiLayeredEncoder(n_layer, emb_size, intermediate_size, num_attention_heads, attention_probs_dropout, hidden_dropout) def forward(self, v): e = v[0].long() e_mask = v[1].long() ex_e_mask = e_mask.unsqueeze(1).unsqueeze(2) ex_e_mask = (1.0 - ex_e_mask) * -10000.0 emb = self.emb(e) encoded_layers = self.encoder(emb.float(), ex_e_mask.float()) return encoded_layers[:, 0] class LayerNorm(nn.Module): def __init__(self, hidden_size, variance_epsilon=1e-12): super(LayerNorm, self).__init__() self.gamma = nn.Parameter(torch.ones(hidden_size)) self.beta = nn.Parameter(torch.zeros(hidden_size)) self.variance_epsilon = variance_epsilon def forward(self, x): u = x.mean(-1, keepdim=True) s = (x - u).pow(2).mean(-1, keepdim=True) x = (x - u) / torch.sqrt(s + self.variance_epsilon) return self.gamma * x + self.beta class Embeddings(nn.Module): """Construct the embeddings from protein/target, position embeddings. """ def __init__(self, vocab_size, hidden_size, max_position_size, dropout_rate): super(Embeddings, self).__init__() self.word_embeddings = nn.Embedding(vocab_size, hidden_size) self.position_embeddings = nn.Embedding(max_position_size, hidden_size) self.LayerNorm = LayerNorm(hidden_size) self.dropout = nn.Dropout(dropout_rate) def forward(self, input_ids): seq_length = input_ids.size(1) position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device) position_ids = position_ids.unsqueeze(0).expand_as(input_ids) words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) embeddings = words_embeddings + position_embeddings embeddings = self.LayerNorm(embeddings) embeddings = self.dropout(embeddings) return embeddings class SelfAttention(nn.Module): def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob): super(SelfAttention, self).__init__() if hidden_size % num_attention_heads != 0: raise ValueError( "The hidden size (%d) is not a multiple of the number of attention " "heads (%d)" % (hidden_size, num_attention_heads)) self.num_attention_heads = num_attention_heads self.attention_head_size = int(hidden_size / num_attention_heads) self.all_head_size = self.num_attention_heads * self.attention_head_size self.query = nn.Linear(hidden_size, self.all_head_size) self.key = nn.Linear(hidden_size, self.all_head_size) self.value = nn.Linear(hidden_size, self.all_head_size) self.dropout = nn.Dropout(attention_probs_dropout_prob) def transpose_for_scores(self, x): new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) x = x.view(*new_x_shape) return x.permute(0, 2, 1, 3) def forward(self, hidden_states, attention_mask): mixed_query_layer = self.query(hidden_states) mixed_key_layer = self.key(hidden_states) mixed_value_layer = self.value(hidden_states) query_layer = self.transpose_for_scores(mixed_query_layer) key_layer = self.transpose_for_scores(mixed_key_layer) value_layer = self.transpose_for_scores(mixed_value_layer) # Take the dot product between "query" and "key" to get the raw attention scores. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) attention_scores = attention_scores / math.sqrt(self.attention_head_size) attention_scores = attention_scores + attention_mask # Normalize the attention scores to probabilities. attention_probs = nn.Softmax(dim=-1)(attention_scores) # This is actually dropping out entire tokens to attend to, which might # seem a bit unusual, but is taken from the original Transformer paper. attention_probs = self.dropout(attention_probs) context_layer = torch.matmul(attention_probs, value_layer) context_layer = context_layer.permute(0, 2, 1, 3).contiguous() new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) context_layer = context_layer.view(*new_context_layer_shape) return context_layer class SelfOutput(nn.Module): def __init__(self, hidden_size, hidden_dropout_prob): super(SelfOutput, self).__init__() self.dense = nn.Linear(hidden_size, hidden_size) self.LayerNorm = LayerNorm(hidden_size) self.dropout = nn.Dropout(hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class Attention(nn.Module): def __init__(self, hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): super(Attention, self).__init__() self.self = SelfAttention(hidden_size, num_attention_heads, attention_probs_dropout_prob) self.output = SelfOutput(hidden_size, hidden_dropout_prob) def forward(self, input_tensor, attention_mask): self_output = self.self(input_tensor, attention_mask) attention_output = self.output(self_output, input_tensor) return attention_output class Intermediate(nn.Module): def __init__(self, hidden_size, intermediate_size): super(Intermediate, self).__init__() self.dense = nn.Linear(hidden_size, intermediate_size) def forward(self, hidden_states): hidden_states = self.dense(hidden_states) hidden_states = nn.functional.relu(hidden_states) return hidden_states class Output(nn.Module): def __init__(self, intermediate_size, hidden_size, hidden_dropout_prob): super(Output, self).__init__() self.dense = nn.Linear(intermediate_size, hidden_size) self.LayerNorm = LayerNorm(hidden_size) self.dropout = nn.Dropout(hidden_dropout_prob) def forward(self, hidden_states, input_tensor): hidden_states = self.dense(hidden_states) hidden_states = self.dropout(hidden_states) hidden_states = self.LayerNorm(hidden_states + input_tensor) return hidden_states class Encoder(nn.Module): def __init__(self, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): super(Encoder, self).__init__() self.attention = Attention(hidden_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) self.intermediate = Intermediate(hidden_size, intermediate_size) self.output = Output(intermediate_size, hidden_size, hidden_dropout_prob) def forward(self, hidden_states, attention_mask): attention_output = self.attention(hidden_states, attention_mask) intermediate_output = self.intermediate(attention_output) layer_output = self.output(intermediate_output, attention_output) return layer_output class MultiLayeredEncoder(nn.Module): def __init__(self, n_layer, hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob): super(MultiLayeredEncoder, self).__init__() layer = Encoder(hidden_size, intermediate_size, num_attention_heads, attention_probs_dropout_prob, hidden_dropout_prob) self.layer = nn.ModuleList([copy.deepcopy(layer) for _ in range(n_layer)]) def forward(self, hidden_states, attention_mask): for layer_module in self.layer: hidden_states = layer_module(hidden_states, attention_mask) return hidden_states