|
from datasets.load import load_dataset |
|
import logging |
|
import sacrebleu |
|
import pandas as pd |
|
from simpletransformers.t5 import T5Model, T5Args |
|
|
|
raw_datasets = load_dataset('iwslt2017', 'iwslt2017-zh-en') |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
transformers_logger = logging.getLogger("transformers") |
|
transformers_logger.setLevel(logging.WARNING) |
|
|
|
|
|
model_args = T5Args() |
|
model_args.max_length = 512 |
|
model_args.length_penalty = 1 |
|
model_args.num_beams = 10 |
|
|
|
model = T5Model("mt5", "outputs", args=model_args) |
|
|
|
en_zh_test = pd.DataFrame(raw_datasets['test']['translation']) |
|
zh_truth = en_zh_test['zh'].tolist() |
|
en_input = en_zh_test['en'].tolist() |
|
|
|
zh_preds = model.predict(en_input) |
|
en_zh_bleu = sacrebleu.corpus_bleu(zh_preds, zh_truth) |
|
print("----------------------------------------------") |
|
print("English to Chinese: ", en_zh_bleu.score) |
|
|
|
en_preds = model.predict(zh_truth) |
|
zh_en_bleu = sacrebleu.corpus_bleu(en_preds, en_input) |
|
print("----------------------------------------------") |
|
print("Chinese to English: ", zh_en_bleu.score) |
|
|