import os from unsloth import FastVisionModel import torch from PIL import Image from datasets import load_dataset from transformers import TextStreamer import matplotlib.pyplot as plt import gradio as gr # Load the model model, tokenizer = FastVisionModel.from_pretrained( "0llheaven/Llama-3.2-11B-Vision-Radiology-mini", load_in_4bit=True, use_gradient_checkpointing="unsloth", ).to("cpu") # เปลี่ยนโหมดของโมเดลเป็นสำหรับ inference FastVisionModel.for_inference(model) # ตัวแปรสำหรับแคช cached_image = None cached_response = None # ฟังก์ชันประมวลผลภาพและสร้างคำอธิบาย def predict_radiology_description(image, instruction): global cached_image, cached_response try: current_image_tensor = torch.tensor(image.getdata()) # ตรวจสอบว่าภาพเหมือนเดิมและข้อความเหมือนเดิมหรือไม่ if cached_image is not None and torch.equal(cached_image, current_image_tensor): # ใช้ cached_response กับ text ใหม่ return cached_response # เตรียมข้อความในรูปแบบที่โมเดลรองรับ messages = [{"role": "user", "content": [ {"type": "image"}, {"type": "text", "text": instruction} ]}] input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True) # เตรียม input สำหรับโมเดล inputs = tokenizer( image, input_text, add_special_tokens=False, return_tensors="pt", ).to("cpu") # ใช้ TextStreamer สำหรับการพยากรณ์ text_streamer = TextStreamer(tokenizer, skip_prompt=True) # ทำนายข้อความ output_ids = model.generate( **inputs, streamer=text_streamer, max_new_tokens=256, use_cache=True, temperature=1.5, min_p=0.1 ) # แปลงข้อความที่สร้างเป็นผลลัพธ์ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) cached_image = current_image_tensor # แคชภาพเป็น Tensor cached_response = generated_text.replace("assistant", "\n\nAssistant").strip() return cached_response except Exception as e: return f"Error: {str(e)}" # ฟังก์ชัน ChatBot def chat_process(image, instruction, history=None): if history is None: history = [] # ประมวลผลภาพและคำสั่ง response = predict_radiology_description(image, instruction) # อัปเดตประวัติ history.append((instruction, response)) return history, history import warnings warnings.filterwarnings("ignore", category=UserWarning, module="gradio.helpers") # UI ของ Gradio with gr.Blocks() as demo: gr.Markdown("# 🩻 Radiology Image ChatBot") gr.Markdown("Upload a radiology image and provide an instruction for the AI to describe the findings.") gr.Markdown("Example instruction : You are an expert radiographer. Describe accurately what you see in this image.") with gr.Row(): with gr.Column(): # อัปโหลดรูปภาพ image_input = gr.Image(type="pil", label="Upload Radiology Image") # ป้อนคำสั่ง (instruction) instruction_input = gr.Textbox( label="Instruction", value="You are an expert radiographer. Describe accurately what you see in this image.", placeholder="Provide specific instructions..." ) with gr.Column(): # แสดงประวัติ Chat chatbot = gr.Chatbot(label="Chat History") with gr.Row(): clear_btn = gr.Button("Clear") submit_btn = gr.Button("Submit") # การทำงานของปุ่ม Submit พร้อมล้างเฉพาะข้อความใน instruction_input submit_btn.click( lambda image, instruction, history: ( *chat_process(image, instruction, history), image, # รีเซ็ตค่า image_input "" ), inputs=[image_input, instruction_input, chatbot], outputs=[chatbot, chatbot, image_input, instruction_input] ) # การทำงานของปุ่ม Clear clear_btn.click( lambda: (None, None, None, None), inputs=[], outputs=[chatbot, chatbot, image_input, instruction_input] ) # รันแอป demo.launch(debug=True)