import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import os
from threading import Thread

MODEL_ID = "THUDM/glm-4v-9b"

TITLE = f'<br><center>🚀 Coin Generative Recognition</a></center>'

DESCRIPTION = f"""
<center>
<p>
A Space for Vision/Multimodal
<br>
<br>
✨ Tips: Send messages or upload multiple IMAGES at a time.
<br>
✨ Tips: Please increase MAX LENGTH when dealing with files.
<br>
🤙 Supported Format: png, jpg, webp
<br>
🙇‍♂️ May be rebuilding from time to time.
</p>
</center>"""

CSS = """
h1 {
    text-align: center;
    display: block;
}
img {
    max-width: 100%; /* Make sure images are not wider than their container */
    height: auto; /* Maintain aspect ratio */
    max-height: 300px; /* Limit the height of images */
}
"""

# Load model directly

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).to(0)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model.eval()


def merge_images(paths):
    images = [Image.open(path).convert('RGB') for path in paths]
    widths, heights = zip(*(i.size for i in images))
    total_width = sum(widths)
    max_height = max(heights)
    new_im = Image.new('RGB', (total_width, max_height))
    x_offset = 0
    for im in images:
        new_im.paste(im, (x_offset,0))
        x_offset += im.width
    return new_im

def mode_load(paths):
    if all(path.lower().endswith(('png', 'jpg', 'jpeg', 'webp')) for path in paths):
        content = merge_images(paths)
        choice = "image"
        return choice, content
    else:
        raise gr.Error("Unsupported file types. Please upload only images.")

@spaces.GPU()
def stream_chat(message, history: list, temperature: float, max_length: int, top_p: float, top_k: int, penalty: float):
    conversation = []
    if message["files"]:
        choice, contents = mode_load(message["files"])
        conversation.append({"role": "user", "image": contents, "content": message['text']})
    elif message["files"] and len(message["files"]) == 1:
        content = Image.open( message["files"][-1]).convert('RGB')
        choice = "image"
        conversation.append({"role": "user", "image": content, "content": message['text']})
    else:
        raise gr.Error("Please upload one or more images.")
    input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)
    streamer = TextIteratorStreamer(tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        max_length=max_length,
        streamer=streamer,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        repetition_penalty=penalty,
        eos_token_id=[151329, 151336, 151338],
    )
    gen_kwargs = {**input_ids, **generate_kwargs}
    with torch.no_grad():
        thread = Thread(target=model.generate, kwargs=gen_kwargs)
        thread.start()
        buffer = ""
        for new_text in streamer:
            buffer += new_text
            yield buffer

chatbot = gr.Chatbot(label="Chatbox", height=600, placeholder=DESCRIPTION)
chat_input = gr.MultimodalTextbox(
    interactive=True,
    placeholder="Enter message or upload images...",
    show_label=False,
    file_count="multiple",
)

EXAMPLES = [
    [{"text": "Give me Country,Denomination and year as json format.", "files": ["./135_back.jpg", "./135_front.jpg"]}],
    [{"text": "Give me Country,Denomination and year as json format.", "files": ["./141_back.jpg","./141_front.jpg"]}]
]

with gr.Blocks(css=CSS, theme="soft", fill_height=True) as demo:
    gr.HTML(TITLE)
    gr.ChatInterface(
        fn=stream_chat,
        multimodal=True,
        textbox=chat_input,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
        additional_inputs=[
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.8,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=1024,
                maximum=8192,
                step=1,
                value=4096,
                label="Max Length",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=1.0,
                step=0.1,
                value=1.0,
                label="top_p",
                render=False,
            ),
            gr.Slider(
                minimum=1,
                maximum=20,
                step=1,
                value=10,
                label="top_k",
                render=False,
            ),
            gr.Slider(
                minimum=0.0,
                maximum=2.0,
                step=0.1,
                value=1.0,
                label="Repetition penalty",
                render=False,
            ),
        ],
    ),
    gr.Examples(EXAMPLES, [chat_input])

if __name__ == "__main__":
    demo.queue(api_open=False).launch(show_api=False, share=False)