Spaces:
Running
Running
File size: 2,568 Bytes
4fb86de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import torch
import json
import torch.nn as nn
from .model import TransformerModel # 确保与训练代码的模型定义一致
# 配置参数
MAX_SEQ_LEN = 60
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "./results/model/model5.pth" # 模型权重路径
SRC_VOCAB_PATH = "word2int_en.json" # 英文词汇表路径
TGT_VOCAB_PATH = "word2int_cn.json" # 中文词汇表路径
# 加载词汇表
def load_vocab(file_path):
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
# 编码输入句子
def encode_sentence(sentence, vocab, max_len):
tokens = sentence.split()
ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
ids = [vocab["<BOS>"]] + ids[:max_len - 2] + [vocab["<EOS>"]]
ids += [vocab["<PAD>"]] * (max_len - len(ids))
return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(DEVICE) # 添加 batch 维度
# 解码输出句子
def decode_sentence(ids, vocab):
reverse_vocab = {idx: word for word, idx in vocab.items()}
tokens = [reverse_vocab[id] for id in ids if id not in {vocab["<PAD>"], vocab["<BOS>"], vocab["<EOS>"]}]
return "".join(tokens) # 中文不需要空格
# 加载模型
def load_model(model_path, src_vocab_size, tgt_vocab_size):
model = TransformerModel(src_vocab_size, tgt_vocab_size).to(DEVICE)
model.load_state_dict(torch.load(model_path, map_location=DEVICE))
model.eval()
return model
# 翻译函数
def translate(model, sentence, src_vocab, tgt_vocab, max_len):
# 编码输入句子
src_tensor = encode_sentence(sentence, src_vocab, max_len)
# 初始化目标序列为 <BOS>
tgt_tensor = torch.tensor([tgt_vocab["<BOS>"]], dtype=torch.long).unsqueeze(0).to(DEVICE)
for _ in range(max_len):
# 生成目标序列的 mask
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE)
# 推理得到输出
output = model(src_tensor, tgt_tensor, tgt_mask=tgt_mask)
# 取最后一个时间步的预测结果
next_token = output[:, -1, :].argmax(dim=-1).item()
# 将预测的 token 添加到目标序列中
tgt_tensor = torch.cat([tgt_tensor, torch.tensor([[next_token]], dtype=torch.long).to(DEVICE)], dim=1)
# 如果预测到 <EOS>,停止生成
if next_token == tgt_vocab["<EOS>"]:
break
# 解码目标序列为句子
return decode_sentence(tgt_tensor.squeeze(0).tolist(), tgt_vocab) |