import gradio as gr

from vid2persona import init
from vid2persona.pipeline import vlm
from vid2persona.pipeline import llm

init.init_model("HuggingFaceH4/zephyr-7b-beta")
init.auth_gcp()
init.get_env_vars()
prompt_tpl_path = "vid2persona/prompts"

async def extract_traits(video_path):
    traits = await vlm.get_traits(
        init.gcp_project_id, 
        init.gcp_project_location, 
        video_path,
        prompt_tpl_path
    )
    if 'characters' in traits:
        traits = traits['characters'][0]

    return [
        traits, [], 
        gr.Textbox("", interactive=True),
        gr.Button(interactive=True),
        gr.Button(interactive=True),
        gr.Button(interactive=True)
    ]

async def conversation(
    message: str, messages: list, traits: dict,
    model_id: str, max_input_token_length: int, 
    max_new_tokens: int, temperature: float, 
    top_p: float, top_k: float, repetition_penalty: float, 
):
    messages = messages + [[message, ""]]
    yield [messages, message, gr.Button(interactive=False), gr.Button(interactive=False)]

    async for partial_response in llm.chat(
        message, messages, traits,
        prompt_tpl_path, model_id, 
        max_input_token_length, max_new_tokens,
        temperature, top_p, top_k, 
        repetition_penalty, hf_token=init.hf_access_token
    ):
        last_message = messages[-1]
        last_message[1] = last_message[1] + partial_response
        messages[-1] = last_message
        yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)]

    yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)]

async def regen_conversation(
    messages: list, traits: dict,
    model_id: str, max_input_token_length: int, 
    max_new_tokens: int, temperature: float, 
    top_p: float, top_k: float, repetition_penalty: float, 
):
    if len(messages) > 0:
        message = messages[-1][0]
        messages = messages[:-1]
        messages = messages + [[message, ""]]
        yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)]

        async for partial_response in llm.chat(
            message, messages, traits,
            prompt_tpl_path, model_id, 
            max_input_token_length, max_new_tokens,
            temperature, top_p, top_k, 
            repetition_penalty, hf_token=init.hf_access_token
        ):
            last_message = messages[-1]
            last_message[1] = last_message[1] + partial_response
            messages[-1] = last_message
            yield [messages, "", gr.Button(interactive=False), gr.Button(interactive=False)]

        yield [messages, "", gr.Button(interactive=True), gr.Button(interactive=True)]

with gr.Blocks(css="styles.css", theme=gr.themes.Soft()) as demo:
    gr.Markdown("Vid2Persona", elem_classes=["md-center", "h1-font"])
    
    gr.Markdown("This project breathes life into video characters by using AI to describe their personality and then chat with you as them. "
                "[Gemini 1.0 Pro Vision model on Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/overview) is used "
                "to grasp traits of video characters, then [HuggingFaceH4/zephyr-7b-beta](https://huggingface.co/HuggingFaceH4/zephyr-7b-beta) model "
                "is used to make conversation with them.",)

    gr.Markdown("This space is modified to be working on Hugging Face [ZeroGPU](https://huggingface.co/zero-gpu-explorers). If you wish to run "
                "the same application on your own machine, please check out the [project repository](https://github.com/deep-diver/Vid2Persona). "
                "You can interact with other LLMs to make conversation besides HuggingFaceH4/zephyr-7b-beta by running them locally, or by "
                "connecting them through remotely hosted within Text Generation Inference framework as [Hugging Face PRO](https://huggingface.co/blog/inference-pro) user.")

    with gr.Column(elem_classes=["group"]):
        with gr.Row():
            video = gr.Video(label="upload short video clip", max_length=180)
            traits = gr.Json(label="extracted traits")
        
        with gr.Row():
            trait_gen = gr.Button("generate  traits")

    with gr.Column(elem_classes=["group"]):
        chatbot = gr.Chatbot([], label="chatbot", elem_id="chatbot", elem_classes=["chatbot-no-label"])
        with gr.Row():
            clear = gr.Button("clear conversation", interactive=False)
            regen = gr.Button("regenerate the last", interactive=False)
            stop = gr.Button("stop", interactive=False) 
        user_input = gr.Textbox(placeholder="ask anything", interactive=False, elem_classes=["textbox-no-label", "textbox-no-top-bottom-borders"])

        with gr.Accordion("parameters' control pane", open=False):
            model_id = gr.Dropdown(choices=init.ALLOWED_LLM_FOR_HF_PRO_ACCOUNTS, value="HuggingFaceH4/zephyr-7b-beta", label="Model ID", visible=False)

            with gr.Row():
                max_input_token_length = gr.Slider(minimum=1024, maximum=4096, value=4096, label="max-input-tokens")
                max_new_tokens = gr.Slider(minimum=128, maximum=2048, value=256, label="max-new-tokens")

            with gr.Row():
                temperature = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.6, label="temperature")
                top_p = gr.Slider(minimum=0, maximum=2, step=0.1, value=0.9, label="top-p")
                top_k = gr.Slider(minimum=0, maximum=2, step=0.1, value=50, label="top-k")
                repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.1, value=1.2, label="repetition-penalty")
    
    with gr.Row():
        gr.Markdown(
            "[![GitHub Repo](https://img.shields.io/badge/GitHub%20Repo-gray?style=for-the-badge&logo=github&link=https://github.com/deep-diver/Vid2Persona)](https://github.com/deep-diver/Vid2Persona) "
            "[![Chansung](https://img.shields.io/badge/Chansung-blue?style=for-the-badge&logo=twitter&link=https://twitter.com/algo_diver)](https://twitter.com/algo_diver) "
            "[![Sayak](https://img.shields.io/badge/Sayak-blue?style=for-the-badge&logo=twitter&link=https://twitter.com/RisingSayak)](https://twitter.com/RisingSayak )",
            elem_id="bottom-md"
        )

    trait_gen.click(
        extract_traits,
        [video],
        [traits, chatbot, user_input, clear, regen, stop],
        concurrency_limit=5,
    )

    conv = user_input.submit(
        conversation,
        [
            user_input, chatbot, traits,
            model_id, max_input_token_length, 
            max_new_tokens, temperature, 
            top_p, top_k, repetition_penalty,
        ],
        [chatbot, user_input, clear, regen],
        concurrency_limit=5,
    )

    clear.click(
        lambda: [
            gr.Chatbot([]),
            gr.Button(interactive=False),
            gr.Button(interactive=False),
        ],
        None, [chatbot, clear, regen],
        concurrency_limit=5,
    )

    conv_regen = regen.click(
        regen_conversation,
        [
            chatbot, traits,
            model_id, max_input_token_length, 
            max_new_tokens, temperature, 
            top_p, top_k, repetition_penalty, 
        ],
        [chatbot, user_input, clear, regen],
        concurrency_limit=5,
    )

    stop.click(
        lambda: [
            gr.Button(interactive=True),
            gr.Button(interactive=True),
            gr.Button(interactive=True),
        ], None, [clear, regen, stop], 
        cancels=[conv, conv_regen],
        concurrency_limit=5,
    )

    gr.Examples(
        [["assets/sample1.mp4"], ["assets/sample2.mp4"], ["assets/sample3.mp4"], ["assets/sample4.mp4"]],
        video,
        [traits, chatbot, user_input, clear, regen, stop],
        extract_traits,
        cache_examples=True
    )

demo.queue(
    max_size=256
).launch(
    debug=True
)