File size: 1,830 Bytes
4fb86de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
import torch.nn as nn
# ========== 配置参数 ==========
BATCH_SIZE = 128
EPOCHS = 50
LEARNING_RATE = 1e-4
MAX_SEQ_LEN = 60
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
D_MODEL = 512
N_HEAD = 8
NUM_LAYERS = 6
DIM_FEEDFORWARD = 2048
# ========== 数据加载 ==========

class TransformerModel(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model=D_MODEL, nhead=N_HEAD, num_layers=NUM_LAYERS, dim_feedforward=DIM_FEEDFORWARD):
        super(TransformerModel, self).__init__()
        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.positional_encoding = nn.Parameter(torch.zeros(1, MAX_SEQ_LEN, d_model))

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=nhead,
            num_encoder_layers=num_layers,
            num_decoder_layers=num_layers,
            dim_feedforward=dim_feedforward,
            dropout=0.1,
        )

        self.fc_out = nn.Linear(d_model, tgt_vocab_size)
        self.d_model = d_model

    def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
        src_emb = self.src_embedding(src) * (self.d_model ** 0.5) + self.positional_encoding[:, :src.size(1), :]
        tgt_emb = self.tgt_embedding(tgt) * (self.d_model ** 0.5) + self.positional_encoding[:, :tgt.size(1), :]

        output = self.transformer(
            src_emb.permute(1, 0, 2),
            tgt_emb.permute(1, 0, 2),
            src_mask=src_mask,
            tgt_mask=tgt_mask,
            src_key_padding_mask=src_padding_mask,
            tgt_key_padding_mask=tgt_padding_mask,
        )

        return self.fc_out(output.permute(1, 0, 2))