Spaces:
Sleeping
Sleeping
File size: 3,267 Bytes
14bf43e 2280f66 af9e027 14bf43e af9e027 14bf43e 28f91bd 14bf43e c0f0f0d 14bf43e dd04c12 14bf43e 1cb1590 cc1d35d 43005c7 2280f66 2d1fe2b afb36e0 3751bab afb36e0 3751bab 2280f66 0aed315 fc6f78e 1436621 85808e2 d03448c 37995d7 1436621 afb36e0 d3e12ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import io
import os
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt
from transformers import AutoFeatureExtractor, YolosForObjectDetection
from PIL import Image
COLORS = [
[0.000, 0.447, 0.741],
[0.850, 0.325, 0.098],
[0.929, 0.694, 0.125],
[0.494, 0.184, 0.556],
[0.466, 0.674, 0.188],
[0.301, 0.745, 0.933],
]
def process_class_list(classes_string: str):
if classes_string == "":
return []
classes_list = classes_string.split(",")
classes_list = [x.strip() for x in classes_list]
return classes_list
def model_inference(img, prob_threshold, classes_to_show):
feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/yolos-small-dwr")
model = YolosForObjectDetection.from_pretrained(f"hustvl/yolos-small-dwr")
img = Image.fromarray(img)
pixel_values = feature_extractor(img, return_tensors="pt").pixel_values
with torch.no_grad():
outputs = model(pixel_values, output_attentions=True)
probas = outputs.logits.softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > prob_threshold
target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
bboxes_scaled = postprocessed_outputs[0]["boxes"]
classes_list = process_class_list(classes_to_show)
res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)
return res_img
def plot_results(pil_img, prob, boxes, model, classes_list):
plt.figure(figsize=(16, 10))
plt.imshow(pil_img)
ax = plt.gca()
colors = COLORS * 100
for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
cl = p.argmax()
object_class = model.config.id2label[cl.item()]
if len(classes_list) > 0:
if object_class not in classes_list:
continue
ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
text = f"{object_class}: {p[cl]:0.2f}"
ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
plt.axis("off")
return fig2img(plt.gcf())
def fig2img(fig):
buf = io.BytesIO()
fig.savefig(buf)
buf.seek(0)
img = Image.open(buf)
return img
description = """Upload an image and get the detected classes"""
title = """Object Detection"""
# Create examples list from "examples/" directory
# example_list = [["examples/" + example] for example in os.listdir("examples")]
# example_list = [["carplane.webp"]]
image_in = gr.components.Image(label="Upload an image")
image_out = gr.components.Image()
classes_to_show = gr.components.Textbox(placeholder="e.g. car, dog", label="Classes to filter (leave empty to detect all classes)")
prob_threshold_slider = gr.components.Slider(minimum=0, maximum=1.0, step=0.01, value=0.7, label="Probability Threshold")
inputs = [image_in, prob_threshold_slider, classes_to_show]
# gr.Examples([['carplane.webp'], ['CTH.png']], inputs=image_in)
gr.Interface(fn=model_inference,
inputs=inputs,
outputs=image_out,
title=title,
description=description,
# examples=example_list
).launch()
|