PDL_translate / app.py
vtiw's picture
Update app.py
df369fa verified
raw
history blame
4.08 kB
import gradio as gr
import nltk
nltk.download('punkt_tab')
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from IndicTransToolkit import IndicProcessor
import torch
# Load IndicTrans2 model
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M"
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True)
ip = IndicProcessor(inference=True)
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
model.to(DEVICE)
def split_text_into_batches(text, max_tokens_per_batch):
sentences = nltk.sent_tokenize(text) # Tokenize text into sentences
batches = []
current_batch = ""
for sentence in sentences:
if len(current_batch) + len(sentence) + 1 <= max_tokens_per_batch: # Add 1 for space
current_batch += sentence + " " # Add sentence to current batch
else:
batches.append(current_batch.strip()) # Add current batch to batches list
current_batch = sentence + " " # Start a new batch with the current sentence
if current_batch:
batches.append(current_batch.strip()) # Add the last batch
return batches
def run_translation(file_uploader, input_text, source_language, target_language):
if file_uploader is not None:
with open(file_uploader.name, "r", encoding="utf-8") as file:
input_text = file.read()
# Language mapping
lang_code_map = {
"Hindi": "hin_Deva",
"Punjabi": "pan_Guru",
"English": "eng_Latn",
}
src_lang = lang_code_map[source_language]
tgt_lang = lang_code_map[target_language]
max_tokens_per_batch = 256
batches = split_text_into_batches(input_text, max_tokens_per_batch)
translated_text = ""
for batch in batches:
batch_preprocessed = ip.preprocess_batch([batch], src_lang=src_lang, tgt_lang=tgt_lang)
inputs = tokenizer(
batch_preprocessed,
truncation=True,
padding="longest",
return_tensors="pt",
return_attention_mask=True,
).to(DEVICE)
with torch.no_grad():
generated_tokens = model.generate(
**inputs,
use_cache=True,
min_length=0,
max_length=256,
num_beams=5,
num_return_sequences=1,
)
with tokenizer.as_target_tokenizer():
decoded_tokens = tokenizer.batch_decode(
generated_tokens.detach().cpu().tolist(),
skip_special_tokens=True,
clean_up_tokenization_spaces=True,
)
translations = ip.postprocess_batch(decoded_tokens, lang=tgt_lang)
translated_text += " ".join(translations) + " "
output = translated_text.strip()
_output_name = "result.txt"
with open(_output_name, "w", encoding="utf-8") as out_file:
out_file.write(output)
return output, _output_name
# Define Gradio UI
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
file_uploader = gr.File(label="Upload a text file (Optional)")
input_text = gr.Textbox(label="Input text", lines=5, placeholder="Enter text here...")
source_language = gr.Dropdown(
label="Source language",
choices=["Hindi", "Punjabi", "English"],
value="Hindi",
)
target_language = gr.Dropdown(
label="Target language",
choices=["Hindi", "Punjabi", "English"],
value="English",
)
btn = gr.Button("Translate")
with gr.Column():
output_text = gr.Textbox(label="Translated text", lines=5)
output_file = gr.File(label="Translated text file")
btn.click(
fn=run_translation,
inputs=[file_uploader, input_text, source_language, target_language],
outputs=[output_text, output_file],
)
if __name__ == "__main__":
demo.launch(debug=True)