ashourzadeh7's picture
Update app.py
cefb4b4 verified
raw
history blame
2.67 kB
import os
import torch
import gradio as gr
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores200_codes import flores_codes
def load_models():
# build model and tokenizer
model_name_dict = {'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
#'nllb-1.3B': 'facebook/nllb-200-1.3B',
#'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
#'nllb-3.3B': 'facebook/nllb-200-3.3B',
}
model_dict = {}
for call_name, real_name in model_name_dict.items():
print('\tLoading model: %s' % call_name)
model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
tokenizer = AutoTokenizer.from_pretrained(real_name)
model_dict[call_name+'_model'] = model
model_dict[call_name+'_tokenizer'] = tokenizer
return model_dict
LANGS = ["pes_Arab", "ckb_Arab", "eng_Latn"]
langs_dict = {
"فارسی": "pes_Arab",
"کردی": "ckb_Arab",
"انگلیسی": "eng_Latn"
}
def translate(text, src_lang, tgt_lang):
"""
Translate the text from source lang to target lang
"""
if len(model_dict) == 2:
model_name = 'nllb-3.3B'
model = model_dict[model_name + '_model']
tokenizer = model_dict[model_name + '_tokenizer']
translation_pipeline = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=langs_dict[src_lang], tgt_lang=langs_dict[tgt_lang], max_length=400, device="cpu")
result = translation_pipeline(text)
return result[0]['translation_text']
def file_translate(sorce_file_path, pred_file_path):
sorce_list = []
with open(sorce_file_path, "r", encoding="utf-8") as sorce_file:
for line in sorce_file:
sorce_list.append(line.strip())
pred_list = []
for line in sorce_list:
pred_list.append(translate(line, list(langs_dict.keys())[0], list(langs_dict.keys())[1]))
with open(pred_file_path, "w", encoding="utf-8") as output_file:
for translation in pred_list:
output_file.write(translation + "\n")
return pred_file_path
if __name__ == '__main__':
print('\tinit models')
global model_dict
model_dict = load_models()
interface = gr.Interface(
fn=file_translate,
inputs=[
gr.components.File(label="Input File"),
gr.components.Textbox(label="Output File Name (optional)"),
],
outputs=[
gr.components.File(label="Modified File"),
],
title="Add 'Hello' Line to Text File",
description="This Gradio demo adds the line 'Hello' to the end of a text file.",
)
interface.launch()