|
import os |
|
import re |
|
import sys |
|
from tqdm import tqdm |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
|
|
|
|
langs_supported = { |
|
"eng_Latn": "en", |
|
"ben_Beng": "bn", |
|
"guj_Gujr": "gu", |
|
"hin_Deva": "hi", |
|
"kan_Knda": "kn", |
|
"mal_Mlym": "ml", |
|
"mar_Deva": "mr", |
|
"npi_Deva": "ne", |
|
"ory_Orya": "or", |
|
"pan_Guru": "pa", |
|
"snd_Arab": "sd", |
|
"tam_Taml": "ta", |
|
"urd_Arab": "ur", |
|
} |
|
|
|
|
|
def predict(batch, tokenizer, model, bos_token_id): |
|
encoded_batch = tokenizer(batch, padding=True, return_tensors="pt").to(model.device) |
|
generated_tokens = model.generate( |
|
**encoded_batch, |
|
num_beams=5, |
|
max_length=256, |
|
min_length=0, |
|
forced_bos_token_id=bos_token_id, |
|
) |
|
hypothesis = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True) |
|
return hypothesis |
|
|
|
|
|
def main(devtest_data_dir, batch_size): |
|
|
|
model_name = "facebook/m2m100-12B-last-ckpt" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
model.eval() |
|
|
|
|
|
for pair in sorted(os.listdir(devtest_data_dir)): |
|
if "-" not in pair: |
|
continue |
|
|
|
src_lang, tgt_lang = pair.split("-") |
|
|
|
|
|
if ( |
|
src_lang not in langs_supported.keys() |
|
or tgt_lang not in langs_supported.keys() |
|
): |
|
print(f"Skipping {src_lang}-{tgt_lang} ...") |
|
continue |
|
|
|
|
|
|
|
|
|
print(f"Evaluating {src_lang}-{tgt_lang} ...") |
|
|
|
infname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}") |
|
outfname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}.pred.m2m100") |
|
|
|
with open(infname, "r") as f: |
|
src_sents = f.read().split("\n") |
|
|
|
add_new_line = False |
|
if src_sents[-1] == "": |
|
add_new_line = True |
|
src_sents = src_sents[:-1] |
|
|
|
|
|
tokenizer.src_lang = langs_supported[src_lang] |
|
|
|
|
|
hypothesis = [] |
|
for i in tqdm(range(0, len(src_sents), batch_size)): |
|
start, end = i, int(min(len(src_sents), i + batch_size)) |
|
batch = src_sents[start:end] |
|
bos_token_id = tokenizer.lang_code_to_id[langs_supported[tgt_lang]] |
|
hypothesis += predict(batch, tokenizer, model, bos_token_id) |
|
|
|
assert len(hypothesis) == len(src_sents) |
|
|
|
hypothesis = [ |
|
re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip() |
|
for x in hypothesis |
|
] |
|
if add_new_line: |
|
hypothesis = hypothesis |
|
|
|
with open(outfname, "w") as f: |
|
f.write("\n".join(hypothesis)) |
|
|
|
|
|
|
|
|
|
infname = os.path.join(devtest_data_dir, pair, f"test.{tgt_lang}") |
|
outfname = os.path.join(devtest_data_dir, pair, f"test.{src_lang}.pred.m2m100") |
|
|
|
with open(infname, "r") as f: |
|
src_sents = f.read().split("\n") |
|
|
|
add_new_line = False |
|
if src_sents[-1] == "": |
|
add_new_line = True |
|
src_sents = src_sents[:-1] |
|
|
|
|
|
tokenizer.src_lang = langs_supported[tgt_lang] |
|
|
|
|
|
hypothesis = [] |
|
for i in tqdm(range(0, len(src_sents), batch_size)): |
|
start, end = i, int(min(len(src_sents), i + batch_size)) |
|
batch = src_sents[start:end] |
|
bos_token_id = tokenizer.lang_code_to_id[langs_supported[src_lang]] |
|
hypothesis += predict(batch, tokenizer, model, bos_token_id) |
|
|
|
assert len(hypothesis) == len(src_sents) |
|
|
|
hypothesis = [ |
|
re.sub("\s+", " ", x.replace("\n", " ").replace("\t", " ")).strip() |
|
for x in hypothesis |
|
] |
|
if add_new_line: |
|
hypothesis = hypothesis |
|
|
|
with open(outfname, "w") as f: |
|
f.write("\n".join(hypothesis)) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
devtest_data_dir = sys.argv[1] |
|
batch_size = int(sys.argv[2]) |
|
|
|
if not torch.cuda.is_available(): |
|
print("No GPU available") |
|
sys.exit(1) |
|
|
|
main(devtest_data_dir, batch_size) |
|
|