import torch
import gradio as gr
from flash_vstream.serve.demo import Chat, title_markdown, block_css
from flash_vstream.constants import *
from flash_vstream.conversation import conv_templates, Conversation
import os
from PIL import Image
import tempfile
import imageio
import shutil


model_path = "IVGSZ/Flash-VStream-7b"
load_8bit = False
load_4bit = False

def save_image_to_local(image):
    filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.jpg')
    image = Image.open(image)
    image.save(filename)
    return filename


def save_video_to_local(video_path):
    filename = os.path.join('temp', next(tempfile._get_candidate_names()) + '.mp4')
    shutil.copyfile(video_path, filename)
    return filename


def generate(video, textbox_in, first_run, state, state_, images_tensor):

    flag = 1
    if not textbox_in:
        if len(state_.messages) > 0:
            textbox_in = state_.messages[-1][1]
            state_.messages.pop(-1)
            flag = 0
        else:
            return "Please enter instruction"

    video = video if video else "none"

    if type(state) is not Conversation:
        state = conv_templates[conv_mode].copy()
        state_ = conv_templates[conv_mode].copy()
        images_tensor = []

    first_run = False if len(state.messages) > 0 else True

    text_en_in = textbox_in.replace("picture", "image")

    image_processor = handler.image_processor

    if os.path.exists(video):
        video_tensor = handler._get_rawvideo_dec(video, image_processor, max_frames=MAX_IMAGE_LENGTH)
        images_tensor = image_processor(video_tensor, return_tensors='pt')['pixel_values'].to(handler.model.device, dtype=torch.float16)
        print("video_tensor", video_tensor.shape)

    if os.path.exists(video):
        text_en_in = DEFAULT_IMAGE_TOKEN + '\n' + text_en_in

    text_en_out, state_ = handler.generate(images_tensor, text_en_in, first_run=first_run, state=state_)
    state_.messages[-1] = (state_.roles[1], text_en_out)

    text_en_out = text_en_out.split('#')[0]
    textbox_out = text_en_out

    show_images = ""
    if os.path.exists(video):
        filename = save_video_to_local(video)
        show_images += f'<video controls playsinline width="500" style="display: inline-block;"  src="./file={filename}"></video>'

    if flag:
        state.append_message(state.roles[0], textbox_in + "\n" + show_images)
    state.append_message(state.roles[1], textbox_out)

    return (state, state_, state.to_gradio_chatbot(), False, gr.update(value=None, interactive=True), images_tensor, gr.update(value=None, interactive=True))


def regenerate(state, state_):
    state.messages.pop(-1)
    state_.messages.pop(-1)
    if len(state.messages) > 0:
        return state, state_, state.to_gradio_chatbot(), False
    return (state, state_, state.to_gradio_chatbot(), True)


def clear_history(state, state_):
    state = conv_templates[conv_mode].copy()
    state_ = conv_templates[conv_mode].copy()
    return (gr.update(value=None, interactive=True), \
        gr.update(value=None, interactive=True),\
        True, state, state_, state.to_gradio_chatbot(), [])


conv_mode = "vicuna_v1"
handler = Chat(model_path, conv_mode=conv_mode, load_4bit=load_4bit, load_8bit=load_8bit)
if not os.path.exists("temp"):
    os.makedirs("temp")

print(torch.cuda.memory_allocated())
print(torch.cuda.max_memory_allocated())

with gr.Blocks(title='Flash-VStream', theme=gr.themes.Soft(), css=block_css) as demo:
    gr.Markdown(title_markdown)
    state = gr.State()
    state_ = gr.State()
    first_run = gr.State()
    images_tensor = gr.State()

    with gr.Row():
        with gr.Column(scale=3):
            video = gr.Video(label="Input Video")

        with gr.Column(scale=7):
            chatbot = gr.Chatbot(label="Flash-VStream", bubble_full_width=True).style(height=700)
            with gr.Row():
                with gr.Column(scale=8):
                    textbox = gr.Textbox(show_label=False,
                                         placeholder="Enter text and press Send",
                                         container=False)
                with gr.Column(scale=2, min_width=50):
                    submit_btn = gr.Button(value="Send", variant="primary", interactive=True)

    with gr.Row(visible=True) as button_row:
        regenerate_btn = gr.Button(value="🔄  Regenerate", interactive=True)
        clear_btn = gr.Button(value="🗑️  Clear history", interactive=True)
   
    cur_dir = os.path.dirname(os.path.abspath(__file__))
        
    with gr.Row():
        gr.Examples(
            examples=[
                [
                    f"{cur_dir}/examples/video1.mp4",
                    "Describe the video briefly.",
                ]
            ],
            inputs=[video, textbox],
        )

        gr.Examples(
            examples=[
                [
                    f"{cur_dir}/examples/video4.mp4",
                    "What is the boy doing?",
                ]
            ],
            inputs=[video, textbox],
        )

        gr.Examples(
            examples=[
                [
                    f"{cur_dir}/examples/video5.mp4",
                    "Why is this video funny?",
                ]
            ],
            inputs=[video, textbox],
        )

    submit_btn.click(generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])

    regenerate_btn.click(regenerate, [state, state_], [state, state_, chatbot, first_run]).then(
        generate, [video, textbox, first_run, state, state_, images_tensor], [state, state_, chatbot, first_run, textbox, images_tensor, video])

    clear_btn.click(clear_history, [state, state_],
                    [video, textbox, first_run, state, state_, chatbot, images_tensor])


# app = gr.mount_gradio_app(app, demo, path="/")
demo.launch()