sergiopaniego HF Staff commited on
Commit
36f3f37
·
verified ·
1 Parent(s): b553066

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -36
app.py CHANGED
@@ -1,33 +1,53 @@
1
  import random
2
  import requests
3
  import json
 
4
 
5
  import matplotlib.pyplot as plt
6
  from PIL import Image, ImageDraw, ImageFont
7
 
8
  import gradio as gr
9
  import torch
10
- from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
11
  from qwen_vl_utils import process_vision_info
12
  from spaces import GPU
13
  from gradio.themes.ocean import Ocean
14
 
15
  # --- Config ---
16
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
17
- "Qwen/Qwen2.5-VL-3B-Instruct", torch_dtype="auto", device_map="auto"
 
 
 
 
 
 
 
 
 
18
  )
19
 
 
 
 
 
 
 
 
20
  min_pixels = 224 * 224
21
- max_pixels = 512 * 512
22
- processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
 
23
 
24
  label2color = {}
 
25
 
26
  def get_color(label, explicit_color=None):
27
  if explicit_color:
28
  return explicit_color
29
  if label not in label2color:
30
- label2color[label] = "#" + ''.join(random.choices('0123456789ABCDEF', k=6))
 
31
  return label2color[label]
32
 
33
  def create_annotated_image(image, json_data, height, width):
@@ -47,11 +67,8 @@ def create_annotated_image(image, json_data, height, width):
47
  draw = ImageDraw.Draw(draw_image)
48
 
49
  try:
50
- print(1)
51
- print('int(12 * scale_factor)', int(12 * scale_factor))
52
- font = ImageFont.truetype("arial.ttf", int(12 * scale_factor))
53
  except:
54
- print(2)
55
  font = ImageFont.load_default()
56
 
57
  for item in bbox_data:
@@ -84,10 +101,51 @@ def create_annotated_image(image, json_data, height, width):
84
 
85
  return draw_image
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  @GPU
88
- def detect(image, prompt):
89
- STANDARD_SIZE = (512, 512)
90
- image.thumbnail(STANDARD_SIZE)
91
  messages = [
92
  {
93
  "role": "user",
@@ -98,21 +156,21 @@ def detect(image, prompt):
98
  }
99
  ]
100
 
101
- text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
102
  image_inputs, video_inputs = process_vision_info(messages)
103
- inputs = processor(
104
  text=[text],
105
  images=image_inputs,
106
  videos=video_inputs,
107
  padding=True,
108
  return_tensors="pt",
109
- ).to(model.device)
110
 
111
- generated_ids = model.generate(**inputs, max_new_tokens=1024)
112
  generated_ids_trimmed = [
113
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
114
  ]
115
- output_text = processor.batch_decode(
116
  generated_ids_trimmed, do_sample=True, skip_special_tokens=True, clean_up_tokenization_spaces=False
117
  )[0]
118
 
@@ -123,6 +181,30 @@ def detect(image, prompt):
123
 
124
  return annotated_image, output_text
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  css_hide_share = """
127
  button#gradio-share-link-button-0 {
128
  display: none !important;
@@ -132,43 +214,70 @@ button#gradio-share-link-button-0 {
132
  # --- Gradio Interface ---
133
  with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
134
 
135
- gr.Markdown("# Object Understanding with Vision Language Models")
136
  gr.Markdown("### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts.")
137
  gr.Markdown("""
138
- *Powered by Qwen2.5-VL*
139
- *Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
140
  """)
141
 
142
  with gr.Row():
143
- with gr.Column(scale=1):
144
  image_input = gr.Image(label="Upload an image", type="pil", height=400)
145
- prompt_input = gr.Textbox(label="Enter your prompt", placeholder="e.g., Detect all red cars in the image")
146
- category_input = gr.Textbox(label="Category", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
  generate_btn = gr.Button(value="Generate")
148
 
149
  with gr.Column(scale=1):
150
- output_image = gr.Image(type="pil", label="Annotated image", height=400)
151
- output_textbox = gr.Textbox(label="Model response", lines=10)
 
 
 
 
152
 
153
  gr.Markdown("### Examples")
154
  example_prompts = [
155
- ["examples/example_1.jpg", "Detect all objects in the image and return their locations and labels.", "Object Detection"],
156
- ["examples/example_2.JPG", "Detect all the individual candies in the image and return their locations and labels.", "Object Detection"],
157
- ["examples/example_1.jpg", "Count the number of red cars in the image.", "Object Counting"],
158
- ["examples/example_2.JPG", "Count the number of blue candies in the image.", "Object Counting"],
159
- ["examples/example_1.jpg", "Identify the red cars in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"],
160
- ["examples/example_2.JPG", "Identify the blue candies in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"],
161
- ["examples/example_1.jpg", "Detect the red car that is leading in this image and return its location and label.", "Visual Grounding + Object Detection"],
162
- ["examples/example_2.JPG", "Detect the blue candy located at the top of the group in this image and return its location and label.", "Visual Grounding + Object Detection"],
163
  ]
164
 
165
  gr.Examples(
166
  examples=example_prompts,
167
- inputs=[image_input, prompt_input, category_input],
168
  label="Click an example to populate the input"
169
  )
170
 
171
- generate_btn.click(fn=detect, inputs=[image_input, prompt_input], outputs=[output_image, output_textbox])
172
 
173
  if __name__ == "__main__":
174
  demo.launch()
 
 
1
  import random
2
  import requests
3
  import json
4
+ import ast
5
 
6
  import matplotlib.pyplot as plt
7
  from PIL import Image, ImageDraw, ImageFont
8
 
9
  import gradio as gr
10
  import torch
11
+ from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration, AutoModelForCausalLM
12
  from qwen_vl_utils import process_vision_info
13
  from spaces import GPU
14
  from gradio.themes.ocean import Ocean
15
 
16
  # --- Config ---
17
+ model_qwen_id = "Qwen/Qwen2.5-VL-3B-Instruct"
18
+ model_moondream_id = "vikhyatk/moondream2"
19
+
20
+ model_qwen = Qwen2_5_VLForConditionalGeneration.from_pretrained(
21
+ model_qwen_id, torch_dtype="auto", device_map="auto"
22
+ )
23
+ model_moondream = AutoModelForCausalLM.from_pretrained(
24
+ model_moondream_id,
25
+ revision="2025-06-21",
26
+ trust_remote_code=True,
27
+ device_map={"": "cuda"}
28
  )
29
 
30
+ def extract_model_short_name(model_id):
31
+ return model_id.split("/")[-1].replace("-", " ").replace("_", " ")
32
+
33
+ model_qwen_name = extract_model_short_name(model_qwen_id) # → "Qwen2.5 VL 3B Instruct"
34
+ model_moondream_name = extract_model_short_name(model_moondream_id) # → "moondream2"
35
+
36
+
37
  min_pixels = 224 * 224
38
+ max_pixels = 1024 * 1024
39
+ processor_qwen = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-3B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)
40
+ #processor_moondream = AutoProcessor.from_pretrained("vikhyatk/moondream2", trust_remote_code=True)
41
 
42
  label2color = {}
43
+ vivid_colors = ["#e6194b", "#3cb44b", "#0082c8", "#f58231", "#911eb4", "#46f0f0", "#f032e6", "#d2f53c", "#fabebe", "#008080", "#e6beff", "#aa6e28", "#fffac8", "#800000", "#aaffc3", "#808000", "#ffd8b1", "#000080", "#808080", "#000000"]
44
 
45
  def get_color(label, explicit_color=None):
46
  if explicit_color:
47
  return explicit_color
48
  if label not in label2color:
49
+ index = len(label2color) % len(vivid_colors)
50
+ label2color[label] = vivid_colors[index]
51
  return label2color[label]
52
 
53
  def create_annotated_image(image, json_data, height, width):
 
67
  draw = ImageDraw.Draw(draw_image)
68
 
69
  try:
70
+ font = ImageFont.truetype("DejaVuSans-Bold.ttf", int(12 * scale_factor))
 
 
71
  except:
 
72
  font = ImageFont.load_default()
73
 
74
  for item in bbox_data:
 
101
 
102
  return draw_image
103
 
104
+ def create_annotated_image_normalized(image, json_data, label="object", explicit_color=None):
105
+ if not isinstance(json_data, dict):
106
+ return image
107
+
108
+ original_width, original_height = image.size
109
+ scale_factor = max(original_width, original_height) / 512
110
+ draw_image = image.copy()
111
+ draw = ImageDraw.Draw(draw_image)
112
+
113
+ try:
114
+ font = ImageFont.truetype("DejaVuSans-Bold.ttf", int(12 * scale_factor))
115
+ except:
116
+ font = ImageFont.load_default()
117
+
118
+ color = get_color(label, explicit_color)
119
+
120
+ for point in json_data.get("points", []):
121
+ x = int(point["x"] * original_width)
122
+ y = int(point["y"] * original_height)
123
+ radius = int(4 * scale_factor)
124
+ draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color, outline=color)
125
+
126
+ for item in json_data.get("objects", []):
127
+ x_min = int(item["x_min"] * original_width)
128
+ y_min = int(item["y_min"] * original_height)
129
+ x_max = int(item["x_max"] * original_width)
130
+ y_max = int(item["y_max"] * original_height)
131
+ draw.rectangle([x_min, y_min, x_max, y_max], outline=color, width=int(2 * scale_factor))
132
+ draw.text((x_min, max(0, y_min - int(15 * scale_factor))), label, fill=color, font=font)
133
+
134
+ if "reasoning" in json_data:
135
+ for grounding in json_data["reasoning"].get("grounding", []):
136
+ for x_norm, y_norm in grounding.get("points", []):
137
+ x = int(x_norm * original_width)
138
+ y = int(y_norm * original_height)
139
+ radius = int(4 * scale_factor)
140
+ draw.ellipse([x - radius, y - radius, x + radius, y + radius], fill=color, outline=color)
141
+
142
+ return draw_image
143
+
144
+
145
+
146
  @GPU
147
+ def detect_qwen(image, prompt):
148
+
 
149
  messages = [
150
  {
151
  "role": "user",
 
156
  }
157
  ]
158
 
159
+ text = processor_qwen.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
160
  image_inputs, video_inputs = process_vision_info(messages)
161
+ inputs = processor_qwen(
162
  text=[text],
163
  images=image_inputs,
164
  videos=video_inputs,
165
  padding=True,
166
  return_tensors="pt",
167
+ ).to(model_qwen.device)
168
 
169
+ generated_ids = model_qwen.generate(**inputs, max_new_tokens=1024)
170
  generated_ids_trimmed = [
171
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
172
  ]
173
+ output_text = processor_qwen.batch_decode(
174
  generated_ids_trimmed, do_sample=True, skip_special_tokens=True, clean_up_tokenization_spaces=False
175
  )[0]
176
 
 
181
 
182
  return annotated_image, output_text
183
 
184
+
185
+ @GPU
186
+ def detect_moondream(image, prompt, category_input):
187
+ if category_input in ["Object Detection", "Visual Grounding + Object Detection"]:
188
+ output_text = model_moondream.detect(image=image, object=prompt)
189
+ elif category_input == "Visual Grounding + Keypoint Detection":
190
+ output_text = model_moondream.point(image=image, object=prompt)
191
+ else:
192
+ output_text = model_moondream.query(image=image, question=prompt, reasoning=True)
193
+
194
+ annotated_image = create_annotated_image_normalized(image=image, json_data=output_text, label="object", explicit_color=None)
195
+
196
+ return annotated_image, output_text
197
+
198
+ @GPU
199
+ def detect(image, prompt_model_1, prompt_model_2, category_input):
200
+ STANDARD_SIZE = (1024, 1024)
201
+ image.thumbnail(STANDARD_SIZE)
202
+
203
+ annotated_image_model_1, output_text_model_1 = detect_qwen(image, prompt_model_1)
204
+ annotated_image_model_2, output_text_model_2 = detect_moondream(image, prompt_model_2, category_input)
205
+
206
+ return annotated_image_model_1, output_text_model_1, annotated_image_model_2, output_text_model_2
207
+
208
  css_hide_share = """
209
  button#gradio-share-link-button-0 {
210
  display: none !important;
 
214
  # --- Gradio Interface ---
215
  with gr.Blocks(theme=Ocean(), css=css_hide_share) as demo:
216
 
217
+ gr.Markdown("# 👓 Object Understanding with Vision Language Models")
218
  gr.Markdown("### Explore object detection, visual grounding, keypoint detection, and/or object counting through natural language prompts.")
219
  gr.Markdown("""
220
+ *Powered by [Qwen2.5-VL 3B](https://huggingface.co/Qwen/Qwen2.5-VL-3B-Instruct) and [Moondream 2B (revision="2025-06-21")](https://huggingface.co/vikhyatk/moondream2). Inspired by the tutorial [Object Detection and Visual Grounding with Qwen 2.5](https://pyimagesearch.com/2025/06/09/object-detection-and-visual-grounding-with-qwen-2-5/) on PyImageSearch.*
221
+ *Moondream 2B uses the [moondream.py API](https://huggingface.co/vikhyatk/moondream2/blob/main/moondream.py), selecting `detect` for categories with "Object Detection" `point` for the ones with "Keypoint Detection", and reasoning-based querying for all others.*
222
  """)
223
 
224
  with gr.Row():
225
+ with gr.Column(scale=2):
226
  image_input = gr.Image(label="Upload an image", type="pil", height=400)
227
+ prompt_input_model_1 = gr.Textbox(
228
+ label=f"Enter your prompt for {model_qwen_name}",
229
+ placeholder="e.g., Detect all red cars in the image"
230
+ )
231
+
232
+ prompt_input_model_2 = gr.Textbox(
233
+ label=f"Enter your prompt for {model_moondream_name}",
234
+ placeholder="e.g., Detect all blue cars in the image"
235
+ )
236
+
237
+
238
+ categories = [
239
+ "Object Detection",
240
+ "Object Counting",
241
+ "Visual Grounding + Keypoint Detection",
242
+ "Visual Grounding + Object Detection",
243
+ "General query"
244
+ ]
245
+
246
+ category_input = gr.Dropdown(
247
+ choices=categories,
248
+ label="Category",
249
+ interactive=True
250
+ )
251
  generate_btn = gr.Button(value="Generate")
252
 
253
  with gr.Column(scale=1):
254
+ output_image_model_1 = gr.Image(type="pil", label=f"Annotated image for {model_qwen_name}", height=400)
255
+ output_textbox_model_1 = gr.Textbox(label=f"Model response for {model_qwen_name}", lines=10)
256
+
257
+ with gr.Column(scale=1):
258
+ output_image_model_2 = gr.Image(type="pil", label=f"Annotated image for {model_moondream_name}", height=400)
259
+ output_textbox_model_2 = gr.Textbox(label=f"Model response for {model_moondream_name}", lines=10)
260
 
261
  gr.Markdown("### Examples")
262
  example_prompts = [
263
+ ["examples/example_1.jpg", "Detect all objects in the image and return their locations and labels.", "Detect all objects in the image and return their locations and labels.", "Object Detection"],
264
+ ["examples/example_2.JPG", "Detect all the individual candies in the image and return their locations and labels.", "Detect all the individual candies in the image and return their locations and labels.", "Object Detection"],
265
+ ["examples/example_1.jpg", "Count the number of red cars in the image.", "Count the number of red cars in the image.", "Object Counting"],
266
+ ["examples/example_2.JPG", "Count the number of blue candies in the image.", "Count the number of blue candies in the image.", "Object Counting"],
267
+ ["examples/example_1.jpg", "Identify the red cars in this image, detect their key points and return their positions in the form of points.", "Identify the red cars in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"],
268
+ ["examples/example_2.JPG", "Identify the blue candies in this image, detect their key points and return their positions in the form of points.", "Identify the blue candies in this image, detect their key points and return their positions in the form of points.", "Visual Grounding + Keypoint Detection"],
269
+ ["examples/example_1.jpg", "Detect the red car that is leading in this image and return its location and label.", "Detect the red car that is leading in this image and return its location and label.", "Visual Grounding + Object Detection"],
270
+ ["examples/example_2.JPG", "Detect the blue candy located at the top of the group in this image and return its location and label.", "Detect the blue candy located at the top of the group in this image and return its location and label.", "Visual Grounding + Object Detection"],
271
  ]
272
 
273
  gr.Examples(
274
  examples=example_prompts,
275
+ inputs=[image_input, prompt_input_model_1, prompt_input_model_2, category_input],
276
  label="Click an example to populate the input"
277
  )
278
 
279
+ generate_btn.click(fn=detect, inputs=[image_input, prompt_input_model_1, prompt_input_model_2, category_input], outputs=[output_image_model_1, output_textbox_model_1, output_image_model_2, output_textbox_model_2])
280
 
281
  if __name__ == "__main__":
282
  demo.launch()
283
+