File size: 1,612 Bytes
7900c16
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import torch.nn as nn


class Model(nn.Module):
    """
    Pretraining models consist of three (five) parts:
        - embedding
        - encoder
        - tgt_embedding (optional)
        - decoder (optional)
        - target
    """

    def __init__(self, args, embedding, encoder, tgt_embedding, decoder, target):
        super(Model, self).__init__()
        self.embedding = embedding
        self.encoder = encoder
        self.tgt_embedding = tgt_embedding
        self.decoder = decoder
        self.target = target

        if "mlm" in args.target and args.tie_weights:
            self.target.mlm.linear_2.weight = self.embedding.word.embedding.weight
        elif "lm" in args.target and args.tie_weights and "word" in self.embedding.embedding_name_list:
            self.target.lm.output_layer.weight = self.embedding.word.embedding.weight
        elif "lm" in args.target and args.tie_weights and "word" in self.tgt_embedding.embedding_name_list:
            self.target.lm.output_layer.weight = self.tgt_embedding.word.embedding.weight
            
        if self.decoder is not None and args.share_embedding:
            self.tgt_embedding.word.embedding.weight = self.embedding.word.embedding.weight

    def forward(self, src, tgt, seg, tgt_in=None, tgt_seg=None):
        emb = self.embedding(src, seg)
        memory_bank = self.encoder(emb, seg)
        if self.decoder:
            tgt_emb = self.tgt_embedding(tgt_in, tgt_seg)
            memory_bank = self.decoder(memory_bank, tgt_emb, (seg, tgt_seg))

        loss_info = self.target(memory_bank, tgt, seg)

        return loss_info