Spaces:
Runtime error
Runtime error
fix postprocessing
Browse files
app.py
CHANGED
@@ -24,13 +24,12 @@ def query_image(img, text_queries, score_threshold):
|
|
24 |
with torch.no_grad():
|
25 |
outputs = model(**inputs)
|
26 |
|
27 |
-
target_sizes = torch.Tensor([[
|
28 |
outputs.logits = outputs.logits.cpu()
|
29 |
outputs.pred_boxes = outputs.pred_boxes.cpu()
|
30 |
results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
|
31 |
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
|
32 |
|
33 |
-
img = cv2.resize(img, (768, 768), interpolation = cv2.INTER_AREA)
|
34 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
35 |
|
36 |
for box, score, label in zip(boxes, scores, labels):
|
@@ -60,7 +59,7 @@ can also use the score threshold slider to set a threshold to filter out low pro
|
|
60 |
"""
|
61 |
demo = gr.Interface(
|
62 |
query_image,
|
63 |
-
inputs=[gr.Image(
|
64 |
outputs="image",
|
65 |
title="Zero-Shot Object Detection with OWL-ViT",
|
66 |
description=description,
|
|
|
24 |
with torch.no_grad():
|
25 |
outputs = model(**inputs)
|
26 |
|
27 |
+
target_sizes = torch.Tensor([img.shape[:2]])
|
28 |
outputs.logits = outputs.logits.cpu()
|
29 |
outputs.pred_boxes = outputs.pred_boxes.cpu()
|
30 |
results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
|
31 |
boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
|
32 |
|
|
|
33 |
font = cv2.FONT_HERSHEY_SIMPLEX
|
34 |
|
35 |
for box, score, label in zip(boxes, scores, labels):
|
|
|
59 |
"""
|
60 |
demo = gr.Interface(
|
61 |
query_image,
|
62 |
+
inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)],
|
63 |
outputs="image",
|
64 |
title="Zero-Shot Object Detection with OWL-ViT",
|
65 |
description=description,
|