import gradio as gr
import torch
import torchaudio
import numpy as np
from transformers import AutoProcessor, SeamlessM4Tv2Model

class TranslationModel:
    def __init__(self):
        self.model_name = "facebook/seamless-m4t-v2-large"
        print("Loading model...")
        self.processor = AutoProcessor.from_pretrained(self.model_name)
        self.model = SeamlessM4Tv2Model.from_pretrained(self.model_name)
        self.sample_rate = self.model.config.sampling_rate
        
        self.languages = {
            "English": "eng",
            "Spanish": "spa",
            "French": "fra",
            "German": "deu",
            "Italian": "ita",
            "Portuguese": "por",
            "Russian": "rus",
            "Chinese": "cmn",
            "Japanese": "jpn",
            "Korean": "kor"
        }

    def translate_text(self, text, src_lang, tgt_lang, progress=gr.Progress()):
        try:
            progress(0.3, desc="Processing...")
            inputs = self.processor(text=text, src_lang=self.languages[src_lang], return_tensors="pt")
            progress(0.6, desc="Generating...")
            audio_array = self.model.generate(**inputs, tgt_lang=self.languages[tgt_lang])[0].cpu().numpy().squeeze()
            progress(1.0, desc="Complete")
            return (self.sample_rate, audio_array)
        except Exception as e:
            raise gr.Error(str(e))

    def translate_audio(self, audio_path, tgt_lang, progress=gr.Progress()):
        if not audio_path:
            raise gr.Error("Please upload an audio file")
        try:
            progress(0.3, desc="Processing...")
            audio, orig_freq = torchaudio.load(audio_path)
            audio = torchaudio.functional.resample(audio, orig_freq=orig_freq, new_freq=16000)
            
            progress(0.6, desc="Translating...")
            inputs = self.processor(audios=audio, return_tensors="pt")
            audio_array = self.model.generate(**inputs, tgt_lang=self.languages[tgt_lang])[0].cpu().numpy().squeeze()
            progress(1.0, desc="Complete")
            return (self.sample_rate, audio_array)
        except Exception as e:
            raise gr.Error(str(e))

css = """
:root {
    --primary-color: #2D3648;
    --secondary-color: #5E6AD2;
    --background-color: #F5F7FF;
    --text-color: #2D3648;
    --border-radius: 12px;
    --spacing: 20px;
}

.gradio-container {
    background-color: var(--background-color) !important;
}

.main-container {
    max-width: 1200px !important;
    margin: 0 auto !important;
    padding: var(--spacing) !important;
}

.app-header {
    text-align: center;
    padding: 40px 20px;
    background: linear-gradient(45deg, var(--primary-color), var(--secondary-color));
    border-radius: var(--border-radius);
    color: white !important;
    margin-bottom: var(--spacing);
}

.app-title {
    font-size: 2.5em;
    font-weight: 700;
    margin-bottom: 10px;
    color: white !important;
}

.app-subtitle {
    font-size: 1.2em;
    opacity: 0.9;
    color: white !important;
}

.content-block {
    background: white;
    padding: var(--spacing);
    border-radius: var(--border-radius);
    box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05);
    margin-bottom: var(--spacing);
}

.gr-button {
    background: var(--secondary-color) !important;
    border: none !important;
    color: white !important;
}

.gr-button:hover {
    box-shadow: 0 4px 10px rgba(94, 106, 210, 0.3) !important;
    transform: translateY(-1px);
}

.gr-input, .gr-select {
    border-radius: 8px !important;
    border: 2px solid #E5E7EB !important;
    padding: 12px !important;
}

.gr-input:focus, .gr-select:focus {
    border-color: var(--secondary-color) !important;
    box-shadow: 0 0 0 3px rgba(94, 106, 210, 0.1) !important;
}

.gr-form {
    background: white !important;
    padding: var(--spacing) !important;
    border-radius: var(--border-radius) !important;
    box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05) !important;
}

.gr-box {
    border-radius: var(--border-radius) !important;
    border: none !important;
    box-shadow: 0 4px 6px rgba(0, 0, 0, 0.05) !important;
}

.footer {
    text-align: center;
    color: var(--text-color);
    padding: var(--spacing);
    opacity: 0.8;
}

/* Custom Tabs Styling */
.tab-nav {
    background: white !important;
    padding: 10px !important;
    border-radius: var(--border-radius) !important;
    margin-bottom: var(--spacing) !important;
}

.tab-nav button {
    border-radius: 8px !important;
    padding: 12px 24px !important;
}

.tab-nav button.selected {
    background: var(--secondary-color) !important;
    color: white !important;
}
"""

def create_ui():
    model = TranslationModel()

    with gr.Blocks(css=css, title="AI Language Translator") as demo:
        gr.HTML(
            """
            <div class="app-header">
                <div class="app-title">AI Language Translator</div>
                <div class="app-subtitle">Powered by Neural Machine Translation</div>
            </div>
            """
        )

        with gr.Tabs():
            # Text Translation Tab
            with gr.Tab("Text to Speech"):
                with gr.Column(variant="panel"):
                    gr.Markdown("### Enter Text")
                    text_input = gr.Textbox(
                        label="",
                        placeholder="Type or paste your text here...",
                        lines=4
                    )
                    
                    with gr.Row():
                        src_lang = gr.Dropdown(
                            choices=sorted(model.languages.keys()),
                            value="English",
                            label="From"
                        )
                        tgt_lang = gr.Dropdown(
                            choices=sorted(model.languages.keys()),
                            value="Spanish",
                            label="To"
                        )
                    
                    translate_btn = gr.Button("Translate", size="lg")
                    
                    gr.Markdown("### Translation Output")
                    audio_output = gr.Audio(
                        label="",
                        type="numpy",
                        show_download_button=True
                    )

            # Audio Translation Tab
            with gr.Tab("Speech to Speech"):
                with gr.Column(variant="panel"):
                    gr.Markdown("### Upload Audio")
                    audio_input = gr.Audio(
                        label="",
                        type="filepath",
                        sources=["upload", "microphone"]
                    )
                    
                    tgt_lang_audio = gr.Dropdown(
                        choices=sorted(model.languages.keys()),
                        value="English",
                        label="Translate to"
                    )
                    
                    translate_audio_btn = gr.Button("Translate Audio", size="lg")
                    
                    gr.Markdown("### Translation Output")
                    audio_output_from_audio = gr.Audio(
                        label="",
                        type="numpy",
                        show_download_button=True
                    )

        gr.HTML(
            """
            <div class="footer">
                Built with ❤️ using Meta's SeamlessM4T and Gradio
            </div>
            """
        )

        # Event handlers
        translate_btn.click(
            fn=model.translate_text,
            inputs=[text_input, src_lang, tgt_lang],
            outputs=audio_output
        )

        translate_audio_btn.click(
            fn=model.translate_audio,
            inputs=[audio_input, tgt_lang_audio],
            outputs=audio_output_from_audio
        )

    return demo

if __name__ == "__main__":
    demo = create_ui()
    demo.queue()
    demo.launch()