|
import os |
|
import torch |
|
import gradio as gr |
|
import time |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
from flores200_codes import flores_codes |
|
|
|
|
|
def load_models(): |
|
|
|
model_name_dict = {'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M', |
|
|
|
|
|
|
|
} |
|
|
|
model_dict = {} |
|
|
|
for call_name, real_name in model_name_dict.items(): |
|
print('\tLoading model: %s' % call_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(real_name) |
|
tokenizer = AutoTokenizer.from_pretrained(real_name) |
|
model_dict[call_name+'_model'] = model |
|
model_dict[call_name+'_tokenizer'] = tokenizer |
|
|
|
return model_dict |
|
|
|
LANGS = ["pes_Arab", "ckb_Arab", "eng_Latn"] |
|
langs_dict = { |
|
"فارسی": "pes_Arab", |
|
"کردی": "ckb_Arab", |
|
"انگلیسی": "eng_Latn" |
|
} |
|
|
|
def translate(text, src_lang, tgt_lang): |
|
""" |
|
Translate the text from source lang to target lang |
|
""" |
|
|
|
if len(model_dict) == 2: |
|
model_name = 'nllb-3.3B' |
|
model = model_dict[model_name + '_model'] |
|
tokenizer = model_dict[model_name + '_tokenizer'] |
|
|
|
translation_pipeline = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=langs_dict[src_lang], tgt_lang=langs_dict[tgt_lang], max_length=400, device="cpu") |
|
result = translation_pipeline(text) |
|
return result[0]['translation_text'] |
|
|
|
def file_translate(sorce_file_path, pred_file_path): |
|
sorce_list = [] |
|
with open(sorce_file_path, "r", encoding="utf-8") as sorce_file: |
|
for line in sorce_file: |
|
sorce_list.append(line.strip()) |
|
|
|
pred_list = [] |
|
for line in sorce_list: |
|
pred_list.append(translate(line, list(langs_dict.keys())[0], list(langs_dict.keys())[1])) |
|
|
|
with open(pred_file_path, "w", encoding="utf-8") as output_file: |
|
for translation in pred_list: |
|
output_file.write(translation + "\n") |
|
return pred_file_path |
|
|
|
if __name__ == '__main__': |
|
print('\tinit models') |
|
|
|
global model_dict |
|
|
|
model_dict = load_models() |
|
|
|
interface = gr.Interface( |
|
fn=file_translate, |
|
inputs=[ |
|
gr.components.File(label="Input File"), |
|
gr.components.Textbox(label="Output File Name (optional)"), |
|
], |
|
outputs=[ |
|
gr.components.File(label="Modified File"), |
|
], |
|
title="Add 'Hello' Line to Text File", |
|
description="This Gradio demo adds the line 'Hello' to the end of a text file.", |
|
) |
|
interface.launch() |
|
|
|
|