Spaces:
Running
Running
import gradio as gr | |
import nltk | |
nltk.download('punkt_tab') | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from IndicTransToolkit import IndicProcessor | |
import torch | |
# Load IndicTrans2 model | |
model_name = "ai4bharat/indictrans2-indic-indic-dist-320M" | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name, trust_remote_code=True) | |
ip = IndicProcessor(inference=True) | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(DEVICE) | |
def split_text_into_batches(text, max_tokens_per_batch): | |
sentences = nltk.sent_tokenize(text) # Tokenize text into sentences | |
batches = [] | |
current_batch = "" | |
for sentence in sentences: | |
if len(current_batch) + len(sentence) + 1 <= max_tokens_per_batch: # Add 1 for space | |
current_batch += sentence + " " # Add sentence to current batch | |
else: | |
batches.append(current_batch.strip()) # Add current batch to batches list | |
current_batch = sentence + " " # Start a new batch with the current sentence | |
if current_batch: | |
batches.append(current_batch.strip()) # Add the last batch | |
return batches | |
def run_translation(file_uploader, input_text, source_language, target_language): | |
if file_uploader is not None: | |
with open(file_uploader.name, "r", encoding="utf-8") as file: | |
input_text = file.read() | |
# Language mapping | |
lang_code_map = { | |
"Hindi": "hin_Deva", | |
"Punjabi": "pan_Guru", | |
"English": "eng_Latn", | |
} | |
src_lang = lang_code_map[source_language] | |
tgt_lang = lang_code_map[target_language] | |
max_tokens_per_batch = 256 | |
batches = split_text_into_batches(input_text, max_tokens_per_batch) | |
translated_text = "" | |
for batch in batches: | |
batch_preprocessed = ip.preprocess_batch([batch], src_lang=src_lang, tgt_lang=tgt_lang) | |
inputs = tokenizer( | |
batch_preprocessed, | |
truncation=True, | |
padding="longest", | |
return_tensors="pt", | |
return_attention_mask=True, | |
).to(DEVICE) | |
with torch.no_grad(): | |
generated_tokens = model.generate( | |
**inputs, | |
use_cache=True, | |
min_length=0, | |
max_length=256, | |
num_beams=5, | |
num_return_sequences=1, | |
) | |
with tokenizer.as_target_tokenizer(): | |
decoded_tokens = tokenizer.batch_decode( | |
generated_tokens.detach().cpu().tolist(), | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True, | |
) | |
translations = ip.postprocess_batch(decoded_tokens, lang=tgt_lang) | |
translated_text += " ".join(translations) + " " | |
output = translated_text.strip() | |
_output_name = "result.txt" | |
with open(_output_name, "w", encoding="utf-8") as out_file: | |
out_file.write(output) | |
return output, _output_name | |
# Define Gradio UI | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
with gr.Column(): | |
file_uploader = gr.File(label="Upload a text file (Optional)") | |
input_text = gr.Textbox(label="Input text", lines=5, placeholder="Enter text here...") | |
source_language = gr.Dropdown( | |
label="Source language", | |
choices=["Hindi", "Punjabi", "English"], | |
value="Hindi", | |
) | |
target_language = gr.Dropdown( | |
label="Target language", | |
choices=["Hindi", "Punjabi", "English"], | |
value="English", | |
) | |
btn = gr.Button("Translate") | |
with gr.Column(): | |
output_text = gr.Textbox(label="Translated text", lines=5) | |
output_file = gr.File(label="Translated text file") | |
btn.click( | |
fn=run_translation, | |
inputs=[file_uploader, input_text, source_language, target_language], | |
outputs=[output_text, output_file], | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |