import torch from src.model import TransformerModel from src.tokenizer import Tokenizer from src.eval import evaluate from src.config import Config from loguru import logger import os torch.cuda.set_device(5) print(os.environ.get("CUDA_VISIBLE_DEVICES")) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f'使用设备:{DEVICE}') def main(): en_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_en.json', int2word_path='./wordtable/int2word_en.json') cn_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_cn.json', int2word_path='./wordtable/int2word_cn.json') config = Config('./config.yaml') model = TransformerModel(config=config,src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer) state_dict = torch.load('./models/Baseline_epoch_40.pth') # 将状态字典加载到模型中 model.load_state_dict(state_dict) model.eval() model.to(DEVICE) evaluate(model=model,config=config,mode='test') if __name__ == '__main__': main()