import torch import json import torch.nn as nn from .model import TransformerModel # 确保与训练代码的模型定义一致 # 配置参数 MAX_SEQ_LEN = 60 DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_PATH = "./results/model/model5.pth" # 模型权重路径 SRC_VOCAB_PATH = "word2int_en.json" # 英文词汇表路径 TGT_VOCAB_PATH = "word2int_cn.json" # 中文词汇表路径 # 加载词汇表 def load_vocab(file_path): with open(file_path, 'r', encoding='utf-8') as f: return json.load(f) # 编码输入句子 def encode_sentence(sentence, vocab, max_len): tokens = sentence.split() ids = [vocab.get(token, vocab[""]) for token in tokens] ids = [vocab[""]] + ids[:max_len - 2] + [vocab[""]] ids += [vocab[""]] * (max_len - len(ids)) return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(DEVICE) # 添加 batch 维度 # 解码输出句子 def decode_sentence(ids, vocab): reverse_vocab = {idx: word for word, idx in vocab.items()} tokens = [reverse_vocab[id] for id in ids if id not in {vocab[""], vocab[""], vocab[""]}] return "".join(tokens) # 中文不需要空格 # 加载模型 def load_model(model_path, src_vocab_size, tgt_vocab_size): model = TransformerModel(src_vocab_size, tgt_vocab_size).to(DEVICE) model.load_state_dict(torch.load(model_path, map_location=DEVICE)) model.eval() return model # 翻译函数 def translate(model, sentence, src_vocab, tgt_vocab, max_len): # 编码输入句子 src_tensor = encode_sentence(sentence, src_vocab, max_len) # 初始化目标序列为 tgt_tensor = torch.tensor([tgt_vocab[""]], dtype=torch.long).unsqueeze(0).to(DEVICE) for _ in range(max_len): # 生成目标序列的 mask tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE) # 推理得到输出 output = model(src_tensor, tgt_tensor, tgt_mask=tgt_mask) # 取最后一个时间步的预测结果 next_token = output[:, -1, :].argmax(dim=-1).item() # 将预测的 token 添加到目标序列中 tgt_tensor = torch.cat([tgt_tensor, torch.tensor([[next_token]], dtype=torch.long).to(DEVICE)], dim=1) # 如果预测到 ,停止生成 if next_token == tgt_vocab[""]: break # 解码目标序列为句子 return decode_sentence(tgt_tensor.squeeze(0).tolist(), tgt_vocab)