import gradio as gr
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
import langid

# Load models and tokenizers into dictionaries for easier access
models = {
    "en": {
        "fr": 
AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-fr"),
        "es": 
AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-es"),
    },
    "fr": {
        "en": 
AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-fr-en"),
        "es": 
AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-fr-es"),
    },
    "es": {
        "en": 
AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-es-en"),
        "fr": 
AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-es-fr"),
    },
}

tokenizers = {
    "en": {
        "fr": 
AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-fr"),
        "es": 
AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-es"),
    },
    "fr": {
        "en": 
AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-en"),
        "es": 
AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-fr-es"),
    },
    "es": {
        "en": 
AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-en"),
        "fr": 
AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-es-fr"),
    },
}

def translate(input_text, source_lang, target_lang):
    tokenizer = tokenizers[source_lang][target_lang]
    model = models[source_lang][target_lang]
    inputs = tokenizer(input_text, return_tensors="pt")
    translated_tokens = model.generate(**inputs)
    return tokenizer.batch_decode(translated_tokens, 
skip_special_tokens=True)[0]

def translate_text(input_text):
    detected_lang, _ = langid.classify(input_text)
    translations = {"English": "", "French": "", "Spanish": ""}

    if detected_lang == "en":
        translations["French"] = translate(input_text, "en", "fr")
        translations["Spanish"] = translate(input_text, "en", "es")
    elif detected_lang == "fr":
        translations["English"] = translate(input_text, "fr", "en")
        translations["Spanish"] = translate(input_text, "fr", "es")
    elif detected_lang == "es":
        translations["English"] = translate(input_text, "es", "en")
        translations["French"] = translate(input_text, "es", "fr")
    else:
        translations["Error"] = "Language not supported for translation."

    return translations["English"], translations["French"], translations["Spanish"]

def clear_textboxes():
    return "", ""

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            text_to_translate = gr.Textbox(label="Text to Translate")
            translate_btn = gr.Button(value="Translate")
        with gr.Column():
            translation_en = gr.Textbox(label="Translation to English")
            translation_fr = gr.Textbox(label="Translation to French")
            translation_es = gr.Textbox(label="Translation to Spanish")
            clear_btn = gr.Button(value="Clear")
    translate_btn.click(
        fn=translate_text, 
        inputs=[text_to_translate], 
        outputs=[translation_en, translation_fr, translation_es]
    )

    clear_btn.click(
        fn=clear_textboxes, 
        inputs=None, 
        outputs=[text_to_translate, translation_en, translation_fr, translation_es]
    )

demo.launch(share=True)