|
import gradio as gr |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline |
|
import torch |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-3.3B").to(device) |
|
tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-3.3B") |
|
|
|
|
|
LANGS = ["pes_Arab", "ckb_Arab", "eng_Latn"] |
|
langs_dict = { |
|
"فارسی": "pes_Arab", |
|
"کردی": "ckb_Arab", |
|
"انگلیسی": "eng_Latn" |
|
} |
|
|
|
def translate(text, src_lang, tgt_lang): |
|
""" |
|
Translate the text from source lang to target lang |
|
""" |
|
translation_pipeline = pipeline("translation", model=model, tokenizer=tokenizer, src_lang=langs_dict[src_lang], tgt_lang=langs_dict[tgt_lang], max_length=400, device=device) |
|
result = translation_pipeline(text) |
|
return result[0]['translation_text'] |
|
|
|
def file_translate(sorce_file_path, src_lang, tgt_lang, pred_file_path): |
|
sorce_list = [] |
|
with open(sorce_file_path, "r", encoding="utf-8") as sorce_file: |
|
for line in sorce_file: |
|
sorce_list.append(line.strip()) |
|
|
|
pred_list = [] |
|
for line in sorce_list: |
|
pred_list.append(translate(line, src_lang, tgt_lang)) |
|
|
|
with open(pred_file_path, "w", encoding="utf-8") as output_file: |
|
for translation in pred_list: |
|
output_file.write(translation + "\n") |
|
return pred_file_path |
|
|
|
def add_line(input_path, output_path): |
|
|
|
with open(input_path, encoding="utf-8") as f: |
|
text = f.read() |
|
|
|
|
|
new_text = text + "\nسلام" |
|
|
|
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
f.write(new_text) |
|
|
|
return output_path |
|
|
|
if __name__ == '__main__': |
|
|
|
interface = gr.Interface( |
|
fn=file_translate, |
|
inputs=[ |
|
gr.components.File(label="Input File"), |
|
gr.components.Dropdown(label="زبان مبدا", choices=list(langs_dict.keys())), |
|
gr.components.Dropdown(label="زبان مقصد", choices=list(langs_dict.keys())), |
|
gr.components.Textbox(label="Output File Name (optional)"), |
|
], |
|
outputs=[ |
|
gr.components.File(label="Modified File"), |
|
], |
|
title="NLLB 3.3B - (Translation Demo)", |
|
description="This Gradio demo translate text files. (CPU)", |
|
) |
|
interface.launch() |
|
|
|
|