Suburst's picture
Upload 27 files
f8bd4d2 verified
raw
history blame
1.02 kB
import torch
from src.model import TransformerModel
from src.tokenizer import Tokenizer
from src.eval import evaluate
from src.config import Config
from loguru import logger
import os
torch.cuda.set_device(5)
print(os.environ.get("CUDA_VISIBLE_DEVICES"))
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f'使用设备:{DEVICE}')
def main():
en_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_en.json', int2word_path='./wordtable/int2word_en.json')
cn_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_cn.json', int2word_path='./wordtable/int2word_cn.json')
config = Config('./config.yaml')
model = TransformerModel(config=config,src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer)
state_dict = torch.load('./models/Baseline_epoch_40.pth')
# 将状态字典加载到模型中
model.load_state_dict(state_dict)
model.eval()
model.to(DEVICE)
evaluate(model=model,config=config,mode='test')
if __name__ == '__main__':
main()