0llheaven commited on
Commit
7edff58
·
verified ·
1 Parent(s): 52f1588

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -85
app.py CHANGED
@@ -5,129 +5,112 @@ from PIL import Image
5
  from datasets import load_dataset
6
  from transformers import TextStreamer
7
  import matplotlib.pyplot as plt
8
-
9
  import gradio as gr
 
 
10
 
11
- # Load the model
12
- model, tokenizer = FastVisionModel.from_pretrained(
13
- "0llheaven/Llama-3.2-11B-Vision-Radiology-mini",
14
- load_in_4bit=True,
15
- use_gradient_checkpointing="unsloth",
16
- ).to("cpu")
17
 
18
- # เปลี่ยนโหมดของโมเดลเป็นสำหรับ inference
19
- FastVisionModel.for_inference(model)
 
 
 
 
 
 
20
 
21
- # ตัวแปรสำหรับแคช
22
- cached_image = None
23
- cached_response = None
24
 
25
- # ฟังก์ชันประมวลผลภาพและสร้างคำอธิบาย
26
- def predict_radiology_description(image, instruction):
27
- global cached_image, cached_response
28
-
29
- try:
30
 
31
- current_image_tensor = torch.tensor(image.getdata())
32
 
33
- # ตรวจสอบว่าภาพเหมือนเดิมและข้อความเหมือนเดิมหรือไม่
34
- if cached_image is not None and torch.equal(cached_image, current_image_tensor):
35
- # ใช้ cached_response กับ text ใหม่
36
- return cached_response
37
 
38
- # เตรียมข้อความในรูปแบบที่โมเดลรองรับ
 
 
39
  messages = [{"role": "user", "content": [
40
  {"type": "image"},
41
  {"type": "text", "text": instruction}
42
  ]}]
43
  input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
44
 
45
- # เตรียม input สำหรับโมเดล
46
  inputs = tokenizer(
47
  image,
48
  input_text,
49
  add_special_tokens=False,
50
  return_tensors="pt",
51
- ).to("cpu")
52
 
53
- # ใช้ TextStreamer สำหรับการพยากรณ์
54
  text_streamer = TextStreamer(tokenizer, skip_prompt=True)
55
 
56
- # ทำนายข้อความ
 
 
 
 
 
 
 
 
 
57
  output_ids = model.generate(
58
  **inputs,
59
  streamer=text_streamer,
60
- max_new_tokens=256,
61
- use_cache=True,
62
- temperature=1.5,
63
- min_p=0.1
64
  )
65
 
66
- # แปลงข้อความที่สร้างเป็นผลลัพธ์
67
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
 
68
 
69
- cached_image = current_image_tensor # แคชภาพเป็น Tensor
70
- cached_response = generated_text.replace("assistant", "\n\nAssistant").strip()
71
- return cached_response
72
  except Exception as e:
73
  return f"Error: {str(e)}"
74
 
75
- # ฟังก์ชัน ChatBot
76
- def chat_process(image, instruction, history=None):
77
- if history is None:
78
- history = []
79
-
80
- # ประมวลผลภาพและคำสั่ง
81
- response = predict_radiology_description(image, instruction)
 
82
 
83
- # อัปเดตประวัติ
84
- history.append((instruction, response))
85
- return history, history
86
 
87
- import warnings
88
- warnings.filterwarnings("ignore", category=UserWarning, module="gradio.helpers")
89
 
90
- # UI ของ Gradio
91
- with gr.Blocks() as demo:
92
- gr.Markdown("# 🩻 Radiology Image ChatBot")
93
- gr.Markdown("Upload a radiology image and provide an instruction for the AI to describe the findings.")
94
- gr.Markdown("Example instruction : You are an expert radiographer. Describe accurately what you see in this image.")
95
 
96
- with gr.Row():
97
- with gr.Column():
98
- # อัปโหลดรูปภาพ
99
- image_input = gr.Image(type="pil", label="Upload Radiology Image")
100
- # ป้อนคำสั่ง (instruction)
101
- instruction_input = gr.Textbox(
102
- label="Instruction",
103
- value="You are an expert radiographer. Describe accurately what you see in this image.",
104
- placeholder="Provide specific instructions..."
 
105
  )
106
- with gr.Column():
107
- # แสดงประวัติ Chat
108
- chatbot = gr.Chatbot(label="Chat History")
109
 
110
- with gr.Row():
111
- clear_btn = gr.Button("Clear")
112
- submit_btn = gr.Button("Submit")
113
-
114
- # การทำงานของปุ่ม Submit พร้อมล้างเฉพาะข้อความใน instruction_input
115
- submit_btn.click(
116
- lambda image, instruction, history: (
117
- *chat_process(image, instruction, history),
118
- image, # รีเซ็ตค่า image_input
119
- ""
120
- ),
121
- inputs=[image_input, instruction_input, chatbot],
122
- outputs=[chatbot, chatbot, image_input, instruction_input]
123
- )
124
 
125
- # การทำงานของปุ่ม Clear
126
- clear_btn.click(
127
- lambda: (None, None, None, None),
128
- inputs=[],
129
- outputs=[chatbot, chatbot, image_input, instruction_input]
130
  )
131
 
132
- # รันแอป
133
- demo.launch(debug=True)
 
5
  from datasets import load_dataset
6
  from transformers import TextStreamer
7
  import matplotlib.pyplot as plt
 
8
  import gradio as gr
9
+ import random
10
+ import numpy as np
11
 
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
13
 
14
+ def set_seed(seed_value=42):
15
+ random.seed(seed_value)
16
+ np.random.seed(seed_value)
17
+ torch.manual_seed(seed_value)
18
+ torch.cuda.manual_seed_all(seed_value)
19
+
20
+ torch.backends.cudnn.deterministic = True
21
+ torch.backends.cudnn.benchmark = False
22
 
23
+ set_seed(42)
 
 
24
 
25
+ model, tokenizer = FastVisionModel.from_pretrained(
26
+ "0llheaven/llama-3.2-11B-Vision-Instruct-Finetune",
27
+ load_in_4bit = True,
28
+ use_gradient_checkpointing = "unsloth",
29
+ )
30
 
31
+ FastVisionModel.for_inference(model)
32
 
33
+ instruction = "You are an expert radiographer. Describe accurately what you see in this image."
 
 
 
34
 
35
+ def predict_radiology_description(image, temperature, use_top_p, top_p_value, use_min_p, min_p_value):
36
+ try:
37
+ set_seed(42)
38
  messages = [{"role": "user", "content": [
39
  {"type": "image"},
40
  {"type": "text", "text": instruction}
41
  ]}]
42
  input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
43
 
 
44
  inputs = tokenizer(
45
  image,
46
  input_text,
47
  add_special_tokens=False,
48
  return_tensors="pt",
49
+ ).to("cuda")
50
 
 
51
  text_streamer = TextStreamer(tokenizer, skip_prompt=True)
52
 
53
+ generate_kwargs = {
54
+ "max_new_tokens": 512,
55
+ "use_cache": True,
56
+ "temperature": temperature,
57
+ }
58
+ if use_top_p:
59
+ generate_kwargs["top_p"] = top_p_value
60
+ if use_min_p:
61
+ generate_kwargs["min_p"] = min_p_value
62
+
63
  output_ids = model.generate(
64
  **inputs,
65
  streamer=text_streamer,
66
+ **generate_kwargs
 
 
 
67
  )
68
 
 
69
  generated_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
70
+ return generated_text.replace("assistant", "\n\nassistant").strip()
71
 
 
 
 
72
  except Exception as e:
73
  return f"Error: {str(e)}"
74
 
75
+ with gr.Blocks() as interface:
76
+ gr.Markdown("<h1><center>Radiology Image Description Generator</center></h1>")
77
+ gr.Markdown("Upload a radiology image, adjust temperature and top-p, and the model will describe the findings in the image")
78
+ with gr.Row():
79
+ with gr.Column():
80
+ image_input = gr.Image(type="pil", label="Upload")
81
+ with gr.Column():
82
+ output_text = gr.Textbox(label="Generated Description")
83
 
84
+ with gr.Row():
85
+ with gr.Column(scale=0.5):
86
+ temperature_slider = gr.Slider(0.1, 2.0, step=0.1, value=1.0, label="temperature")
87
 
88
+ use_top_p_checkbox = gr.Checkbox(label="Use top-p", value=True)
89
+ top_p_slider = gr.Slider(0.1, 1.0, step=0.05, value=0.9, label="top-p")
90
 
91
+ use_min_p_checkbox = gr.Checkbox(label="Use min-p", value=False)
92
+ min_p_slider = gr.Slider(0.0, 1.0, step=0.05, value=0.1, label="min-p", visible=False)
 
 
 
93
 
94
+ # Update visibility of sliders
95
+ use_top_p_checkbox.change(
96
+ lambda use_top_p: gr.update(visible=use_top_p),
97
+ inputs=use_top_p_checkbox,
98
+ outputs=top_p_slider
99
+ )
100
+ use_min_p_checkbox.change(
101
+ lambda use_min_p: gr.update(visible=use_min_p),
102
+ inputs=use_min_p_checkbox,
103
+ outputs=min_p_slider
104
  )
 
 
 
105
 
106
+ generate_button = gr.Button("Generate Description")
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
+ # Link function to UI
109
+ generate_button.click(
110
+ predict_radiology_description,
111
+ inputs=[image_input, temperature_slider, use_top_p_checkbox, top_p_slider, use_min_p_checkbox, min_p_slider],
112
+ outputs=output_text
113
  )
114
 
115
+ # Gradio
116
+ interface.launch(share=True, debug=True)