File size: 1,033 Bytes
5cedadf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 |
from datasets.load import load_dataset
import logging
import sacrebleu
import pandas as pd
from simpletransformers.t5 import T5Model, T5Args
raw_datasets = load_dataset('iwslt2017', 'iwslt2017-zh-en')
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
model_args = T5Args()
model_args.max_length = 512
model_args.length_penalty = 1
model_args.num_beams = 10
model = T5Model("mt5", "outputs", args=model_args)
en_zh_test = pd.DataFrame(raw_datasets['test']['translation'])
zh_truth = en_zh_test['zh'].tolist()
en_input = en_zh_test['en'].tolist()
zh_preds = model.predict(en_input)
en_zh_bleu = sacrebleu.corpus_bleu(zh_preds, zh_truth)
print("----------------------------------------------")
print("English to Chinese: ", en_zh_bleu.score)
en_preds = model.predict(zh_truth)
zh_en_bleu = sacrebleu.corpus_bleu(en_preds, en_input)
print("----------------------------------------------")
print("Chinese to English: ", zh_en_bleu.score)
|