Suburst's picture
Upload 10 files
4fb86de verified
import torch
import torch.nn as nn
# ========== 配置参数 ==========
BATCH_SIZE = 128
EPOCHS = 50
LEARNING_RATE = 1e-4
MAX_SEQ_LEN = 60
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
D_MODEL = 512
N_HEAD = 8
NUM_LAYERS = 6
DIM_FEEDFORWARD = 2048
# ========== 数据加载 ==========
class TransformerModel(nn.Module):
def __init__(self, src_vocab_size, tgt_vocab_size, d_model=D_MODEL, nhead=N_HEAD, num_layers=NUM_LAYERS, dim_feedforward=DIM_FEEDFORWARD):
super(TransformerModel, self).__init__()
self.src_embedding = nn.Embedding(src_vocab_size, d_model)
self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
self.positional_encoding = nn.Parameter(torch.zeros(1, MAX_SEQ_LEN, d_model))
self.transformer = nn.Transformer(
d_model=d_model,
nhead=nhead,
num_encoder_layers=num_layers,
num_decoder_layers=num_layers,
dim_feedforward=dim_feedforward,
dropout=0.1,
)
self.fc_out = nn.Linear(d_model, tgt_vocab_size)
self.d_model = d_model
def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None):
src_emb = self.src_embedding(src) * (self.d_model ** 0.5) + self.positional_encoding[:, :src.size(1), :]
tgt_emb = self.tgt_embedding(tgt) * (self.d_model ** 0.5) + self.positional_encoding[:, :tgt.size(1), :]
output = self.transformer(
src_emb.permute(1, 0, 2),
tgt_emb.permute(1, 0, 2),
src_mask=src_mask,
tgt_mask=tgt_mask,
src_key_padding_mask=src_padding_mask,
tgt_key_padding_mask=tgt_padding_mask,
)
return self.fc_out(output.permute(1, 0, 2))