import torch import torch.nn as nn from torch.nn import Transformer from loguru import logger import os DEVICE = torch.device("cpu") def get_sinusoid_encoding_table(max_len, d_model): pos_encoding = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model)) pos_encoding[:, 0::2] = torch.sin(position * div_term) pos_encoding[:, 1::2] = torch.cos(position * div_term) pos_encoding = pos_encoding.unsqueeze(0) return nn.Parameter(pos_encoding, requires_grad=False) class TransformerModel(nn.Module): def __init__(self, config, src_tokenizer, tgt_tokenizer): super(TransformerModel, self).__init__() # 从配置中获取参数 global DEVICE config = config.config DEVICE = torch.device(config['DEVICE']) self.batch_size = config.get('BATCH_SIZE') self.epochs = config.get('EPOCHS') self.learning_rate = config.get('LEARNING_RATE') self.max_seq_len = config.get('MAX_SEQ_LEN') self.d_model = config.get('D_MODEL') self.n_head = config.get('N_HEAD') self.num_layers = config.get('NUM_LAYERS') self.dim_feedforward = config.get('DIM_FEEDFORWARD') self.dropout = config.get('DROPOUT') self.src_tokenizer = src_tokenizer self.tgt_tokenizer = tgt_tokenizer self.transformer = Transformer(d_model=self.d_model, nhead=self.n_head, num_encoder_layers=self.num_layers, num_decoder_layers=self.num_layers, dim_feedforward=self.dim_feedforward, dropout = self.dropout ) # 这里可以添加一些用于嵌入和最终输出的层,例如 #self.embedding = nn.Embedding(self.max_seq_len, self.d_model) src_vocab_size = len(src_tokenizer.word2int) tgt_vocab_size = len(tgt_tokenizer.word2int) self.src_embedding = nn.Embedding(src_vocab_size, self.d_model) self.tgt_embedding = nn.Embedding(tgt_vocab_size, self.d_model) self.fc = nn.Linear(self.d_model, tgt_vocab_size) # 添加位置编码 self.pos_encoding = get_sinusoid_encoding_table(self.max_seq_len, self.d_model) self.to(DEVICE) def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None): # 对输入进行嵌入 src_emb = self.src_embedding(src) * (self.d_model ** 0.5) tgt_emb = self.tgt_embedding(tgt) * (self.d_model ** 0.5) # 添加位置编码 src_emb = src_emb+ self.pos_encoding[0, :src_emb.size(1)] tgt_emb = tgt_emb + self.pos_encoding[0, :tgt_emb.size(1)] # 期望src和tgt的形状为 (seq_len, batch_size) output = self.transformer( src_emb.permute(1, 0, 2), tgt_emb.permute(1, 0, 2), src_mask=src_mask, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask, tgt_key_padding_mask=tgt_padding_mask, ) output = self.fc(output.permute(1, 0, 2)) return output def decode_sentence(self, ids): int2word = self.tgt_tokenizer.int2word word2int = self.tgt_tokenizer.word2int tokens = [int2word[id] for id in ids if id not in {word2int[""], word2int[""], word2int[""]}] return tokens, " ".join(tokens) # 中文不需要空格 def encode_sentence(self, sentence): tokens = sentence.split() word2int = self.src_tokenizer.word2int max_len = self.max_seq_len ids = [word2int.get(token, word2int[""]) for token in tokens] ids = [word2int[""]] + ids[:max_len - 2] + [word2int[""]] ids += [word2int[""]] * (max_len - len(ids)) return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(DEVICE) # 添加 batch 维度 def translate(self, sentence): src_tensor = self.encode_sentence(sentence).to(DEVICE) tgt_tensor = tgt_tensor = torch.tensor([self.tgt_tokenizer.word2int[""]], dtype=torch.long).unsqueeze(0).to(DEVICE) for _ in range(self.max_seq_len): # 生成目标序列的 mask tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE) # 推理得到输出 output = self.forward(src_tensor, tgt_tensor, tgt_mask=tgt_mask) # 取最后一个时间步的预测结果 next_token = output[:, -1, :].argmax(dim=-1).item() # 将预测的 token 添加到目标序列中 tgt_tensor = torch.cat([tgt_tensor, torch.tensor([[next_token]], dtype=torch.long).to(DEVICE)], dim=1) # 如果预测到 ,停止生成 if next_token == self.tgt_tokenizer.word2int[""]: break return self.decode_sentence(tgt_tensor.squeeze(0).tolist()) def translate_no_unk(self, sentence): src_tensor = self.encode_sentence(sentence) tgt_tensor = torch.tensor([self.tgt_tokenizer.word2int[""]], dtype=torch.long).unsqueeze(0).to(DEVICE) for _ in range(self.max_seq_len): # 生成目标序列的 mask tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE) # 推理得到输出 output = self.forward(src_tensor, tgt_tensor, tgt_mask=tgt_mask) # 取最后一个时间步的预测结果 logits = output[:, -1, :] sorted_indices = torch.argsort(logits, dim=-1, descending=True) next_token = sorted_indices[0, 0].item() unk_id = self.tgt_tokenizer.word2int[""] # 如果最大概率是 ,选择概率第二大的 token if next_token == unk_id: next_token = sorted_indices[0, 1].item() # 将预测的 token 添加到目标序列中 tgt_tensor = torch.cat([tgt_tensor, torch.tensor([[next_token]], dtype=torch.long).to(DEVICE)], dim=1) # 如果预测到 ,停止生成 if next_token == self.tgt_tokenizer.word2int[""]: break return self.decode_sentence(tgt_tensor.squeeze(0).tolist()) def translate_beam_search(self, sentence, beam_size=3): src_tensor = self.encode_sentence(sentence) bos_id = self.tgt_tokenizer.word2int[""] eos_id = self.tgt_tokenizer.word2int[""] unk_id = self.tgt_tokenizer.word2int[""] # 初始化 beam beams = [([bos_id], 0)] # (sequence, score) for _ in range(self.max_seq_len): new_beams = [] for seq, score in beams: tgt_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(DEVICE) tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE) output = self.forward(src_tensor, tgt_tensor, tgt_mask=tgt_mask) logits = output[:, -1, :] probs = torch.softmax(logits, dim=-1).squeeze(0) sorted_indices = torch.argsort(probs, dim=-1, descending=True) top_indices = sorted_indices[:beam_size] for idx in top_indices: next_token = idx.item() if next_token == unk_id: continue new_seq = seq + [next_token] new_score = score + torch.log(probs[next_token]).item() new_beams.append((new_seq, new_score)) # 按分数排序并选择前 beam_size 个 new_beams.sort(key=lambda x: x[1], reverse=True) beams = new_beams[:beam_size] # 检查是否有句子以 结尾 ended_beams = [beam for beam in beams if beam[0][-1] == eos_id] if ended_beams: best_beam = max(ended_beams, key=lambda x: x[1]) return self.decode_sentence(best_beam[0]) # 如果没有以 结尾的句子,选择分数最高的 best_beam = max(beams, key=lambda x: x[1]) return self.decode_sentence(best_beam[0]) def translate_no_unk_beam_search(self, sentence, beam_size=3): src_tensor = self.encode_sentence(sentence) bos_id = self.tgt_tokenizer.word2int[""] eos_id = self.tgt_tokenizer.word2int[""] unk_id = self.tgt_tokenizer.word2int[""] # 初始化 beam beams = [([bos_id], 0)] # (sequence, score) for _ in range(self.max_seq_len): new_beams = [] for seq, score in beams: tgt_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(DEVICE) tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE) output = self.forward(src_tensor, tgt_tensor, tgt_mask=tgt_mask) logits = output[:, -1, :] probs = torch.softmax(logits, dim=-1).squeeze(0) sorted_indices = torch.argsort(probs, dim=-1, descending=True) top_indices = sorted_indices[:beam_size + 1] # 多取一个以防第一个是 valid_count = 0 for idx in top_indices: next_token = idx.item() if next_token == unk_id: continue new_seq = seq + [next_token] new_score = score + torch.log(probs[next_token]).item() new_beams.append((new_seq, new_score)) valid_count += 1 if valid_count >= beam_size: break # 按分数排序并选择前 beam_size 个 new_beams.sort(key=lambda x: x[1], reverse=True) beams = new_beams[:beam_size] # 检查是否有句子以 结尾 ended_beams = [beam for beam in beams if beam[0][-1] == eos_id] if ended_beams: best_beam = max(ended_beams, key=lambda x: x[1]) return self.decode_sentence(best_beam[0]) # 如果没有以 结尾的句子,选择分数最高的 best_beam = max(beams, key=lambda x: x[1]) return self.decode_sentence(best_beam[0])