import gradio as gr
import edge_tts
import asyncio
import tempfile
import os
from huggingface_hub import InferenceClient
import re
from streaming_stt_nemo import Model
import torch
import random
from openai import OpenAI
import subprocess

default_lang = "en"

engines = { default_lang: Model(default_lang) }

def transcribe(audio):
    if audio is None:
        return ""
    lang = "en"
    model = engines[lang]
    text = model.stt_file(audio)[0]
    return text

HF_TOKEN = os.environ.get("HF_TOKEN", None)

def client_fn(model):
    if "Llama 3 8B Service" in model:
        return OpenAI(
            base_url="http://52.76.81.56:60002/v1",
            api_key="token-abc123"
        )
    elif "Llama" in model:
        return InferenceClient("meta-llama/Meta-Llama-3-8B-Instruct")
    elif "Mistral" in model:
        return InferenceClient("mistralai/Mistral-7B-Instruct-v0.2")
    elif "Phi" in model:
        return InferenceClient("microsoft/Phi-3-mini-4k-instruct")
    elif "Mixtral" in model:
        return InferenceClient("mistralai/Mixtral-8x7B-Instruct-v0.1")
    else: 
        return InferenceClient("microsoft/Phi-3-mini-4k-instruct")

def randomize_seed_fn(seed: int) -> int:
    seed = random.randint(0, 999999)
    return seed

system_instructions1 = """
[SYSTEM] You are OPTIMUS Prime a personal AI voice assistant, Created by Jaward.
Keep conversation friendly, short, clear, and concise. 
Avoid unnecessary introductions and answer the user's questions directly. 
Respond in a normal, conversational manner while being friendly and helpful.
Remember previous parts of the conversation and use that context in your responses.
Your creator Jaward is an AI/ML Research Engineer at Linksoul AI. He is currently specializing in Artificial Intelligence (AI) research more specifically training and optimizing advance AI systems. He aspires to build not just human-like intelligence but AI Systems that augment human intelligence. He has contributed greatly to the opensource community with first-principles code implementations of AI/ML research papers. He did his first internship at Beijing Academy of Artificial Intelligence as an AI Researher where he contributed in cutting-edge AI research leading to him contributing to an insightful paper (AUTOAGENTS - A FRAMEWORK FOR AUTOMATIC AGENT GENERATION). The paper got accepted this year at IJCAI(International Joint Conference On AI). He is currently doing internship at LinkSoul AI - a small opensource AI Research startup in Beijing.
[USER]
"""

conversation_history = []

def models(text, model="Llama 3B Service", seed=42):
    global conversation_history
    seed = int(randomize_seed_fn(seed))
    generator = torch.Generator().manual_seed(seed)  
    
    client = client_fn(model)
    
    if "Llama 3 8B Service" in model:
        messages = [
            {"role": "system", "content": system_instructions1},
        ] + conversation_history + [
            {"role": "user", "content": text}
        ]
        completion = client.chat.completions.create(
            model="/data/shared/huggingface/hub/models--meta-llama--Meta-Llama-3-8B-Instruct/snapshots/c4a54320a52ed5f88b7a2f84496903ea4ff07b45/",
            messages=messages
        )
        assistant_response = completion.choices[0].message.content
        
        # Update conversation history
        conversation_history.append({"role": "user", "content": text})
        conversation_history.append({"role": "assistant", "content": assistant_response})
        
        # Keep only the last 10 messages to avoid token limit issues
        if len(conversation_history) > 20:
            conversation_history = conversation_history[-20:]
        
        return assistant_response
    else:
        # For other models, we'll concatenate the conversation history into a single string
        history_text = "\n".join([f"{'User' if msg['role'] == 'user' else 'Assistant'}: {msg['content']}" for msg in conversation_history])
        formatted_prompt = f"{system_instructions1}\n\nConversation history:\n{history_text}\n\nUser: {text}\nOPTIMUS:"
        
        generate_kwargs = dict(
            max_new_tokens=300,
            seed=seed
        )    
        stream = client.text_generation(
            formatted_prompt, **generate_kwargs, stream=True, details=True, return_full_text=False)
        output = ""
        for response in stream:
            if not response.token.text == "</s>":
                output += response.token.text
        
        # Update conversation history
        conversation_history.append({"role": "user", "content": text})
        conversation_history.append({"role": "assistant", "content": output})
        
        # Keep only the last 10 messages to avoid token limit issues
        if len(conversation_history) > 20:
            conversation_history = conversation_history[-20:]
        
        return output

async def respond(audio, model, seed):
    if audio is None:
        return None
    user = transcribe(audio)
    if not user:
        return None
    reply = models(user, model, seed)
    communicate = edge_tts.Communicate(reply, voice="en-US-ChristopherNeural")
    with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_file:
        tmp_path = tmp_file.name
        await communicate.save(tmp_path)
    return tmp_path

# Supported languages for seamless-expressive
LANGUAGE_CODES = {
    "English": "eng",
    "Spanish": "spa",
    "French": "fra",
    "German": "deu",
    "Italian": "ita",
    "Chinese": "cmn"
}

def translate_speech(audio_file, target_language):
    """
    Translate input speech (audio file) to the specified target language.
    """
    if audio_file is None:
        return None
    
    language_code = LANGUAGE_CODES[target_language]
    output_file = "translated_audio.wav"
    
    command = [
        "expressivity_predict",
        audio_file,
        "--tgt_lang", language_code,
        "--model_name", "seamless_expressivity",
        "--vocoder_name", "vocoder_pretssel",
        "--gated-model-dir", "seamlessmodel",
        "--output_path", output_file
    ]
    
    subprocess.run(command, check=True)

    if os.path.exists(output_file):
        print(f"File created successfully: {output_file}")
        return output_file
    else:
        print(f"File not found: {output_file}")
        return None

def clear_history():
    global conversation_history
    conversation_history = []
    return None, None, None, None

def voice_assistant_tab():
    return "# <center><b>Hello, I am Optimus Prime your personal AI voice assistant</b></center>"

def speech_translation_tab():
    return "# <center><b>Hear how you sound in another language</b></center>"

with gr.Blocks(css="style.css") as demo:
    description = gr.Markdown("# <center><b>Hello, I am Optimus Prime your personal AI voice assistant</b></center>")
    
    with gr.Tabs() as tabs:
        with gr.TabItem("Voice Assistant") as voice_assistant:
            select = gr.Dropdown([
                'Llama 3 8B Service',
                'Mixtral 8x7B',
                'Llama 3 8B',
                'Mistral 7B v0.3',
                'Phi 3 mini',
            ],
            value="Llama 3 8B Service",
            label="Model"
            )
            seed = gr.Slider(
            label="Seed",
            minimum=0,
            maximum=999999,
            step=1,
            value=0,
            visible=False
            )
            input = gr.Audio(label="User", sources="microphone", type="filepath", waveform_options=False)
            output = gr.Audio(label="AI", type="filepath",
                            interactive=False,
                            autoplay=True,
                            elem_classes="audio")

            gr.Interface(
                fn=respond, 
                inputs=[input, select, seed],
                outputs=[output],
                live=True
            )
        
        with gr.TabItem("Speech Translation") as speech_translation:
            input_audio = gr.Audio(label="User", sources="microphone", type="filepath", waveform_options=False)
            target_lang = gr.Dropdown(
                choices=list(LANGUAGE_CODES.keys()),
                value="Spanish",
                label="Target Language"
            )
            output_audio = gr.Audio(label="Translated Audio",
                                    interactive=False,
                                    autoplay=True,
                                    elem_classes="audio")
            
            gr.Interface(
                fn=translate_speech,
                inputs=[input_audio, target_lang],
                outputs=[output_audio],
                live=True
            )

    # clear_button = gr.Button("Clear")
    # clear_button.click(
    #     fn=clear_history,
    #     inputs=[],
    #     outputs=[input, output, input_audio, output_audio],
    #     api_name="clear"
    # )

    voice_assistant.select(fn=voice_assistant_tab, inputs=None, outputs=description)
    speech_translation.select(fn=speech_translation_tab, inputs=None, outputs=description)

if __name__ == "__main__":
    demo.queue(max_size=200).launch()