import torch
import spaces

import gradio as gr
from threading import Thread
import re
import time
import tempfile
import os

from transformers import pipeline
from transformers.pipelines.audio_utils import ffmpeg_read

from PIL import Image

from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration, TextIteratorStreamer
processor = LlavaNextProcessor.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf")
model = LlavaNextForConditionalGeneration.from_pretrained("llava-hf/llava-v1.6-mistral-7b-hf", torch_dtype=torch.float16, low_cpu_mem_usage=True)
model.to("cuda:0")

ASR_MODEL_NAME = "openai/whisper-large-v3"
ASR_BATCH_SIZE = 8
ASR_CHUNK_LENGTH_S = 30
TEMP_FILE_LIMIT_MB = 1000

from huggingface_hub import InferenceClient
"""
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
"""
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")

device = 0 if torch.cuda.is_available() else "cpu"

asr_pl = pipeline(
    task="automatic-speech-recognition",
    model=ASR_MODEL_NAME,
    chunk_length_s=ASR_CHUNK_LENGTH_S,
    device=device,
)

application_title = "Enlight Innovations Limited -- Demo"
application_description = "This demo is desgined to illustrate our basic idea and feasibility in implementation."

@spaces.GPU
def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "system", "content": system_message}]

    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content

        response += token
        yield response

@spaces.GPU
def transcribe(inputs, task):
    if inputs is None:
        raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.") 
        
    text = asr_pl(asr_inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)["text"]
    return  text


audio_input = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio: from microphone") #gr.Audio(sources="upload", type="filepath", label="Audio: from file")
audio_input_choice = gr.Radio(["audio file", "microphone"], label="Audio") #, value="audio file"
task_input_choice = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")

transcribe_interface = gr.Interface(
    fn=transcribe,
    inputs=[
        audio_input,
        audio_input_choice,
        task_input_choice,
    ],
    outputs="text",
    title=application_title,
    description=application_description,
    allow_flagging="never",
)


"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""

chatbot_sys_output = gr.Textbox(value="You are a friendly Chatbot.", label="System message")
chatbot_max_tokens = gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens")
chatbot_temperature = gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature")
chatbot_top_p = gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.05,
            label="Top-p (nucleus sampling)",
        )

chat_interface = gr.ChatInterface(
    respond,
    title=application_title,
    description=application_description,
    additional_inputs=[
        chatbot_sys_output,
        chatbot_max_tokens,
        chatbot_temperature,
        chatbot_top_p,
    ],
)

with gr.Blocks() as demo:
    gr.TabbedInterface([transcribe_interface, chat_interface], ["Step 1: Transcribe", "Step 2: Extract"])
    
    def update_audio_input(audio_input_choice):
        if audio_input_choice == "audio file":
            return gr.Audio(sources="upload", type="filepath", label="Audio: from file")
        elif audio_input_choice == "microphone":
            return gr.Audio(sources="microphone", type="filepath", label="Audio: from microphone")
    audio_input_choice.input(fn=update_audio_input, 
                             inputs=audio_input_choice, 
                             outputs=audio_input
                            )


if __name__ == "__main__":
    demo.queue().launch() #demo.launch()