|
import os |
|
import re |
|
import sys |
|
from tqdm import tqdm |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
|
|
langs_supported = { |
|
"eng_Latn": "en_XX", |
|
"guj_Gujr": "gu_IN", |
|
"hin_Deva": "hi_IN", |
|
"npi_Deva": "ne_NP", |
|
"ben_Beng": "bn_IN", |
|
"mal_Mlym": "ml_IN", |
|
"mar_Deva": "mr_IN", |
|
"tam_Taml": "ta_IN", |
|
"tel_Telu": "te_IN", |
|
"urd_Arab": "ur_PK", |
|
} |
|
|
|
|
|
def predict(batch, tokenizer, model, bos_token_id): |
|
encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device) |
|
generated_tokens = model.generate( |
|
**encoded_batch, |
|
num_beams=5, |
|
max_length=256, |
|
min_length=0, |
|
forced_bos_token_id=bos_token_id, |
|
) |
|
hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
return hypothesis |
|
|
|
|
|
def main(devtest_data_dir, batch_size): |
|
|
|
enxx_model_name = "facebook/mbart-large-50-one-to-many-mmt" |
|
xxen_model_name = "facebook/mbart-large-50-many-to-one-mmt" |
|
tokenizers = { |
|
"enxx": AutoTokenizer.from_pretrained(enxx_model_name), |
|
"xxen": AutoTokenizer.from_pretrained(xxen_model_name), |
|
} |
|
models = { |
|
"enxx": AutoModelForSeq2SeqLM.from_pretrained(enxx_model_name).cuda(), |
|
"xxen": AutoModelForSeq2SeqLM.from_pretrained(xxen_model_name).cuda(), |
|
} |
|
|
|
|
|
for model_name in models: |
|
models[model_name].eval() |
|
|
|
|
|
for pair in sorted(os.listdir(devtest_data_dir)): |
|
if "-" not in pair: |
|
continue |
|
|
|
src_lang, tgt_lang = pair.split("-") |
|
|
|
|
|
if ( |
|
src_lang not in langs_supported.keys() |
|
or tgt_lang not in langs_supported.keys() |
|
): |
|
print(f"Skipping {src_lang}-{tgt_lang} ...") |
|
continue |
|
|
|
|
|
|
|
|
|
print(f"Evaluating {src_lang}-{tgt_lang} ...") |
|
|
|
infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}") |
|
outfname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}.pred.mbart50") |
|
|
|
with open(infname, "r") as f: |
|
src_sents = f.read().split("\n") |
|
|
|
add_new_line = False |
|
if src_sents[-1] == "": |
|
add_new_line = True |
|
src_sents = src_sents[:-1] |
|
|
|
|
|
tokenizers["enxx"].src_lang = langs_supported[src_lang] |
|
|
|
|
|
hypothesis = [] |
|
for i in tqdm(range(0, len(src_sents), batch_size)): |
|
start, end = i, int(min(len(src_sents), i + batch_size)) |
|
batch = src_sents[start:end] |
|
bos_token_id = tokenizers["enxx"].lang_code_to_id[langs_supported[tgt_lang]] |
|
hypothesis += predict( |
|
batch, tokenizers["enxx"], models["enxx"], bos_token_id |
|
) |
|
|
|
assert len(hypothesis) == len(src_sents) |
|
|
|
hypothesis = [ |
|
re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip() |
|
for x in hypothesis |
|
] |
|
if add_new_line: |
|
hypothesis = hypothesis |
|
|
|
with open(outfname, "w") as f: |
|
f.write("\n".join(hypothesis)) |
|
|
|
|
|
|
|
|
|
infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}") |
|
outfname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}.pred.mbart50") |
|
|
|
with open(infname, "r") as f: |
|
src_sents = f.read().split("\n") |
|
|
|
add_new_line = False |
|
if src_sents[-1] == "": |
|
add_new_line = True |
|
src_sents = src_sents[:-1] |
|
|
|
|
|
tokenizers["xxen"].src_lang = langs_supported[tgt_lang] |
|
|
|
|
|
hypothesis = [] |
|
for i in tqdm(range(0, len(src_sents), batch_size)): |
|
start, end = i, int(min(len(src_sents), i + batch_size)) |
|
batch = src_sents[start:end] |
|
bos_token_id = tokenizers["xxen"].lang_code_to_id[langs_supported[src_lang]] |
|
hypothesis += predict( |
|
batch, tokenizers["xxen"], models["xxen"], bos_token_id |
|
) |
|
|
|
assert len(hypothesis) == len(src_sents) |
|
|
|
hypothesis = [ |
|
re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip() |
|
for x in hypothesis |
|
] |
|
if add_new_line: |
|
hypothesis = hypothesis |
|
|
|
with open(outfname, "w") as f: |
|
f.write("\n".join(hypothesis)) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
devtest_data_dir = sys.argv[1] |
|
batch_size = int(sys.argv[2]) |
|
|
|
if not torch.cuda.is_available(): |
|
print("No GPU available") |
|
sys.exit(1) |
|
|
|
main(devtest_data_dir, batch_size) |
|
|