adirik commited on
Commit
ef08078
·
1 Parent(s): 4777db1

fix postprocessing

Browse files
Files changed (1) hide show
  1. app.py +2 -3
app.py CHANGED
@@ -24,13 +24,12 @@ def query_image(img, text_queries, score_threshold):
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
 
27
- target_sizes = torch.Tensor([[768, 768]])
28
  outputs.logits = outputs.logits.cpu()
29
  outputs.pred_boxes = outputs.pred_boxes.cpu()
30
  results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
31
  boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
32
 
33
- img = cv2.resize(img, (768, 768), interpolation = cv2.INTER_AREA)
34
  font = cv2.FONT_HERSHEY_SIMPLEX
35
 
36
  for box, score, label in zip(boxes, scores, labels):
@@ -60,7 +59,7 @@ can also use the score threshold slider to set a threshold to filter out low pro
60
  """
61
  demo = gr.Interface(
62
  query_image,
63
- inputs=[gr.Image(shape=(768, 768)), "text", gr.Slider(0, 1, value=0.1)],
64
  outputs="image",
65
  title="Zero-Shot Object Detection with OWL-ViT",
66
  description=description,
 
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
 
27
+ target_sizes = torch.Tensor([img.shape[:2]])
28
  outputs.logits = outputs.logits.cpu()
29
  outputs.pred_boxes = outputs.pred_boxes.cpu()
30
  results = processor.post_process(outputs=outputs, target_sizes=target_sizes)
31
  boxes, scores, labels = results[0]["boxes"], results[0]["scores"], results[0]["labels"]
32
 
 
33
  font = cv2.FONT_HERSHEY_SIMPLEX
34
 
35
  for box, score, label in zip(boxes, scores, labels):
 
59
  """
60
  demo = gr.Interface(
61
  query_image,
62
+ inputs=[gr.Image(), "text", gr.Slider(0, 1, value=0.1)],
63
  outputs="image",
64
  title="Zero-Shot Object Detection with OWL-ViT",
65
  description=description,