import gradio as gr
import spaces
import os
import time
from PIL import Image
from models.mllava import MLlavaProcessor, LlavaForConditionalGeneration, chat_mllava, MLlavaForConditionalGeneration
from typing import List
processor = MLlavaProcessor.from_pretrained("TIGER-Lab/Mantis-llava-7b-v1.1")
model = LlavaForConditionalGeneration.from_pretrained("TIGER-Lab/Mantis-llava-7b-v1.1")

@spaces.GPU
def generate(text:str, images:List[Image.Image], history: List[dict], **kwargs):
    global processor, model
    model = model.to("cuda")
    if not images:
        images = None
    for text, history in chat_mllava(text, images, model, processor, history=history, stream=True, **kwargs):
        yield text

    return text

def enable_next_image(uploaded_images, image):
    uploaded_images.append(image)
    return uploaded_images, gr.MultimodalTextbox(value=None, interactive=False)

def add_message(history, message):
    if message["files"]:
        for file in message["files"]:
            history.append([(file,), None])
    if message["text"]:
        history.append([message["text"], None])
    return history, gr.MultimodalTextbox(value=None)

def print_like_dislike(x: gr.LikeData):
    print(x.index, x.value, x.liked)


def get_chat_history(history):
    chat_history = []
    for i, message in enumerate(history):
        if isinstance(message[0], str):
            chat_history.append({"role": "user", "text": message[0]})
            if i != len(history) - 1:
                assert message[1], "The bot message is not provided, internal error"
                chat_history.append({"role": "assistant", "text": message[1]})
            else:
                assert not message[1], "the bot message internal error, get: {}".format(message[1])
                chat_history.append({"role": "assistant", "text": ""})
    return chat_history

def get_chat_images(history):
    images = []
    for message in history:
        if isinstance(message[0], tuple):
            images.extend(message[0])
    return images
            
def bot(history):
    print(history)
    cur_messages = {"text": "", "images": []}
    for message in history[::-1]:
        if message[1]:
            break
        if isinstance(message[0], str):
            cur_messages["text"] = message[0] + " " + cur_messages["text"]
        elif isinstance(message[0], tuple):
            cur_messages["images"].extend(message[0])
    cur_messages["text"] = cur_messages["text"].strip()
    cur_messages["images"] = cur_messages["images"][::-1]
    if not cur_messages["text"]:
        raise gr.Error("Please enter a message")
    if cur_messages['text'].count("<image>") < len(cur_messages['images']):
        gr.Warning("The number of images uploaded is more than the number of <image> placeholders in the text. Will automatically prepend <image> to the text.")
        cur_messages['text'] = "<image> "* (len(cur_messages['images']) - cur_messages['text'].count("<image>")) + cur_messages['text']
        history[-1][0] = cur_messages["text"]
    if cur_messages['text'].count("<image>") > len(cur_messages['images']):
        gr.Warning("The number of images uploaded is less than the number of <image> placeholders in the text. Will automatically remove extra <image> placeholders from the text.")
        cur_messages['text'] = cur_messages['text'][::-1].replace("<image>"[::-1], "", cur_messages['text'].count("<image>") - len(cur_messages['images']))[::-1]
        history[-1][0] = cur_messages["text"]
    
    chat_history = get_chat_history(history)
    chat_images = get_chat_images(history)
    generation_kwargs = {
        "max_new_tokens": 4096,
        "temperature": 0.2,
        "top_p": 1.0,
        "do_sample": True,
    }
    print(None, chat_images, chat_history, generation_kwargs)
    response = generate(None, chat_images, chat_history, **generation_kwargs)
            
    for _output in response:
        history[-1][1] = _output
        time.sleep(0.05)
        yield history
        
def build_demo():
    with gr.Blocks() as demo:
        
        with gr.Row():
            with gr.Column():
                gr.Markdown(""" # Mantis
        Mantis is a multimodal conversational AI model that can chat with users about images and text. It's optimized for multi-image reasoning, where inverleaved text and images can be used to generate responses.

        | [Github](https://github.com/TIGER-AI-Lab/Mantis) | [Blog](https://tiger-ai-lab.github.io/Blog/mantis) | [Models](https://huggingface.co/collections/TIGER-Lab/mantis-6619b0834594c878cdb1d6e4) |                   
                """)
                # gr.Image("./barchart_single_image_vqa.jpeg")
            with gr.Column():
                gr.Image("./barchart.jpeg")
        
        gr.Markdown("""## Chat with Mantis
        Mantis supports interleaved text-image input format, where you can simply use the placeholder `<image>` to indicate the position of uploaded images.
        The model is optimized for multi-image reasoning, while preserving the ability to chat about text and images in a single conversation.
        (The model currently serving is [Mantis-bakllava-7b](https://huggingface.co/TIGER-Lab/Mantis-bakllava-7b))
        """)
        
        chatbot = gr.Chatbot(line_breaks=True)
        chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload images. Please use <image> to indicate the position of uploaded images", show_label=True)
        
        chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
        bot_msg = chat_msg.success(bot, chatbot, chatbot, api_name="bot_response")
        
        chatbot.like(print_like_dislike, None, None)
        
        with gr.Row():
            send_button = gr.Button("Send")
            clear_button = gr.ClearButton([chatbot, chat_input])
        
        send_button.click(
            add_message, [chatbot, chat_input], [chatbot, chat_input]
        ).then(
            bot, chatbot, chatbot, api_name="bot_response"
        )
        
        gr.Examples(
            examples=[
                {
                    "text": "<image> <image> How many dices are there in image 1 and image 2 respectively?",
                    "files": ["./examples/image10.jpg", "./examples/image11.jpg"]
                },
                {
                    "text": "<image> <image> <image> Which image shows a different mood of character from the others?",
                    "files": ["./examples/image12.jpg", "./examples/image13.jpg", "./examples/image14.jpg"]
                },
                {
                    "text": "<image> <image> What's the difference between these two images? Please describe as much as you can.", 
                    "files": ["./examples/image1.jpg", "./examples/image2.jpg"]
                },
                {
                    "text": "<image> <image> Which image shows an older dog?",
                    "files": ["./examples/image8.jpg", "./examples/image9.jpg"]   
                },
                {
                    "text": "Write a description for the given image sequence in a single paragraph, what is happening in this episode?", 
                    "files": ["./examples/image3.jpg", "./examples/image4.jpg", "./examples/image5.jpg", "./examples/image6.jpg", "./examples/image7.jpg"]
                },
            ],
            inputs=[chat_input],
        )        
        
    return demo    
    

if __name__ == "__main__":
    demo = build_demo()
    demo.launch()