Spaces:
Running
Running
import torch | |
import json | |
import torch.nn as nn | |
from .model import TransformerModel # 确保与训练代码的模型定义一致 | |
# 配置参数 | |
MAX_SEQ_LEN = 60 | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
MODEL_PATH = "./results/model/model5.pth" # 模型权重路径 | |
SRC_VOCAB_PATH = "word2int_en.json" # 英文词汇表路径 | |
TGT_VOCAB_PATH = "word2int_cn.json" # 中文词汇表路径 | |
# 加载词汇表 | |
def load_vocab(file_path): | |
with open(file_path, 'r', encoding='utf-8') as f: | |
return json.load(f) | |
# 编码输入句子 | |
def encode_sentence(sentence, vocab, max_len): | |
tokens = sentence.split() | |
ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens] | |
ids = [vocab["<BOS>"]] + ids[:max_len - 2] + [vocab["<EOS>"]] | |
ids += [vocab["<PAD>"]] * (max_len - len(ids)) | |
return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(DEVICE) # 添加 batch 维度 | |
# 解码输出句子 | |
def decode_sentence(ids, vocab): | |
reverse_vocab = {idx: word for word, idx in vocab.items()} | |
tokens = [reverse_vocab[id] for id in ids if id not in {vocab["<PAD>"], vocab["<BOS>"], vocab["<EOS>"]}] | |
return "".join(tokens) # 中文不需要空格 | |
# 加载模型 | |
def load_model(model_path, src_vocab_size, tgt_vocab_size): | |
model = TransformerModel(src_vocab_size, tgt_vocab_size).to(DEVICE) | |
model.load_state_dict(torch.load(model_path, map_location=DEVICE)) | |
model.eval() | |
return model | |
# 翻译函数 | |
def translate(model, sentence, src_vocab, tgt_vocab, max_len): | |
# 编码输入句子 | |
src_tensor = encode_sentence(sentence, src_vocab, max_len) | |
# 初始化目标序列为 <BOS> | |
tgt_tensor = torch.tensor([tgt_vocab["<BOS>"]], dtype=torch.long).unsqueeze(0).to(DEVICE) | |
for _ in range(max_len): | |
# 生成目标序列的 mask | |
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE) | |
# 推理得到输出 | |
output = model(src_tensor, tgt_tensor, tgt_mask=tgt_mask) | |
# 取最后一个时间步的预测结果 | |
next_token = output[:, -1, :].argmax(dim=-1).item() | |
# 将预测的 token 添加到目标序列中 | |
tgt_tensor = torch.cat([tgt_tensor, torch.tensor([[next_token]], dtype=torch.long).to(DEVICE)], dim=1) | |
# 如果预测到 <EOS>,停止生成 | |
if next_token == tgt_vocab["<EOS>"]: | |
break | |
# 解码目标序列为句子 | |
return decode_sentence(tgt_tensor.squeeze(0).tolist(), tgt_vocab) |