jameslahm commited on
Commit
0580ff2
·
verified ·
1 Parent(s): f4ea8b0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +130 -19
app.py CHANGED
@@ -11,12 +11,9 @@ import spaces
11
 
12
  @spaces.GPU
13
  def init_model(model_id, is_pf=False):
14
- if not is_pf:
15
- path = hf_hub_download(repo_id="jameslahm/yoloe", filename=f"{model_id}-seg.pt")
16
- model = YOLOE(path)
17
- else:
18
- path = hf_hub_download(repo_id="jameslahm/yoloe", filename=f"{model_id}-seg-pf.pt")
19
- model = YOLOE(path)
20
  model.eval()
21
  model.to("cuda")
22
  return model
@@ -72,13 +69,13 @@ def app():
72
  with gr.Tab("Visual") as visual_tab:
73
  with gr.Row():
74
  visual_prompt_type = gr.Dropdown(choices=["bboxes", "masks"], value="bboxes", label="Visual Type", interactive=True)
75
- visual_usage_type = gr.Radio(choices=["Intra-Image", "Inter-Image"], value="Intra-Image", label="Intra/Inter Image", interactive=True)
76
 
77
  with gr.Tab("Prompt-Free") as prompt_free_tab:
78
  gr.HTML(
79
  """
80
  <p style='text-align: center'>
81
- Prompt-Free Mode is On
82
  </p>
83
  """, show_label=False)
84
 
@@ -123,11 +120,10 @@ def app():
123
  return gr.update(value="Text"), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
124
 
125
  def update_visual_image_visiblity(visual_prompt_type, visual_usage_type):
126
- use_target = gr.update(visible=True) if visual_usage_type == "Inter-Image" else gr.update(visible=False)
127
  if visual_prompt_type == "bboxes":
128
- return gr.update(value="Visual"), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), use_target
129
  elif visual_prompt_type == "masks":
130
- return gr.update(value="Visual"), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), use_target
131
 
132
  def update_pf_image_visibility():
133
  return gr.update(value="Prompt-free"), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
@@ -159,10 +155,10 @@ def app():
159
 
160
  def update_visual_usage_type(visual_usage_type):
161
  if visual_usage_type == "Intra-Image":
162
- return gr.update(visible=False, value=None)
163
- if visual_usage_type == "Inter-Image":
164
- return gr.update(visible=True, value=None)
165
- return gr.update(visible=False, value=None)
166
 
167
  visual_prompt_type.change(
168
  fn=update_visual_prompt_type,
@@ -176,9 +172,10 @@ def app():
176
  outputs=[target_image]
177
  )
178
 
179
- def run_inference(raw_image, box_image, mask_image, target_image, texts, model_id, image_size, conf_thresh, iou_thresh, prompt_type, visual_prompt_type):
180
  # add text/built-in prompts
181
  if prompt_type == "Text" or prompt_type == "Prompt-free":
 
182
  image = raw_image
183
  if prompt_type == "Prompt-free":
184
  with open('tools/ram_tag_list.txt', 'r') as f:
@@ -190,9 +187,14 @@ def app():
190
  }
191
  # add visual prompt
192
  elif prompt_type == "Visual":
 
 
193
  if visual_prompt_type == "bboxes":
194
  image, points = box_image["image"], box_image["points"]
195
  points = np.array(points)
 
 
 
196
  prompts = {
197
  "bboxes": np.array([p[[0, 1, 3, 4]] for p in points if p[2] == 2]),
198
  }
@@ -202,6 +204,9 @@ def app():
202
  masks = np.array(masks.convert("L"))
203
  masks = binary_fill_holes(masks).astype(np.uint8)
204
  masks[masks > 0] = 1
 
 
 
205
  prompts = {
206
  "masks": masks[None]
207
  }
@@ -209,10 +214,116 @@ def app():
209
 
210
  yoloe_infer.click(
211
  fn=run_inference,
212
- inputs=[raw_image, box_image, mask_image, target_image, texts, model_id, image_size, conf_thresh, iou_thresh, prompt_type, visual_prompt_type],
213
  outputs=[output_image],
214
  )
215
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
 
217
  gradio_app = gr.Blocks()
218
  with gradio_app:
@@ -241,7 +352,7 @@ with gradio_app:
241
  )
242
  gr.Markdown(
243
  """
244
- Drawing **multiple** boxes or handcrafted shapes as visual prompt in an image is also supported.
245
  """
246
  )
247
  with gr.Row():
@@ -249,4 +360,4 @@ with gradio_app:
249
  app()
250
 
251
  if __name__ == '__main__':
252
- gradio_app.launch(allowed_paths=["figures"])
 
11
 
12
  @spaces.GPU
13
  def init_model(model_id, is_pf=False):
14
+ filename = f"{model_id}-seg.pt" if not is_pf else f"{model_id}-seg-pf.pt"
15
+ path = hf_hub_download(repo_id="jameslahm/yoloe", filename=filename)
16
+ model = YOLOE(path)
 
 
 
17
  model.eval()
18
  model.to("cuda")
19
  return model
 
69
  with gr.Tab("Visual") as visual_tab:
70
  with gr.Row():
71
  visual_prompt_type = gr.Dropdown(choices=["bboxes", "masks"], value="bboxes", label="Visual Type", interactive=True)
72
+ visual_usage_type = gr.Radio(choices=["Intra-Image", "Cross-Image"], value="Intra-Image", label="Intra/Cross Image", interactive=True)
73
 
74
  with gr.Tab("Prompt-Free") as prompt_free_tab:
75
  gr.HTML(
76
  """
77
  <p style='text-align: center'>
78
+ <b>Prompt-Free Mode is On</b>
79
  </p>
80
  """, show_label=False)
81
 
 
120
  return gr.update(value="Text"), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
121
 
122
  def update_visual_image_visiblity(visual_prompt_type, visual_usage_type):
 
123
  if visual_prompt_type == "bboxes":
124
+ return gr.update(value="Visual"), gr.update(visible=False), gr.update(visible=True), gr.update(visible=False), gr.update(visible=(visual_usage_type == "Cross-Image"))
125
  elif visual_prompt_type == "masks":
126
+ return gr.update(value="Visual"), gr.update(visible=False), gr.update(visible=False), gr.update(visible=True), gr.update(visible=(visual_usage_type == "Cross-Image"))
127
 
128
  def update_pf_image_visibility():
129
  return gr.update(value="Prompt-free"), gr.update(visible=True), gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
 
155
 
156
  def update_visual_usage_type(visual_usage_type):
157
  if visual_usage_type == "Intra-Image":
158
+ return gr.update(visible=False)
159
+ if visual_usage_type == "Cross-Image":
160
+ return gr.update(visible=True)
161
+ return gr.update(visible=False)
162
 
163
  visual_prompt_type.change(
164
  fn=update_visual_prompt_type,
 
172
  outputs=[target_image]
173
  )
174
 
175
+ def run_inference(raw_image, box_image, mask_image, target_image, texts, model_id, image_size, conf_thresh, iou_thresh, prompt_type, visual_prompt_type, visual_usage_type):
176
  # add text/built-in prompts
177
  if prompt_type == "Text" or prompt_type == "Prompt-free":
178
+ target_image = None
179
  image = raw_image
180
  if prompt_type == "Prompt-free":
181
  with open('tools/ram_tag_list.txt', 'r') as f:
 
187
  }
188
  # add visual prompt
189
  elif prompt_type == "Visual":
190
+ if visual_usage_type != "Cross-Image":
191
+ target_image = None
192
  if visual_prompt_type == "bboxes":
193
  image, points = box_image["image"], box_image["points"]
194
  points = np.array(points)
195
+ if len(points) == 0:
196
+ gr.Warning("No boxes are provided. No image output.", visible=True)
197
+ return gr.update(value=None)
198
  prompts = {
199
  "bboxes": np.array([p[[0, 1, 3, 4]] for p in points if p[2] == 2]),
200
  }
 
204
  masks = np.array(masks.convert("L"))
205
  masks = binary_fill_holes(masks).astype(np.uint8)
206
  masks[masks > 0] = 1
207
+ if masks.sum() == 0:
208
+ gr.Warning("No masks are provided. No image output.", visible=True)
209
+ return gr.update(value=None)
210
  prompts = {
211
  "masks": masks[None]
212
  }
 
214
 
215
  yoloe_infer.click(
216
  fn=run_inference,
217
+ inputs=[raw_image, box_image, mask_image, target_image, texts, model_id, image_size, conf_thresh, iou_thresh, prompt_type, visual_prompt_type, visual_usage_type],
218
  outputs=[output_image],
219
  )
220
 
221
+ ###################### Examples ##########################
222
+ text_examples = gr.Examples(
223
+ examples=[[
224
+ "ultralytics/assets/bus.jpg",
225
+ "person,bus",
226
+ "yoloe-v8l",
227
+ 640,
228
+ 0.25,
229
+ 0.7]],
230
+ inputs=[raw_image, texts, model_id, image_size, conf_thresh, iou_thresh],
231
+ visible=True, cache_examples=False, label="Text Prompt Examples")
232
+
233
+ box_examples = gr.Examples(
234
+ examples=[[
235
+ {"image": "ultralytics/assets/bus_box.jpg", "points": [[235, 408, 2, 342, 863, 3]]},
236
+ "ultralytics/assets/zidane.jpg",
237
+ "yoloe-v8l",
238
+ 640,
239
+ 0.2,
240
+ 0.7,
241
+ ]],
242
+ inputs=[box_image, target_image, model_id, image_size, conf_thresh, iou_thresh],
243
+ visible=False, cache_examples=False, label="Box Visual Prompt Examples")
244
+
245
+ mask_examples = gr.Examples(
246
+ examples=[[
247
+ {"background": "ultralytics/assets/bus.jpg", "layers": ["ultralytics/assets/bus_mask.png"], "composite": "ultralytics/assets/bus_composite.jpg"},
248
+ "ultralytics/assets/zidane.jpg",
249
+ "yoloe-v8l",
250
+ 640,
251
+ 0.15,
252
+ 0.7,
253
+ ]],
254
+ inputs=[mask_image, target_image, model_id, image_size, conf_thresh, iou_thresh],
255
+ visible=False, cache_examples=False, label="Mask Visual Prompt Examples")
256
+
257
+ pf_examples = gr.Examples(
258
+ examples=[[
259
+ "ultralytics/assets/bus.jpg",
260
+ "yoloe-v8l",
261
+ 640,
262
+ 0.25,
263
+ 0.7,
264
+ ]],
265
+ inputs=[raw_image, model_id, image_size, conf_thresh, iou_thresh],
266
+ visible=False, cache_examples=False, label="Prompt-free Examples")
267
+
268
+ # Components update
269
+ def load_box_example(visual_usage_type):
270
+ return (gr.update(visible=True, value={"image": "ultralytics/assets/bus_box.jpg", "points": [[235, 408, 2, 342, 863, 3]]}),
271
+ gr.update(visible=(visual_usage_type=="Cross-Image")))
272
+
273
+ def load_mask_example(visual_usage_type):
274
+ return gr.update(visible=True), gr.update(visible=(visual_usage_type=="Cross-Image"))
275
+
276
+ box_examples.load_input_event.then(
277
+ fn=load_box_example,
278
+ inputs=visual_usage_type,
279
+ outputs=[box_image, target_image]
280
+ )
281
+
282
+ mask_examples.load_input_event.then(
283
+ fn=load_mask_example,
284
+ inputs=visual_usage_type,
285
+ outputs=[mask_image, target_image]
286
+ )
287
+
288
+ # Examples update
289
+ def update_text_examples():
290
+ return gr.Dataset(visible=True), gr.Dataset(visible=False), gr.Dataset(visible=False), gr.Dataset(visible=False)
291
+
292
+ def update_pf_examples():
293
+ return gr.Dataset(visible=False), gr.Dataset(visible=False), gr.Dataset(visible=False), gr.Dataset(visible=True)
294
+
295
+ def update_visual_examples(visual_prompt_type):
296
+ if visual_prompt_type == "bboxes":
297
+ return gr.Dataset(visible=False), gr.Dataset(visible=True), gr.Dataset(visible=False), gr.Dataset(visible=False),
298
+ elif visual_prompt_type == "masks":
299
+ return gr.Dataset(visible=False), gr.Dataset(visible=False), gr.Dataset(visible=True), gr.Dataset(visible=False),
300
+
301
+ text_tab.select(
302
+ fn=update_text_examples,
303
+ inputs=None,
304
+ outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset]
305
+ )
306
+ visual_tab.select(
307
+ fn=update_visual_examples,
308
+ inputs=[visual_prompt_type],
309
+ outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset]
310
+ )
311
+ prompt_free_tab.select(
312
+ fn=update_pf_examples,
313
+ inputs=None,
314
+ outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset]
315
+ )
316
+ visual_prompt_type.change(
317
+ fn=update_visual_examples,
318
+ inputs=[visual_prompt_type],
319
+ outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset]
320
+ )
321
+ visual_usage_type.change(
322
+ fn=update_visual_examples,
323
+ inputs=[visual_prompt_type],
324
+ outputs=[text_examples.dataset, box_examples.dataset, mask_examples.dataset, pf_examples.dataset]
325
+ )
326
+
327
 
328
  gradio_app = gr.Blocks()
329
  with gradio_app:
 
352
  )
353
  gr.Markdown(
354
  """
355
+ Drawing **multiple** boxes or handcrafted shapes as visual prompt in an image is also supported, which leads to more accurate prompt.
356
  """
357
  )
358
  with gr.Row():
 
360
  app()
361
 
362
  if __name__ == '__main__':
363
+ gradio_app.launch(allowed_paths=["figures"])