File size: 1,024 Bytes
f8bd4d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
from src.model import TransformerModel
from src.tokenizer import Tokenizer
from src.eval import evaluate
from src.config import Config
from loguru import logger
import os
torch.cuda.set_device(5)
print(os.environ.get("CUDA_VISIBLE_DEVICES"))
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f'使用设备:{DEVICE}')
def main():
en_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_en.json', int2word_path='./wordtable/int2word_en.json')
cn_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_cn.json', int2word_path='./wordtable/int2word_cn.json')
config = Config('./config.yaml')
model = TransformerModel(config=config,src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer)
state_dict = torch.load('./models/Baseline_epoch_40.pth')
# 将状态字典加载到模型中
model.load_state_dict(state_dict)
model.eval()
model.to(DEVICE)
evaluate(model=model,config=config,mode='test')
if __name__ == '__main__':
main() |