0llheaven commited on
Commit
bd9df5d
·
verified ·
1 Parent(s): b250b49

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -68
app.py CHANGED
@@ -1,87 +1,133 @@
1
- import spaces
2
- import gradio as gr
3
  import torch
4
  from PIL import Image
5
- from transformers import AutoModelForImageTextToText, MllamaForConditionalGeneration, AutoProcessor
6
  from transformers import TextStreamer
7
- from torchvision.transforms import Resize
8
- from unsloth import FastVisionModel
9
 
10
- # Define the model and processor
11
- model_id = "0llheaven/Llama-3.2-11B-Vision-Radiology-mini"
12
 
13
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
- print(f"Using device: {device}")
 
 
 
 
15
 
16
- # device = "cuda" if torch.cuda.is_available() else "cpu"
17
- model = AutoModelForImageTextToText.from_pretrained(
18
- model_id,
19
- # load_in_4bit=True,
20
- torch_dtype=torch.float32 if device.type == "cpu" else torch.bfloat16,
21
- device_map=device,
22
- ).to(device)
23
 
24
- # if device.type == "cuda":
25
- # model.gradient_checkpointing_enable()
26
- model.gradient_checkpointing_enable()
27
 
28
- processor = AutoProcessor.from_pretrained(model_id)
 
 
29
 
30
- # @spaces.GPU(duration=120)
31
- # Function to process the image and generate the description
32
- def generate_description(image: Image.Image, instruction: str):
33
 
34
- FastVisionModel.for_inference(model)
35
- print("กำลังโหลด tokenizer...")
36
- base_model, tokenizer = FastVisionModel.from_pretrained(
37
- "unsloth/Llama-3.2-11B-Vision-Instruct",
38
- # load_in_4bit = True,
39
- use_gradient_checkpointing = "unsloth",
40
- )
41
-
42
- image = image.convert("RGB")
43
- # image = Resize((224, 224))(image)
44
-
45
- # Create the message to pass to the model
46
- instruction = "You are an expert radiographer. Describe accurately what you see in this image."
47
- messages = [
48
- {"role": "user", "content": [
49
  {"type": "image"},
50
  {"type": "text", "text": instruction}
51
- ]}
52
- ]
53
-
54
- input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
55
- # input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
56
- inputs = tokenizer(
57
- image,
58
- input_text,
59
- add_special_tokens=False,
60
- return_tensors="pt"
61
- ).to(device)
62
-
63
- # Generate the output from the model
64
- # output = model.generate(**inputs, max_new_tokens=256)
65
- text_streamer = TextStreamer(tokenizer, skip_prompt=True)
66
- outputs = model.generate(
67
- **inputs,
68
  streamer=text_streamer,
69
  max_new_tokens=256,
70
  use_cache=True,
71
  temperature=1.5,
72
  min_p=0.1
73
  )
74
- return tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
75
-
76
- # Define Gradio interface
77
- interface = gr.Interface(
78
- fn=generate_description,
79
- inputs=gr.Image(type="pil", label="Upload an Image"),
80
- outputs=gr.Textbox(label="Generated Description"),
81
- # live=True,
82
- title="Radiology Image Description Generator",
83
- description="Upload an image and provide an instruction to generate a description using a vision-language model."
84
- )
85
 
86
- # Launch the interface
87
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from unsloth import FastVisionModel
3
  import torch
4
  from PIL import Image
5
+ from datasets import load_dataset
6
  from transformers import TextStreamer
7
+ import matplotlib.pyplot as plt
 
8
 
9
+ import gradio as gr
 
10
 
11
+ # Load the model
12
+ model, tokenizer = FastVisionModel.from_pretrained(
13
+ "0llheaven/Llama-3.2-11B-Vision-Radiology-mini",
14
+ load_in_4bit=True,
15
+ use_gradient_checkpointing="unsloth",
16
+ )
17
 
18
+ # เปลี่ยนโหมดของโมเดลเป็นสำหรับ inference
19
+ FastVisionModel.for_inference(model)
 
 
 
 
 
20
 
21
+ # ตัวแปรสำหรับแคช
22
+ cached_image = None
23
+ cached_response = None
24
 
25
+ # ฟังก์ชันประมวลผลภาพและสร้างคำอธิบาย
26
+ def predict_radiology_description(image, instruction):
27
+ global cached_image, cached_response
28
 
29
+ try:
 
 
30
 
31
+ current_image_tensor = torch.tensor(image.getdata())
32
+
33
+ # ตรวจสอบว่าภาพเหมือนเดิมและข้อความเหมือนเดิมหรือไม่
34
+ if cached_image is not None and torch.equal(cached_image, current_image_tensor):
35
+ # ใช้ cached_response กับ text ใหม่
36
+ return cached_response
37
+
38
+ # เตรียมข้อความในรูปแบบที่โมเดลรองรับ
39
+ messages = [{"role": "user", "content": [
 
 
 
 
 
 
40
  {"type": "image"},
41
  {"type": "text", "text": instruction}
42
+ ]}]
43
+ input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
44
+
45
+ # เตรียม input สำหรับโมเดล
46
+ inputs = tokenizer(
47
+ image,
48
+ input_text,
49
+ add_special_tokens=False,
50
+ return_tensors="pt",
51
+ ).to("cuda")
52
+
53
+ # ใช้ TextStreamer สำหรับการพยากรณ์
54
+ text_streamer = TextStreamer(tokenizer, skip_prompt=True)
55
+
56
+ # ทำนายข้อความ
57
+ output_ids = model.generate(
58
+ **inputs,
59
  streamer=text_streamer,
60
  max_new_tokens=256,
61
  use_cache=True,
62
  temperature=1.5,
63
  min_p=0.1
64
  )
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ # แปลงข้อความที่สร้างเป็นผลลัพธ์
67
+ generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
68
+
69
+ cached_image = current_image_tensor # แคชภาพเป็น Tensor
70
+ cached_response = generated_text.replace("assistant", "\n\nAssistant").strip()
71
+ return cached_response
72
+ except Exception as e:
73
+ return f"Error: {str(e)}"
74
+
75
+ # ฟังก์ชัน ChatBot
76
+ def chat_process(image, instruction, history=None):
77
+ if history is None:
78
+ history = []
79
+
80
+ # ประมวลผลภาพและคำสั่ง
81
+ response = predict_radiology_description(image, instruction)
82
+
83
+ # อัปเดตประวัติ
84
+ history.append((instruction, response))
85
+ return history, history
86
+
87
+ import warnings
88
+ warnings.filterwarnings("ignore", category=UserWarning, module="gradio.helpers")
89
+
90
+ # UI ของ Gradio
91
+ with gr.Blocks() as demo:
92
+ gr.Markdown("# 🩻 Radiology Image ChatBot")
93
+ gr.Markdown("Upload a radiology image and provide an instruction for the AI to describe the findings.")
94
+ gr.Markdown("Example instruction : You are an expert radiographer. Describe accurately what you see in this image.")
95
+
96
+ with gr.Row():
97
+ with gr.Column():
98
+ # อัปโหลดรูปภาพ
99
+ image_input = gr.Image(type="pil", label="Upload Radiology Image")
100
+ # ป้อนคำสั่ง (instruction)
101
+ instruction_input = gr.Textbox(
102
+ label="Instruction",
103
+ value="You are an expert radiographer. Describe accurately what you see in this image.",
104
+ placeholder="Provide specific instructions..."
105
+ )
106
+ with gr.Column():
107
+ # แสดงประวัติ Chat
108
+ chatbot = gr.Chatbot(label="Chat History")
109
+
110
+ with gr.Row():
111
+ clear_btn = gr.Button("Clear")
112
+ submit_btn = gr.Button("Submit")
113
+
114
+ # การทำงานของปุ่ม Submit พร้อมล้างเฉพาะข้อความใน instruction_input
115
+ submit_btn.click(
116
+ lambda image, instruction, history: (
117
+ *chat_process(image, instruction, history),
118
+ image, # รีเซ็ตค่า image_input
119
+ ""
120
+ ),
121
+ inputs=[image_input, instruction_input, chatbot],
122
+ outputs=[chatbot, chatbot, image_input, instruction_input]
123
+ )
124
+
125
+ # การทำงานของปุ่ม Clear
126
+ clear_btn.click(
127
+ lambda: (None, None, None, None),
128
+ inputs=[],
129
+ outputs=[chatbot, chatbot, image_input, instruction_input]
130
+ )
131
+
132
+ # รันแอป
133
+ demo.launch(debug=True)