File size: 1,033 Bytes
5cedadf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from datasets.load import load_dataset
import logging
import sacrebleu
import pandas as pd
from simpletransformers.t5 import T5Model, T5Args

raw_datasets = load_dataset('iwslt2017', 'iwslt2017-zh-en')

logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)


model_args = T5Args()
model_args.max_length = 512
model_args.length_penalty = 1
model_args.num_beams = 10

model = T5Model("mt5", "outputs", args=model_args)

en_zh_test = pd.DataFrame(raw_datasets['test']['translation'])
zh_truth = en_zh_test['zh'].tolist()
en_input = en_zh_test['en'].tolist()

zh_preds = model.predict(en_input)
en_zh_bleu = sacrebleu.corpus_bleu(zh_preds, zh_truth)
print("----------------------------------------------")
print("English to Chinese: ", en_zh_bleu.score)

en_preds = model.predict(zh_truth)
zh_en_bleu = sacrebleu.corpus_bleu(en_preds, en_input)
print("----------------------------------------------")
print("Chinese to English: ", zh_en_bleu.score)