majinyu commited on
Commit
2c0c15f
·
1 Parent(s): 3e575ea

add a checkbox to make grounded-sam optional

Browse files
Files changed (1) hide show
  1. app.py +46 -26
app.py CHANGED
@@ -125,7 +125,10 @@ def draw_box(box, draw, label):
125
 
126
 
127
  @torch.no_grad()
128
- def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grounding_dino_model, sam_model):
 
 
 
129
  print(f"Start processing, image size {raw_image.size}")
130
  raw_image = raw_image.convert("RGB")
131
 
@@ -155,6 +158,13 @@ def inference(raw_image, specified_tags, tagging_model_type, tagging_model, grou
155
  print(f"Tags: {tags}")
156
  print(f"Caption: {caption}")
157
 
 
 
 
 
 
 
 
158
  # run groundingDINO
159
  transform = T.Compose([
160
  T.RandomResize([800], max_size=1333),
@@ -255,14 +265,17 @@ if __name__ == "__main__":
255
  <br>
256
  Welcome to the RAM/Tag2Text + Grounded-SAM demo! <br><br>
257
  <li>
258
- <b>Recognize Anything Model + Grounded-SAM:</b> Upload your image to get the <b>English and Chinese tags</b> (by RAM) and <b>masks and boxes</b> (by Grounded-SAM)!
259
  </li>
260
  <li>
261
- <b>Tag2Text Model + Grounded-SAM:</b> Upload your image to get the <b>tags and caption</b> (by Tag2Text) and <b>masks and boxes</b> (by Grounded-SAM)!
262
  (Optional: Specify tags to get the corresponding caption.)
263
  </li>
 
 
 
264
  <br>
265
- Note: this demo may take up to minutes to inference. If you do not need masks and boxes, visit <a href='https://huggingface.co/spaces/xinyu1205/Recognize_Anything-Tag2Text/' target='_blank'>this demo</a>.
266
  """ # noqa
267
 
268
  article = """
@@ -277,11 +290,17 @@ if __name__ == "__main__":
277
  </p>
278
  """ # noqa
279
 
280
- def inference_with_ram(img):
281
- return inference(img, None, "RAM", ram_model, grounding_dino_model, sam_model)
 
 
 
282
 
283
- def inference_with_t2t(img, input_tags):
284
- return inference(img, input_tags, "Tag2Text", tag2text_model, grounding_dino_model, sam_model)
 
 
 
285
 
286
  with gr.Blocks(title="Recognize Anything Model") as demo:
287
  ###############
@@ -293,6 +312,7 @@ if __name__ == "__main__":
293
  with gr.Row():
294
  with gr.Column():
295
  ram_in_img = gr.Image(type="pil")
 
296
  with gr.Row():
297
  ram_btn_run = gr.Button(value="Run")
298
  ram_btn_clear = gr.ClearButton()
@@ -302,12 +322,12 @@ if __name__ == "__main__":
302
  ram_out_biaoqian = gr.Textbox(label="标签")
303
  gr.Examples(
304
  examples=[
305
- ["images/demo1.jpg"],
306
- ["images/demo2.jpg"],
307
- ["images/demo4.jpg"],
308
  ],
309
  fn=inference_with_ram,
310
- inputs=[ram_in_img],
311
  outputs=[ram_out_tag, ram_out_biaoqian, ram_out_img],
312
  cache_examples=True
313
  )
@@ -317,6 +337,7 @@ if __name__ == "__main__":
317
  with gr.Column():
318
  t2t_in_img = gr.Image(type="pil")
319
  t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
 
320
  with gr.Row():
321
  t2t_btn_run = gr.Button(value="Run")
322
  t2t_btn_clear = gr.ClearButton()
@@ -326,12 +347,12 @@ if __name__ == "__main__":
326
  t2t_out_cap = gr.Textbox(label="Caption")
327
  gr.Examples(
328
  examples=[
329
- ["images/demo4.jpg", ""],
330
- ["images/demo4.jpg", "power line"],
331
- ["images/demo4.jpg", "track, train"],
332
  ],
333
  fn=inference_with_t2t,
334
- inputs=[t2t_in_img, t2t_in_tag],
335
  outputs=[t2t_out_tag, t2t_out_cap, t2t_out_img],
336
  cache_examples=True
337
  )
@@ -344,23 +365,22 @@ if __name__ == "__main__":
344
  # run inference
345
  ram_btn_run.click(
346
  fn=inference_with_ram,
347
- inputs=[ram_in_img],
348
  outputs=[ram_out_tag, ram_out_biaoqian, ram_out_img]
349
  )
350
  t2t_btn_run.click(
351
  fn=inference_with_t2t,
352
- inputs=[t2t_in_img, t2t_in_tag],
353
  outputs=[t2t_out_tag, t2t_out_cap, t2t_out_img]
354
  )
355
 
356
- ram_btn_clear.add([
357
- ram_in_img, t2t_in_img, t2t_in_tag,
358
- ram_out_img, ram_out_tag, ram_out_biaoqian, t2t_out_img, t2t_out_tag, t2t_out_cap
359
- ])
360
- t2t_btn_clear.add([
361
- ram_in_img, t2t_in_img, t2t_in_tag,
362
- ram_out_img, ram_out_tag, ram_out_biaoqian, t2t_out_img, t2t_out_tag, t2t_out_cap
363
- ])
364
 
365
  return demo
366
 
 
125
 
126
 
127
  @torch.no_grad()
128
+ def inference(
129
+ raw_image, specified_tags, do_det_seg,
130
+ tagging_model_type, tagging_model, grounding_dino_model, sam_model
131
+ ):
132
  print(f"Start processing, image size {raw_image.size}")
133
  raw_image = raw_image.convert("RGB")
134
 
 
158
  print(f"Tags: {tags}")
159
  print(f"Caption: {caption}")
160
 
161
+ # return
162
+ if not do_det_seg:
163
+ if tagging_model_type == "RAM":
164
+ return tags.replace(", ", " | "), tags_chinese.replace(", ", " | "), None
165
+ else:
166
+ return tags.replace(", ", " | "), caption, None
167
+
168
  # run groundingDINO
169
  transform = T.Compose([
170
  T.RandomResize([800], max_size=1333),
 
265
  <br>
266
  Welcome to the RAM/Tag2Text + Grounded-SAM demo! <br><br>
267
  <li>
268
+ <b>Recognize Anything Model:</b> Upload your image to get the <b>English and Chinese tags</b>!
269
  </li>
270
  <li>
271
+ <b>Tag2Text Model:</b> Upload your image to get the <b>tags and caption</b>!
272
  (Optional: Specify tags to get the corresponding caption.)
273
  </li>
274
+ <li>
275
+ <b>Grounded-SAM:</b> Tick the checkbox to get <b>boxes</b> and <b>masks</b> of tags!
276
+ </li>
277
  <br>
278
+ Great thanks to <a href='https://huggingface.co/majinyu' target='_blank'>Ma Jinyu</a>, the major contributor of this demo!
279
  """ # noqa
280
 
281
  article = """
 
290
  </p>
291
  """ # noqa
292
 
293
+ def inference_with_ram(img, do_det_seg):
294
+ return inference(
295
+ img, None, do_det_seg,
296
+ "RAM", ram_model, grounding_dino_model, sam_model
297
+ )
298
 
299
+ def inference_with_t2t(img, input_tags, do_det_seg):
300
+ return inference(
301
+ img, input_tags, do_det_seg,
302
+ "Tag2Text", tag2text_model, grounding_dino_model, sam_model
303
+ )
304
 
305
  with gr.Blocks(title="Recognize Anything Model") as demo:
306
  ###############
 
312
  with gr.Row():
313
  with gr.Column():
314
  ram_in_img = gr.Image(type="pil")
315
+ ram_opt_det_seg = gr.Checkbox(label="Get Boxes and Masks with Grounded-SAM", value=True)
316
  with gr.Row():
317
  ram_btn_run = gr.Button(value="Run")
318
  ram_btn_clear = gr.ClearButton()
 
322
  ram_out_biaoqian = gr.Textbox(label="标签")
323
  gr.Examples(
324
  examples=[
325
+ ["images/demo1.jpg", True],
326
+ ["images/demo2.jpg", True],
327
+ ["images/demo4.jpg", True],
328
  ],
329
  fn=inference_with_ram,
330
+ inputs=[ram_in_img, ram_opt_det_seg],
331
  outputs=[ram_out_tag, ram_out_biaoqian, ram_out_img],
332
  cache_examples=True
333
  )
 
337
  with gr.Column():
338
  t2t_in_img = gr.Image(type="pil")
339
  t2t_in_tag = gr.Textbox(label="User Specified Tags (Optional, separated by comma)")
340
+ t2t_opt_det_seg = gr.Checkbox(label="Get Boxes and Masks with Grounded-SAM", value=True)
341
  with gr.Row():
342
  t2t_btn_run = gr.Button(value="Run")
343
  t2t_btn_clear = gr.ClearButton()
 
347
  t2t_out_cap = gr.Textbox(label="Caption")
348
  gr.Examples(
349
  examples=[
350
+ ["images/demo4.jpg", "", True],
351
+ ["images/demo4.jpg", "power line", False],
352
+ ["images/demo4.jpg", "track, train", False],
353
  ],
354
  fn=inference_with_t2t,
355
+ inputs=[t2t_in_img, t2t_in_tag, t2t_opt_det_seg],
356
  outputs=[t2t_out_tag, t2t_out_cap, t2t_out_img],
357
  cache_examples=True
358
  )
 
365
  # run inference
366
  ram_btn_run.click(
367
  fn=inference_with_ram,
368
+ inputs=[ram_in_img, ram_opt_det_seg],
369
  outputs=[ram_out_tag, ram_out_biaoqian, ram_out_img]
370
  )
371
  t2t_btn_run.click(
372
  fn=inference_with_t2t,
373
+ inputs=[t2t_in_img, t2t_in_tag, t2t_opt_det_seg],
374
  outputs=[t2t_out_tag, t2t_out_cap, t2t_out_img]
375
  )
376
 
377
+ # hide or show image output
378
+ ram_opt_det_seg.change(fn=lambda b: gr.update(visible=b), inputs=[ram_opt_det_seg], outputs=[ram_out_img])
379
+ t2t_opt_det_seg.change(fn=lambda b: gr.update(visible=b), inputs=[t2t_opt_det_seg], outputs=[t2t_out_img])
380
+
381
+ # clear
382
+ ram_btn_clear.add([ram_in_img, ram_out_img, ram_out_tag, ram_out_biaoqian])
383
+ t2t_btn_clear.add([t2t_in_img, t2t_in_tag, t2t_out_img, t2t_out_tag, t2t_out_cap])
 
384
 
385
  return demo
386