File size: 4,028 Bytes
8044721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
#!/usr/bin/env python3
# coding=utf-8

import math

import torch
import torch.nn as nn
import torch.nn.functional as F

from transformers import AutoModel
from model.module.char_embedding import CharEmbedding


class WordDropout(nn.Dropout):
    def forward(self, input_tensor):
        if self.p == 0:
            return input_tensor

        ones = input_tensor.new_ones(input_tensor.shape[:-1])
        dropout_mask = torch.nn.functional.dropout(ones, self.p, self.training, inplace=False)

        return dropout_mask.unsqueeze(-1) * input_tensor


class Encoder(nn.Module):
    def __init__(self, args, dataset):
        super(Encoder, self).__init__()

        self.dim = args.hidden_size
        self.n_layers = args.n_encoder_layers
        self.width_factor = args.query_length

        self.bert = AutoModel.from_pretrained(args.encoder, add_pooling_layer=False)
        # self.bert._set_gradient_checkpointing(self.bert.encoder, value=True)
        if args.encoder_freeze_embedding:
            self.bert.embeddings.requires_grad_(False)
            self.bert.embeddings.LayerNorm.requires_grad_(True)

        if args.freeze_bert:
            self.bert.requires_grad_(False)

        self.use_char_embedding = args.char_embedding
        if self.use_char_embedding:
            self.form_char_embedding = CharEmbedding(dataset.char_form_vocab_size, args.char_embedding_size, self.dim)
            self.word_dropout = WordDropout(args.dropout_word)

        self.post_layer_norm = nn.LayerNorm(self.dim)
        self.subword_attention = nn.Linear(self.dim, 1)

        if self.width_factor > 1:
            self.query_generator = nn.Linear(self.dim, self.dim * self.width_factor)
        else:
            self.query_generator = nn.Identity()

        self.encoded_layer_norm = nn.LayerNorm(self.dim)
        self.scores = nn.Parameter(torch.zeros(self.n_layers, 1, 1, 1), requires_grad=True)

    def forward(self, bert_input, form_chars, to_scatter, n_words):
        tokens, mask = bert_input
        batch_size = tokens.size(0)

        encoded = self.bert(tokens, attention_mask=mask, output_hidden_states=True).hidden_states[1:]
        encoded = torch.stack(encoded, dim=0)  # shape: (12, B, T, H)
        encoded = self.encoded_layer_norm(encoded)

        if self.training:
            time_len = encoded.size(2)
            scores = self.scores.expand(-1, batch_size, time_len, -1)
            dropout = torch.empty(self.n_layers, batch_size, 1, 1, dtype=torch.bool, device=self.scores.device)
            dropout.bernoulli_(0.1)
            scores = scores.masked_fill(dropout, float("-inf"))
        else:
            scores = self.scores

        scores = F.softmax(scores, dim=0)
        encoded = (scores * encoded).sum(0)  # shape: (B, T, H)
        encoded = encoded.masked_fill(mask.unsqueeze(-1) == 0, 0.0)  # shape: (B, T, H)

        subword_attention = self.subword_attention(encoded) / math.sqrt(self.dim)  # shape: (B, T, 1)
        subword_attention = subword_attention.expand_as(to_scatter)  # shape: (B, T_subword, T_word)
        subword_attention = subword_attention.masked_fill(to_scatter == 0, float("-inf"))  # shape: (B, T_subword, T_word)
        subword_attention = torch.softmax(subword_attention, dim=1)  # shape: (B, T_subword, T_word)
        subword_attention = subword_attention.masked_fill(to_scatter.sum(1, keepdim=True) == 0, value=0.0)  # shape: (B, T_subword, T_word)

        encoder_output = torch.einsum("bsd,bsw->bwd", encoded, subword_attention)
        encoder_output = self.post_layer_norm(encoder_output)

        if self.use_char_embedding:
            form_char_embedding = self.form_char_embedding(form_chars[0], form_chars[1], form_chars[2])
            encoder_output = self.word_dropout(encoder_output) + form_char_embedding

        decoder_input = self.query_generator(encoder_output)
        decoder_input = decoder_input.view(batch_size, -1, self.width_factor, self.dim).flatten(1, 2)  # shape: (B, T*Q, D)

        return encoder_output, decoder_input