import os
from unsloth import FastVisionModel
import torch
from PIL import Image
from datasets import load_dataset
from transformers import TextStreamer
import matplotlib.pyplot as plt

import gradio as gr

# Load the model
model, tokenizer = FastVisionModel.from_pretrained(
    "0llheaven/Llama-3.2-11B-Vision-Radiology-mini",
    load_in_4bit=True,
    use_gradient_checkpointing="unsloth",
).to("cpu")

# เปลี่ยนโหมดของโมเดลเป็นสำหรับ inference
FastVisionModel.for_inference(model)

# ตัวแปรสำหรับแคช
cached_image = None
cached_response = None

# ฟังก์ชันประมวลผลภาพและสร้างคำอธิบาย
def predict_radiology_description(image, instruction):
    global cached_image, cached_response

    try:

        current_image_tensor = torch.tensor(image.getdata())

        # ตรวจสอบว่าภาพเหมือนเดิมและข้อความเหมือนเดิมหรือไม่
        if cached_image is not None and torch.equal(cached_image, current_image_tensor):
            # ใช้ cached_response กับ text ใหม่
            return cached_response

        # เตรียมข้อความในรูปแบบที่โมเดลรองรับ
        messages = [{"role": "user", "content": [
            {"type": "image"},
            {"type": "text", "text": instruction}
        ]}]
        input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)

        # เตรียม input สำหรับโมเดล
        inputs = tokenizer(
            image,
            input_text,
            add_special_tokens=False,
            return_tensors="pt",
        ).to("cpu")

        # ใช้ TextStreamer สำหรับการพยากรณ์
        text_streamer = TextStreamer(tokenizer, skip_prompt=True)

        # ทำนายข้อความ
        output_ids = model.generate(
            **inputs,
            streamer=text_streamer,
            max_new_tokens=256,
            use_cache=True,
            temperature=1.5,
            min_p=0.1
        )

        # แปลงข้อความที่สร้างเป็นผลลัพธ์
        generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)

        cached_image = current_image_tensor  # แคชภาพเป็น Tensor
        cached_response = generated_text.replace("assistant", "\n\nAssistant").strip()
        return cached_response
    except Exception as e:
        return f"Error: {str(e)}"

# ฟังก์ชัน ChatBot
def chat_process(image, instruction, history=None):
    if history is None:
        history = []

    # ประมวลผลภาพและคำสั่ง
    response = predict_radiology_description(image, instruction)

    # อัปเดตประวัติ
    history.append((instruction, response))
    return history, history

import warnings
warnings.filterwarnings("ignore", category=UserWarning, module="gradio.helpers")

# UI ของ Gradio
with gr.Blocks() as demo:
    gr.Markdown("# 🩻 Radiology Image ChatBot")
    gr.Markdown("Upload a radiology image and provide an instruction for the AI to describe the findings.")
    gr.Markdown("Example instruction : You are an expert radiographer. Describe accurately what you see in this image.")

    with gr.Row():
        with gr.Column():
            # อัปโหลดรูปภาพ
            image_input = gr.Image(type="pil", label="Upload Radiology Image")
            # ป้อนคำสั่ง (instruction)
            instruction_input = gr.Textbox(
                label="Instruction",
                value="You are an expert radiographer. Describe accurately what you see in this image.",
                placeholder="Provide specific instructions..."
            )
        with gr.Column():
            # แสดงประวัติ Chat
            chatbot = gr.Chatbot(label="Chat History")

    with gr.Row():
        clear_btn = gr.Button("Clear")
        submit_btn = gr.Button("Submit")

    # การทำงานของปุ่ม Submit พร้อมล้างเฉพาะข้อความใน instruction_input
    submit_btn.click(
        lambda image, instruction, history: (
            *chat_process(image, instruction, history),
            image,  # รีเซ็ตค่า image_input
            ""
        ),
        inputs=[image_input, instruction_input, chatbot],
        outputs=[chatbot, chatbot, image_input, instruction_input]
    )

    # การทำงานของปุ่ม Clear
    clear_btn.click(
        lambda: (None, None, None, None),
        inputs=[],
        outputs=[chatbot, chatbot, image_input, instruction_input]
    )

# รันแอป
demo.launch(debug=True)