File size: 1,024 Bytes
f8bd4d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import torch
from src.model import TransformerModel
from src.tokenizer import Tokenizer
from src.eval import evaluate
from src.config import Config
from loguru import logger
import os
torch.cuda.set_device(5)
print(os.environ.get("CUDA_VISIBLE_DEVICES"))
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f'使用设备:{DEVICE}')

def main():
    en_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_en.json', int2word_path='./wordtable/int2word_en.json')
    cn_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_cn.json', int2word_path='./wordtable/int2word_cn.json')
    config = Config('./config.yaml')
    model = TransformerModel(config=config,src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer) 
    state_dict = torch.load('./models/Baseline_epoch_40.pth')
    # 将状态字典加载到模型中
    model.load_state_dict(state_dict)
    model.eval()
    model.to(DEVICE)
    evaluate(model=model,config=config,mode='test')


if __name__ == '__main__':
    main()