Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import MarianMTModel, MarianTokenizer | |
import torch | |
import nltk | |
# Download punkt for sentence tokenization | |
nltk.download('punkt') | |
from nltk.tokenize import sent_tokenize | |
# Cache for storing models and tokenizers | |
models_cache = {} | |
def load_model(model_name): | |
""" | |
Load and cache the MarianMT model and tokenizer. | |
""" | |
if model_name not in models_cache: | |
tokenizer = MarianTokenizer.from_pretrained(model_name) | |
model = MarianMTModel.from_pretrained(model_name) | |
if torch.cuda.is_available(): | |
model = model.to('cuda') | |
models_cache[model_name] = (model, tokenizer) | |
return models_cache[model_name] | |
def translate_text(model_name, text): | |
""" | |
Translate input text sentence by sentence using the specified model. | |
""" | |
if not model_name or not text: | |
return "Please select a model and provide text for translation." | |
try: | |
# Load the model and tokenizer | |
model, tokenizer = load_model(model_name) | |
# Split text into sentences | |
sentences = sent_tokenize(text) | |
translated_sentences = [] | |
for sentence in sentences: | |
# Tokenize the sentence | |
tokens = tokenizer(sentence, return_tensors="pt", padding=True) | |
if torch.cuda.is_available(): | |
tokens = {k: v.to('cuda') for k, v in tokens.items()} | |
# Generate translation for the sentence | |
translated = model.generate(**tokens) | |
translated_text = tokenizer.decode(translated[0], skip_special_tokens=True) | |
translated_sentences.append(translated_text) | |
# Join translated sentences back into a single string | |
return " ".join(translated_sentences) | |
except Exception as e: | |
return f"Error: {str(e)}" | |
# Model options | |
model_options = [ | |
("English to Turkish", "Helsinki-NLP/opus-mt-tc-big-en-tr"), | |
("Turkish to English", "Helsinki-NLP/opus-mt-tc-big-tr-en"), | |
# Add other models here... | |
] | |
# Create Gradio interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# π Real-Time Sentence Translation") | |
with gr.Row(): | |
model_dropdown = gr.Dropdown( | |
label="Select Translation Model", | |
choices=[option[1] for option in model_options], | |
type="value", | |
) | |
with gr.Row(): | |
input_text = gr.Textbox( | |
label="Enter text (complete sentences)", | |
lines=5, | |
placeholder="Type here...", | |
) | |
with gr.Row(): | |
translate_button = gr.Button("Translate") | |
clear_button = gr.Button("Clear") | |
output_text = gr.Textbox(label="Translated Text", interactive=False) | |
def clear_inputs(): | |
return "", "" | |
translate_button.click( | |
fn=translate_text, | |
inputs=[model_dropdown, input_text], | |
outputs=output_text, | |
) | |
clear_button.click( | |
fn=clear_inputs, | |
inputs=[], | |
outputs=[input_text, output_text], | |
) | |
# Run the Gradio app | |
demo.launch() | |