import torch import torch.nn as nn # ========== 配置参数 ========== BATCH_SIZE = 128 EPOCHS = 50 LEARNING_RATE = 1e-4 MAX_SEQ_LEN = 60 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") D_MODEL = 512 N_HEAD = 8 NUM_LAYERS = 6 DIM_FEEDFORWARD = 2048 # ========== 数据加载 ========== class TransformerModel(nn.Module): def __init__(self, src_vocab_size, tgt_vocab_size, d_model=D_MODEL, nhead=N_HEAD, num_layers=NUM_LAYERS, dim_feedforward=DIM_FEEDFORWARD): super(TransformerModel, self).__init__() self.src_embedding = nn.Embedding(src_vocab_size, d_model) self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model) self.positional_encoding = nn.Parameter(torch.zeros(1, MAX_SEQ_LEN, d_model)) self.transformer = nn.Transformer( d_model=d_model, nhead=nhead, num_encoder_layers=num_layers, num_decoder_layers=num_layers, dim_feedforward=dim_feedforward, dropout=0.1, ) self.fc_out = nn.Linear(d_model, tgt_vocab_size) self.d_model = d_model def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None): src_emb = self.src_embedding(src) * (self.d_model ** 0.5) + self.positional_encoding[:, :src.size(1), :] tgt_emb = self.tgt_embedding(tgt) * (self.d_model ** 0.5) + self.positional_encoding[:, :tgt.size(1), :] output = self.transformer( src_emb.permute(1, 0, 2), tgt_emb.permute(1, 0, 2), src_mask=src_mask, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask, tgt_key_padding_mask=tgt_padding_mask, ) return self.fc_out(output.permute(1, 0, 2))