yuragoithf commited on
Commit
dd04c12
·
1 Parent(s): 1cb1590

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -12
app.py CHANGED
@@ -30,9 +30,7 @@ def process_class_list(classes_string: str):
30
  def model_inference(img, prob_threshold, classes_to_show):
31
  feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/yolos-small-dwr")
32
  model = YolosForObjectDetection.from_pretrained(f"hustvl/yolos-small-dwr")
33
-
34
  img = Image.fromarray(img)
35
-
36
  pixel_values = feature_extractor(img, return_tensors="pt").pixel_values
37
 
38
  with torch.no_grad():
@@ -40,11 +38,9 @@ def model_inference(img, prob_threshold, classes_to_show):
40
 
41
  probas = outputs.logits.softmax(-1)[0, :, :-1]
42
  keep = probas.max(-1).values > prob_threshold
43
-
44
  target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
45
  postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
46
  bboxes_scaled = postprocessed_outputs[0]["boxes"]
47
-
48
  classes_list = process_class_list(classes_to_show)
49
  res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)
50
 
@@ -59,16 +55,10 @@ def plot_results(pil_img, prob, boxes, model, classes_list):
59
  for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
60
  cl = p.argmax()
61
  object_class = model.config.id2label[cl.item()]
62
-
63
  if len(classes_list) > 0:
64
  if object_class not in classes_list:
65
  continue
66
-
67
- ax.add_patch(
68
- plt.Rectangle(
69
- (xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3
70
- )
71
- )
72
  text = f"{object_class}: {p[cl]:0.2f}"
73
  ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
74
  plt.axis("off")
@@ -90,7 +80,7 @@ title = """Object Detection"""
90
  # example_list = [["examples/" + example] for example in os.listdir("examples")]
91
  example_list = [["carplane.webp"]]
92
 
93
- image_in = gr.components.Image(label="Upload an image")
94
  image_out = gr.components.Image()
95
  classes_to_show = gr.components.Textbox(placeholder="e.g. car, dog", label="Classes to filter (leave empty to detect all classes)")
96
  prob_threshold_slider = gr.components.Slider(minimum=0, maximum=1.0, step=0.01, value=0.7, label="Probability Threshold")
 
30
  def model_inference(img, prob_threshold, classes_to_show):
31
  feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/yolos-small-dwr")
32
  model = YolosForObjectDetection.from_pretrained(f"hustvl/yolos-small-dwr")
 
33
  img = Image.fromarray(img)
 
34
  pixel_values = feature_extractor(img, return_tensors="pt").pixel_values
35
 
36
  with torch.no_grad():
 
38
 
39
  probas = outputs.logits.softmax(-1)[0, :, :-1]
40
  keep = probas.max(-1).values > prob_threshold
 
41
  target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
42
  postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
43
  bboxes_scaled = postprocessed_outputs[0]["boxes"]
 
44
  classes_list = process_class_list(classes_to_show)
45
  res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)
46
 
 
55
  for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
56
  cl = p.argmax()
57
  object_class = model.config.id2label[cl.item()]
 
58
  if len(classes_list) > 0:
59
  if object_class not in classes_list:
60
  continue
61
+ ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
 
 
 
 
 
62
  text = f"{object_class}: {p[cl]:0.2f}"
63
  ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
64
  plt.axis("off")
 
80
  # example_list = [["examples/" + example] for example in os.listdir("examples")]
81
  example_list = [["carplane.webp"]]
82
 
83
+ image_in = [gr.components.Image(label="Upload an image")]
84
  image_out = gr.components.Image()
85
  classes_to_show = gr.components.Textbox(placeholder="e.g. car, dog", label="Classes to filter (leave empty to detect all classes)")
86
  prob_threshold_slider = gr.components.Slider(minimum=0, maximum=1.0, step=0.01, value=0.7, label="Probability Threshold")