Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# coding=utf-8 | |
import math | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from transformers import AutoModel | |
from model.module.char_embedding import CharEmbedding | |
class WordDropout(nn.Dropout): | |
def forward(self, input_tensor): | |
if self.p == 0: | |
return input_tensor | |
ones = input_tensor.new_ones(input_tensor.shape[:-1]) | |
dropout_mask = torch.nn.functional.dropout(ones, self.p, self.training, inplace=False) | |
return dropout_mask.unsqueeze(-1) * input_tensor | |
class Encoder(nn.Module): | |
def __init__(self, args, dataset): | |
super(Encoder, self).__init__() | |
self.dim = args.hidden_size | |
self.n_layers = args.n_encoder_layers | |
self.width_factor = args.query_length | |
self.bert = AutoModel.from_pretrained(args.encoder, add_pooling_layer=False) | |
# self.bert._set_gradient_checkpointing(self.bert.encoder, value=True) | |
if args.encoder_freeze_embedding: | |
self.bert.embeddings.requires_grad_(False) | |
self.bert.embeddings.LayerNorm.requires_grad_(True) | |
if args.freeze_bert: | |
self.bert.requires_grad_(False) | |
self.use_char_embedding = args.char_embedding | |
if self.use_char_embedding: | |
self.form_char_embedding = CharEmbedding(dataset.char_form_vocab_size, args.char_embedding_size, self.dim) | |
self.word_dropout = WordDropout(args.dropout_word) | |
self.post_layer_norm = nn.LayerNorm(self.dim) | |
self.subword_attention = nn.Linear(self.dim, 1) | |
if self.width_factor > 1: | |
self.query_generator = nn.Linear(self.dim, self.dim * self.width_factor) | |
else: | |
self.query_generator = nn.Identity() | |
self.encoded_layer_norm = nn.LayerNorm(self.dim) | |
self.scores = nn.Parameter(torch.zeros(self.n_layers, 1, 1, 1), requires_grad=True) | |
def forward(self, bert_input, form_chars, to_scatter, n_words): | |
tokens, mask = bert_input | |
batch_size = tokens.size(0) | |
encoded = self.bert(tokens, attention_mask=mask, output_hidden_states=True).hidden_states[1:] | |
encoded = torch.stack(encoded, dim=0) # shape: (12, B, T, H) | |
encoded = self.encoded_layer_norm(encoded) | |
if self.training: | |
time_len = encoded.size(2) | |
scores = self.scores.expand(-1, batch_size, time_len, -1) | |
dropout = torch.empty(self.n_layers, batch_size, 1, 1, dtype=torch.bool, device=self.scores.device) | |
dropout.bernoulli_(0.1) | |
scores = scores.masked_fill(dropout, float("-inf")) | |
else: | |
scores = self.scores | |
scores = F.softmax(scores, dim=0) | |
encoded = (scores * encoded).sum(0) # shape: (B, T, H) | |
encoded = encoded.masked_fill(mask.unsqueeze(-1) == 0, 0.0) # shape: (B, T, H) | |
subword_attention = self.subword_attention(encoded) / math.sqrt(self.dim) # shape: (B, T, 1) | |
subword_attention = subword_attention.expand_as(to_scatter) # shape: (B, T_subword, T_word) | |
subword_attention = subword_attention.masked_fill(to_scatter == 0, float("-inf")) # shape: (B, T_subword, T_word) | |
subword_attention = torch.softmax(subword_attention, dim=1) # shape: (B, T_subword, T_word) | |
subword_attention = subword_attention.masked_fill(to_scatter.sum(1, keepdim=True) == 0, value=0.0) # shape: (B, T_subword, T_word) | |
encoder_output = torch.einsum("bsd,bsw->bwd", encoded, subword_attention) | |
encoder_output = self.post_layer_norm(encoder_output) | |
if self.use_char_embedding: | |
form_char_embedding = self.form_char_embedding(form_chars[0], form_chars[1], form_chars[2]) | |
encoder_output = self.word_dropout(encoder_output) + form_char_embedding | |
decoder_input = self.query_generator(encoder_output) | |
decoder_input = decoder_input.view(batch_size, -1, self.width_factor, self.dim).flatten(1, 2) # shape: (B, T*Q, D) | |
return encoder_output, decoder_input | |