File size: 2,568 Bytes
4fb86de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
import json
import torch.nn as nn
from .model import TransformerModel  # 确保与训练代码的模型定义一致

# 配置参数
MAX_SEQ_LEN = 60
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_PATH = "./results/model/model5.pth"  # 模型权重路径
SRC_VOCAB_PATH = "word2int_en.json"  # 英文词汇表路径
TGT_VOCAB_PATH = "word2int_cn.json"  # 中文词汇表路径

# 加载词汇表
def load_vocab(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        return json.load(f)

# 编码输入句子
def encode_sentence(sentence, vocab, max_len):
    tokens = sentence.split()
    ids = [vocab.get(token, vocab["<UNK>"]) for token in tokens]
    ids = [vocab["<BOS>"]] + ids[:max_len - 2] + [vocab["<EOS>"]]
    ids += [vocab["<PAD>"]] * (max_len - len(ids))
    return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(DEVICE)  # 添加 batch 维度

# 解码输出句子
def decode_sentence(ids, vocab):
    reverse_vocab = {idx: word for word, idx in vocab.items()}
    tokens = [reverse_vocab[id] for id in ids if id not in {vocab["<PAD>"], vocab["<BOS>"], vocab["<EOS>"]}]
    return "".join(tokens)  # 中文不需要空格

# 加载模型
def load_model(model_path, src_vocab_size, tgt_vocab_size):
    model = TransformerModel(src_vocab_size, tgt_vocab_size).to(DEVICE)
    model.load_state_dict(torch.load(model_path, map_location=DEVICE))
    model.eval()
    return model

# 翻译函数
def translate(model, sentence, src_vocab, tgt_vocab, max_len):
    # 编码输入句子
    src_tensor = encode_sentence(sentence, src_vocab, max_len)

    # 初始化目标序列为 <BOS>
    tgt_tensor = torch.tensor([tgt_vocab["<BOS>"]], dtype=torch.long).unsqueeze(0).to(DEVICE)

    for _ in range(max_len):
        # 生成目标序列的 mask
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE)

        # 推理得到输出
        output = model(src_tensor, tgt_tensor, tgt_mask=tgt_mask)

        # 取最后一个时间步的预测结果
        next_token = output[:, -1, :].argmax(dim=-1).item()

        # 将预测的 token 添加到目标序列中
        tgt_tensor = torch.cat([tgt_tensor, torch.tensor([[next_token]], dtype=torch.long).to(DEVICE)], dim=1)

        # 如果预测到 <EOS>,停止生成
        if next_token == tgt_vocab["<EOS>"]:
            break

    # 解码目标序列为句子
    return decode_sentence(tgt_tensor.squeeze(0).tolist(), tgt_vocab)