Prasanna Sridhar commited on
Commit
346623e
·
1 Parent(s): c469934

remove unused imports

Browse files
Files changed (1) hide show
  1. app.py +19 -25
app.py CHANGED
@@ -1,16 +1,10 @@
1
  import spaces
2
  import gradio as gr
3
- import copy
4
  import random
5
  import torch
6
- import PIL
7
- from PIL import Image, ImageDraw, ImageFont
8
- import torchvision.transforms.functional as F
9
  import numpy as np
10
  import argparse
11
- import json
12
- import plotly.express as px
13
- import pandas as pd
14
  from util.slconfig import SLConfig, DictAction
15
  from util.misc import nested_tensor_from_tensor_list
16
  import datasets.transforms as T
@@ -258,14 +252,14 @@ def get_ind_to_filter(text, word_ids, keywords):
258
  def count(image, text, prompts, state, device):
259
 
260
  keywords = "" # do not handle this for now
261
-
262
  # Handle no prompt case.
263
  if prompts is None:
264
  prompts = {"image": image, "points": []}
265
  input_image, _ = transform(image, {"exemplars": torch.tensor([])})
266
  input_image = input_image.unsqueeze(0).to(device)
267
  exemplars = get_box_inputs(prompts["points"])
268
-
269
  input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
270
  input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
271
  exemplars = [exemplars["exemplars"].to(device)]
@@ -278,7 +272,7 @@ def count(image, text, prompts, state, device):
278
  [torch.tensor([0]).to(device) for _ in range(len(input_image))],
279
  captions=[text + " ."] * len(input_image),
280
  )
281
-
282
  ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
283
  logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
284
  boxes = model_output["pred_boxes"][0]
@@ -288,7 +282,7 @@ def count(image, text, prompts, state, device):
288
  box_mask = logits.max(dim=-1).values > CONF_THRESH
289
  logits = logits[box_mask, :].cpu().numpy()
290
  boxes = boxes[box_mask, :].cpu().numpy()
291
-
292
  # Plot results.
293
  (w, h) = image.size
294
  det_map = np.zeros((h, w))
@@ -327,7 +321,7 @@ def count(image, text, prompts, state, device):
327
  if len(text.strip()) > 0:
328
  out_label += " text"
329
  if exemplars[0].size()[0] == 1:
330
- out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
331
  elif exemplars[0].size()[0] > 1:
332
  out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
333
  else:
@@ -339,7 +333,7 @@ def count(image, text, prompts, state, device):
339
  out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
340
  else:
341
  out_label = "Nothing specified to detect."
342
-
343
  return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
344
 
345
  @spaces.GPU
@@ -351,11 +345,11 @@ def count_main(image, text, prompts, device):
351
  input_image, _ = transform(image, {"exemplars": torch.tensor([])})
352
  input_image = input_image.unsqueeze(0).to(device)
353
  exemplars = get_box_inputs(prompts["points"])
354
-
355
  input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
356
  input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
357
  exemplars = [exemplars["exemplars"].to(device)]
358
-
359
  with torch.no_grad():
360
  model_output = model(
361
  nested_tensor_from_tensor_list(input_image),
@@ -364,7 +358,7 @@ def count_main(image, text, prompts, device):
364
  [torch.tensor([0]).to(device) for _ in range(len(input_image))],
365
  captions=[text + " ."] * len(input_image),
366
  )
367
-
368
  ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
369
  logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
370
  boxes = model_output["pred_boxes"][0]
@@ -374,7 +368,7 @@ def count_main(image, text, prompts, device):
374
  box_mask = logits.max(dim=-1).values > CONF_THRESH
375
  logits = logits[box_mask, :].cpu().numpy()
376
  boxes = boxes[box_mask, :].cpu().numpy()
377
-
378
  # Plot results.
379
  (w, h) = image.size
380
  det_map = np.zeros((h, w))
@@ -395,7 +389,7 @@ def count_main(image, text, prompts, device):
395
  if len(text.strip()) > 0:
396
  out_label += " text"
397
  if exemplars[0].size()[0] == 1:
398
- out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
399
  elif exemplars[0].size()[0] > 1:
400
  out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
401
  else:
@@ -407,7 +401,7 @@ def count_main(image, text, prompts, device):
407
  out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
408
  else:
409
  out_label = "Nothing specified to detect."
410
-
411
  return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]))
412
 
413
  def remove_label(image):
@@ -452,20 +446,20 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
452
  with gr.Tab("Step 1", visible=True) as step_1:
453
  input_image = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=False, width="30vw")
454
  gr.Markdown('# Click "Count" to count the strawberries.')
455
-
456
  with gr.Column():
457
  with gr.Tab("Output Image"):
458
  detected_instances = gr.Image(label="Detected Instances", show_label='True', interactive=False, visible=True, width="40vw")
459
-
460
  with gr.Row():
461
  input_text = gr.Textbox(label="What would you like to count?", value="strawberry", interactive=True)
462
  pred_count = gr.Number(label="Predicted Count", visible=False)
463
  submit_btn = gr.Button("Count", variant="primary", interactive=True)
464
-
465
  submit_btn.click(fn=remove_label, inputs=[detected_instances], outputs=[detected_instances]).then(fn=count, inputs=[input_image, input_text, exemplar_image, state, device], outputs=[detected_instances, pred_count, submit_btn, step_2, step_3, state])
466
  exemplar_image.change(check_submit_btn, inputs=[exemplar_image, state], outputs=[submit_btn])
467
  with gr.Tab("App", visible=True) as main_app:
468
-
469
  gr.Markdown(
470
  """
471
  # <center>CountGD: Multi-Modal Open-World Counting
@@ -476,7 +470,7 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
476
  Limitation: this app does not support fine-grained counting based on attributes or visual grounding inputs yet. Note: if the exemplar and text conflict each other, both will be counted.</center>
477
  """
478
  )
479
-
480
  with gr.Row():
481
  with gr.Column():
482
  input_image_main = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=True)
@@ -490,6 +484,6 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
490
  gr.Examples(label="Examples: click on a row to load the example. Add visual exemplars by drawing boxes on the loaded \"Visual Exemplar Image.\"", examples=examples, inputs=[input_image_main, input_text_main, exemplar_image_main])
491
  submit_btn_main.click(fn=remove_label, inputs=[detected_instances_main], outputs=[detected_instances_main]).then(fn=count_main, inputs=[input_image_main, input_text_main, exemplar_image_main, device], outputs=[detected_instances_main, pred_count_main])
492
  clear_btn_main.add([input_image_main, input_text_main, exemplar_image_main, detected_instances_main, pred_count_main])
493
-
494
 
495
  demo.queue().launch(allowed_paths=['back-icon.jpg', 'paste-icon.jpg', 'upload-icon.jpg', 'button-legend.jpg'])
 
1
  import spaces
2
  import gradio as gr
 
3
  import random
4
  import torch
5
+ from PIL import Image
 
 
6
  import numpy as np
7
  import argparse
 
 
 
8
  from util.slconfig import SLConfig, DictAction
9
  from util.misc import nested_tensor_from_tensor_list
10
  import datasets.transforms as T
 
252
  def count(image, text, prompts, state, device):
253
 
254
  keywords = "" # do not handle this for now
255
+
256
  # Handle no prompt case.
257
  if prompts is None:
258
  prompts = {"image": image, "points": []}
259
  input_image, _ = transform(image, {"exemplars": torch.tensor([])})
260
  input_image = input_image.unsqueeze(0).to(device)
261
  exemplars = get_box_inputs(prompts["points"])
262
+
263
  input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
264
  input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
265
  exemplars = [exemplars["exemplars"].to(device)]
 
272
  [torch.tensor([0]).to(device) for _ in range(len(input_image))],
273
  captions=[text + " ."] * len(input_image),
274
  )
275
+
276
  ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
277
  logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
278
  boxes = model_output["pred_boxes"][0]
 
282
  box_mask = logits.max(dim=-1).values > CONF_THRESH
283
  logits = logits[box_mask, :].cpu().numpy()
284
  boxes = boxes[box_mask, :].cpu().numpy()
285
+
286
  # Plot results.
287
  (w, h) = image.size
288
  det_map = np.zeros((h, w))
 
321
  if len(text.strip()) > 0:
322
  out_label += " text"
323
  if exemplars[0].size()[0] == 1:
324
+ out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
325
  elif exemplars[0].size()[0] > 1:
326
  out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
327
  else:
 
333
  out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
334
  else:
335
  out_label = "Nothing specified to detect."
336
+
337
  return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
338
 
339
  @spaces.GPU
 
345
  input_image, _ = transform(image, {"exemplars": torch.tensor([])})
346
  input_image = input_image.unsqueeze(0).to(device)
347
  exemplars = get_box_inputs(prompts["points"])
348
+
349
  input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
350
  input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
351
  exemplars = [exemplars["exemplars"].to(device)]
352
+
353
  with torch.no_grad():
354
  model_output = model(
355
  nested_tensor_from_tensor_list(input_image),
 
358
  [torch.tensor([0]).to(device) for _ in range(len(input_image))],
359
  captions=[text + " ."] * len(input_image),
360
  )
361
+
362
  ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
363
  logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
364
  boxes = model_output["pred_boxes"][0]
 
368
  box_mask = logits.max(dim=-1).values > CONF_THRESH
369
  logits = logits[box_mask, :].cpu().numpy()
370
  boxes = boxes[box_mask, :].cpu().numpy()
371
+
372
  # Plot results.
373
  (w, h) = image.size
374
  det_map = np.zeros((h, w))
 
389
  if len(text.strip()) > 0:
390
  out_label += " text"
391
  if exemplars[0].size()[0] == 1:
392
+ out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
393
  elif exemplars[0].size()[0] > 1:
394
  out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
395
  else:
 
401
  out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
402
  else:
403
  out_label = "Nothing specified to detect."
404
+
405
  return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]))
406
 
407
  def remove_label(image):
 
446
  with gr.Tab("Step 1", visible=True) as step_1:
447
  input_image = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=False, width="30vw")
448
  gr.Markdown('# Click "Count" to count the strawberries.')
449
+
450
  with gr.Column():
451
  with gr.Tab("Output Image"):
452
  detected_instances = gr.Image(label="Detected Instances", show_label='True', interactive=False, visible=True, width="40vw")
453
+
454
  with gr.Row():
455
  input_text = gr.Textbox(label="What would you like to count?", value="strawberry", interactive=True)
456
  pred_count = gr.Number(label="Predicted Count", visible=False)
457
  submit_btn = gr.Button("Count", variant="primary", interactive=True)
458
+
459
  submit_btn.click(fn=remove_label, inputs=[detected_instances], outputs=[detected_instances]).then(fn=count, inputs=[input_image, input_text, exemplar_image, state, device], outputs=[detected_instances, pred_count, submit_btn, step_2, step_3, state])
460
  exemplar_image.change(check_submit_btn, inputs=[exemplar_image, state], outputs=[submit_btn])
461
  with gr.Tab("App", visible=True) as main_app:
462
+
463
  gr.Markdown(
464
  """
465
  # <center>CountGD: Multi-Modal Open-World Counting
 
470
  Limitation: this app does not support fine-grained counting based on attributes or visual grounding inputs yet. Note: if the exemplar and text conflict each other, both will be counted.</center>
471
  """
472
  )
473
+
474
  with gr.Row():
475
  with gr.Column():
476
  input_image_main = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=True)
 
484
  gr.Examples(label="Examples: click on a row to load the example. Add visual exemplars by drawing boxes on the loaded \"Visual Exemplar Image.\"", examples=examples, inputs=[input_image_main, input_text_main, exemplar_image_main])
485
  submit_btn_main.click(fn=remove_label, inputs=[detected_instances_main], outputs=[detected_instances_main]).then(fn=count_main, inputs=[input_image_main, input_text_main, exemplar_image_main, device], outputs=[detected_instances_main, pred_count_main])
486
  clear_btn_main.add([input_image_main, input_text_main, exemplar_image_main, detected_instances_main, pred_count_main])
487
+
488
 
489
  demo.queue().launch(allowed_paths=['back-icon.jpg', 'paste-icon.jpg', 'upload-icon.jpg', 'button-legend.jpg'])