|
import torch |
|
from src.model import TransformerModel |
|
from src.tokenizer import Tokenizer |
|
from src.eval import evaluate |
|
from src.config import Config |
|
from loguru import logger |
|
import os |
|
torch.cuda.set_device(5) |
|
print(os.environ.get("CUDA_VISIBLE_DEVICES")) |
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
logger.info(f'使用设备:{DEVICE}') |
|
|
|
def main(): |
|
en_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_en.json', int2word_path='./wordtable/int2word_en.json') |
|
cn_tokenizer = Tokenizer(word2int_path='./wordtable/word2int_cn.json', int2word_path='./wordtable/int2word_cn.json') |
|
config = Config('./config.yaml') |
|
model = TransformerModel(config=config,src_tokenizer=en_tokenizer, tgt_tokenizer=cn_tokenizer) |
|
state_dict = torch.load('./models/Baseline_epoch_40.pth') |
|
|
|
model.load_state_dict(state_dict) |
|
model.eval() |
|
model.to(DEVICE) |
|
evaluate(model=model,config=config,mode='test') |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |