Spaces:
Running
on
T4
Running
on
T4
Prasanna Sridhar
commited on
Commit
·
346623e
1
Parent(s):
c469934
remove unused imports
Browse files
app.py
CHANGED
@@ -1,16 +1,10 @@
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
3 |
-
import copy
|
4 |
import random
|
5 |
import torch
|
6 |
-
import
|
7 |
-
from PIL import Image, ImageDraw, ImageFont
|
8 |
-
import torchvision.transforms.functional as F
|
9 |
import numpy as np
|
10 |
import argparse
|
11 |
-
import json
|
12 |
-
import plotly.express as px
|
13 |
-
import pandas as pd
|
14 |
from util.slconfig import SLConfig, DictAction
|
15 |
from util.misc import nested_tensor_from_tensor_list
|
16 |
import datasets.transforms as T
|
@@ -258,14 +252,14 @@ def get_ind_to_filter(text, word_ids, keywords):
|
|
258 |
def count(image, text, prompts, state, device):
|
259 |
|
260 |
keywords = "" # do not handle this for now
|
261 |
-
|
262 |
# Handle no prompt case.
|
263 |
if prompts is None:
|
264 |
prompts = {"image": image, "points": []}
|
265 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
266 |
input_image = input_image.unsqueeze(0).to(device)
|
267 |
exemplars = get_box_inputs(prompts["points"])
|
268 |
-
|
269 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
270 |
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
|
271 |
exemplars = [exemplars["exemplars"].to(device)]
|
@@ -278,7 +272,7 @@ def count(image, text, prompts, state, device):
|
|
278 |
[torch.tensor([0]).to(device) for _ in range(len(input_image))],
|
279 |
captions=[text + " ."] * len(input_image),
|
280 |
)
|
281 |
-
|
282 |
ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
|
283 |
logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
|
284 |
boxes = model_output["pred_boxes"][0]
|
@@ -288,7 +282,7 @@ def count(image, text, prompts, state, device):
|
|
288 |
box_mask = logits.max(dim=-1).values > CONF_THRESH
|
289 |
logits = logits[box_mask, :].cpu().numpy()
|
290 |
boxes = boxes[box_mask, :].cpu().numpy()
|
291 |
-
|
292 |
# Plot results.
|
293 |
(w, h) = image.size
|
294 |
det_map = np.zeros((h, w))
|
@@ -327,7 +321,7 @@ def count(image, text, prompts, state, device):
|
|
327 |
if len(text.strip()) > 0:
|
328 |
out_label += " text"
|
329 |
if exemplars[0].size()[0] == 1:
|
330 |
-
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
|
331 |
elif exemplars[0].size()[0] > 1:
|
332 |
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
|
333 |
else:
|
@@ -339,7 +333,7 @@ def count(image, text, prompts, state, device):
|
|
339 |
out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
|
340 |
else:
|
341 |
out_label = "Nothing specified to detect."
|
342 |
-
|
343 |
return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
|
344 |
|
345 |
@spaces.GPU
|
@@ -351,11 +345,11 @@ def count_main(image, text, prompts, device):
|
|
351 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
352 |
input_image = input_image.unsqueeze(0).to(device)
|
353 |
exemplars = get_box_inputs(prompts["points"])
|
354 |
-
|
355 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
356 |
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
|
357 |
exemplars = [exemplars["exemplars"].to(device)]
|
358 |
-
|
359 |
with torch.no_grad():
|
360 |
model_output = model(
|
361 |
nested_tensor_from_tensor_list(input_image),
|
@@ -364,7 +358,7 @@ def count_main(image, text, prompts, device):
|
|
364 |
[torch.tensor([0]).to(device) for _ in range(len(input_image))],
|
365 |
captions=[text + " ."] * len(input_image),
|
366 |
)
|
367 |
-
|
368 |
ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
|
369 |
logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
|
370 |
boxes = model_output["pred_boxes"][0]
|
@@ -374,7 +368,7 @@ def count_main(image, text, prompts, device):
|
|
374 |
box_mask = logits.max(dim=-1).values > CONF_THRESH
|
375 |
logits = logits[box_mask, :].cpu().numpy()
|
376 |
boxes = boxes[box_mask, :].cpu().numpy()
|
377 |
-
|
378 |
# Plot results.
|
379 |
(w, h) = image.size
|
380 |
det_map = np.zeros((h, w))
|
@@ -395,7 +389,7 @@ def count_main(image, text, prompts, device):
|
|
395 |
if len(text.strip()) > 0:
|
396 |
out_label += " text"
|
397 |
if exemplars[0].size()[0] == 1:
|
398 |
-
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
|
399 |
elif exemplars[0].size()[0] > 1:
|
400 |
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
|
401 |
else:
|
@@ -407,7 +401,7 @@ def count_main(image, text, prompts, device):
|
|
407 |
out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
|
408 |
else:
|
409 |
out_label = "Nothing specified to detect."
|
410 |
-
|
411 |
return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]))
|
412 |
|
413 |
def remove_label(image):
|
@@ -452,20 +446,20 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
|
|
452 |
with gr.Tab("Step 1", visible=True) as step_1:
|
453 |
input_image = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=False, width="30vw")
|
454 |
gr.Markdown('# Click "Count" to count the strawberries.')
|
455 |
-
|
456 |
with gr.Column():
|
457 |
with gr.Tab("Output Image"):
|
458 |
detected_instances = gr.Image(label="Detected Instances", show_label='True', interactive=False, visible=True, width="40vw")
|
459 |
-
|
460 |
with gr.Row():
|
461 |
input_text = gr.Textbox(label="What would you like to count?", value="strawberry", interactive=True)
|
462 |
pred_count = gr.Number(label="Predicted Count", visible=False)
|
463 |
submit_btn = gr.Button("Count", variant="primary", interactive=True)
|
464 |
-
|
465 |
submit_btn.click(fn=remove_label, inputs=[detected_instances], outputs=[detected_instances]).then(fn=count, inputs=[input_image, input_text, exemplar_image, state, device], outputs=[detected_instances, pred_count, submit_btn, step_2, step_3, state])
|
466 |
exemplar_image.change(check_submit_btn, inputs=[exemplar_image, state], outputs=[submit_btn])
|
467 |
with gr.Tab("App", visible=True) as main_app:
|
468 |
-
|
469 |
gr.Markdown(
|
470 |
"""
|
471 |
# <center>CountGD: Multi-Modal Open-World Counting
|
@@ -476,7 +470,7 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
|
|
476 |
Limitation: this app does not support fine-grained counting based on attributes or visual grounding inputs yet. Note: if the exemplar and text conflict each other, both will be counted.</center>
|
477 |
"""
|
478 |
)
|
479 |
-
|
480 |
with gr.Row():
|
481 |
with gr.Column():
|
482 |
input_image_main = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=True)
|
@@ -490,6 +484,6 @@ with gr.Blocks(title="CountGD: Multi-Modal Open-World Counting", theme="soft", h
|
|
490 |
gr.Examples(label="Examples: click on a row to load the example. Add visual exemplars by drawing boxes on the loaded \"Visual Exemplar Image.\"", examples=examples, inputs=[input_image_main, input_text_main, exemplar_image_main])
|
491 |
submit_btn_main.click(fn=remove_label, inputs=[detected_instances_main], outputs=[detected_instances_main]).then(fn=count_main, inputs=[input_image_main, input_text_main, exemplar_image_main, device], outputs=[detected_instances_main, pred_count_main])
|
492 |
clear_btn_main.add([input_image_main, input_text_main, exemplar_image_main, detected_instances_main, pred_count_main])
|
493 |
-
|
494 |
|
495 |
demo.queue().launch(allowed_paths=['back-icon.jpg', 'paste-icon.jpg', 'upload-icon.jpg', 'button-legend.jpg'])
|
|
|
1 |
import spaces
|
2 |
import gradio as gr
|
|
|
3 |
import random
|
4 |
import torch
|
5 |
+
from PIL import Image
|
|
|
|
|
6 |
import numpy as np
|
7 |
import argparse
|
|
|
|
|
|
|
8 |
from util.slconfig import SLConfig, DictAction
|
9 |
from util.misc import nested_tensor_from_tensor_list
|
10 |
import datasets.transforms as T
|
|
|
252 |
def count(image, text, prompts, state, device):
|
253 |
|
254 |
keywords = "" # do not handle this for now
|
255 |
+
|
256 |
# Handle no prompt case.
|
257 |
if prompts is None:
|
258 |
prompts = {"image": image, "points": []}
|
259 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
260 |
input_image = input_image.unsqueeze(0).to(device)
|
261 |
exemplars = get_box_inputs(prompts["points"])
|
262 |
+
|
263 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
264 |
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
|
265 |
exemplars = [exemplars["exemplars"].to(device)]
|
|
|
272 |
[torch.tensor([0]).to(device) for _ in range(len(input_image))],
|
273 |
captions=[text + " ."] * len(input_image),
|
274 |
)
|
275 |
+
|
276 |
ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
|
277 |
logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
|
278 |
boxes = model_output["pred_boxes"][0]
|
|
|
282 |
box_mask = logits.max(dim=-1).values > CONF_THRESH
|
283 |
logits = logits[box_mask, :].cpu().numpy()
|
284 |
boxes = boxes[box_mask, :].cpu().numpy()
|
285 |
+
|
286 |
# Plot results.
|
287 |
(w, h) = image.size
|
288 |
det_map = np.zeros((h, w))
|
|
|
321 |
if len(text.strip()) > 0:
|
322 |
out_label += " text"
|
323 |
if exemplars[0].size()[0] == 1:
|
324 |
+
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
|
325 |
elif exemplars[0].size()[0] > 1:
|
326 |
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
|
327 |
else:
|
|
|
333 |
out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
|
334 |
else:
|
335 |
out_label = "Nothing specified to detect."
|
336 |
+
|
337 |
return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]), new_submit_btn, gr.Tab(visible=True), step_3, state)
|
338 |
|
339 |
@spaces.GPU
|
|
|
345 |
input_image, _ = transform(image, {"exemplars": torch.tensor([])})
|
346 |
input_image = input_image.unsqueeze(0).to(device)
|
347 |
exemplars = get_box_inputs(prompts["points"])
|
348 |
+
|
349 |
input_image_exemplars, exemplars = transform(prompts["image"], {"exemplars": torch.tensor(exemplars)})
|
350 |
input_image_exemplars = input_image_exemplars.unsqueeze(0).to(device)
|
351 |
exemplars = [exemplars["exemplars"].to(device)]
|
352 |
+
|
353 |
with torch.no_grad():
|
354 |
model_output = model(
|
355 |
nested_tensor_from_tensor_list(input_image),
|
|
|
358 |
[torch.tensor([0]).to(device) for _ in range(len(input_image))],
|
359 |
captions=[text + " ."] * len(input_image),
|
360 |
)
|
361 |
+
|
362 |
ind_to_filter = get_ind_to_filter(text, model_output["token"][0].word_ids, keywords)
|
363 |
logits = model_output["pred_logits"].sigmoid()[0][:, ind_to_filter]
|
364 |
boxes = model_output["pred_boxes"][0]
|
|
|
368 |
box_mask = logits.max(dim=-1).values > CONF_THRESH
|
369 |
logits = logits[box_mask, :].cpu().numpy()
|
370 |
boxes = boxes[box_mask, :].cpu().numpy()
|
371 |
+
|
372 |
# Plot results.
|
373 |
(w, h) = image.size
|
374 |
det_map = np.zeros((h, w))
|
|
|
389 |
if len(text.strip()) > 0:
|
390 |
out_label += " text"
|
391 |
if exemplars[0].size()[0] == 1:
|
392 |
+
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplar."
|
393 |
elif exemplars[0].size()[0] > 1:
|
394 |
out_label += " and " + str(exemplars[0].size()[0]) + " visual exemplars."
|
395 |
else:
|
|
|
401 |
out_label += " " + str(exemplars[0].size()[0]) + " visual exemplars."
|
402 |
else:
|
403 |
out_label = "Nothing specified to detect."
|
404 |
+
|
405 |
return (gr.Image(output_img, visible=True, label=out_label, show_label=True), gr.Number(label="Predicted Count", visible=True, value=boxes.shape[0]))
|
406 |
|
407 |
def remove_label(image):
|
|
|
446 |
with gr.Tab("Step 1", visible=True) as step_1:
|
447 |
input_image = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=False, width="30vw")
|
448 |
gr.Markdown('# Click "Count" to count the strawberries.')
|
449 |
+
|
450 |
with gr.Column():
|
451 |
with gr.Tab("Output Image"):
|
452 |
detected_instances = gr.Image(label="Detected Instances", show_label='True', interactive=False, visible=True, width="40vw")
|
453 |
+
|
454 |
with gr.Row():
|
455 |
input_text = gr.Textbox(label="What would you like to count?", value="strawberry", interactive=True)
|
456 |
pred_count = gr.Number(label="Predicted Count", visible=False)
|
457 |
submit_btn = gr.Button("Count", variant="primary", interactive=True)
|
458 |
+
|
459 |
submit_btn.click(fn=remove_label, inputs=[detected_instances], outputs=[detected_instances]).then(fn=count, inputs=[input_image, input_text, exemplar_image, state, device], outputs=[detected_instances, pred_count, submit_btn, step_2, step_3, state])
|
460 |
exemplar_image.change(check_submit_btn, inputs=[exemplar_image, state], outputs=[submit_btn])
|
461 |
with gr.Tab("App", visible=True) as main_app:
|
462 |
+
|
463 |
gr.Markdown(
|
464 |
"""
|
465 |
# <center>CountGD: Multi-Modal Open-World Counting
|
|
|
470 |
Limitation: this app does not support fine-grained counting based on attributes or visual grounding inputs yet. Note: if the exemplar and text conflict each other, both will be counted.</center>
|
471 |
"""
|
472 |
)
|
473 |
+
|
474 |
with gr.Row():
|
475 |
with gr.Column():
|
476 |
input_image_main = gr.Image(type='pil', label='Input Image', show_label='True', value="strawberry.jpg", interactive=True)
|
|
|
484 |
gr.Examples(label="Examples: click on a row to load the example. Add visual exemplars by drawing boxes on the loaded \"Visual Exemplar Image.\"", examples=examples, inputs=[input_image_main, input_text_main, exemplar_image_main])
|
485 |
submit_btn_main.click(fn=remove_label, inputs=[detected_instances_main], outputs=[detected_instances_main]).then(fn=count_main, inputs=[input_image_main, input_text_main, exemplar_image_main, device], outputs=[detected_instances_main, pred_count_main])
|
486 |
clear_btn_main.add([input_image_main, input_text_main, exemplar_image_main, detected_instances_main, pred_count_main])
|
487 |
+
|
488 |
|
489 |
demo.queue().launch(allowed_paths=['back-icon.jpg', 'paste-icon.jpg', 'upload-icon.jpg', 'button-legend.jpg'])
|