File size: 3,267 Bytes
14bf43e
2280f66
af9e027
14bf43e
af9e027
 
 
 
 
14bf43e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28f91bd
 
 
14bf43e
c0f0f0d
14bf43e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dd04c12
14bf43e
 
 
 
 
 
 
 
 
 
 
 
 
 
1cb1590
cc1d35d
43005c7
2280f66
2d1fe2b
afb36e0
3751bab
afb36e0
3751bab
2280f66
 
0aed315
fc6f78e
1436621
85808e2
d03448c
37995d7
 
1436621
afb36e0
d3e12ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import io
import os
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt

from transformers import AutoFeatureExtractor, YolosForObjectDetection
from PIL import Image


COLORS = [
    [0.000, 0.447, 0.741],
    [0.850, 0.325, 0.098],
    [0.929, 0.694, 0.125],
    [0.494, 0.184, 0.556],
    [0.466, 0.674, 0.188],
    [0.301, 0.745, 0.933],
]


def process_class_list(classes_string: str):
    if classes_string == "":
        return []
    classes_list = classes_string.split(",")
    classes_list = [x.strip() for x in classes_list]
    return classes_list


def model_inference(img, prob_threshold, classes_to_show):
    feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/yolos-small-dwr")
    model = YolosForObjectDetection.from_pretrained(f"hustvl/yolos-small-dwr")
    img = Image.fromarray(img)
    pixel_values = feature_extractor(img, return_tensors="pt").pixel_values

    with torch.no_grad():
        outputs = model(pixel_values, output_attentions=True)

    probas = outputs.logits.softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > prob_threshold
    target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
    postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
    bboxes_scaled = postprocessed_outputs[0]["boxes"]
    classes_list = process_class_list(classes_to_show)
    res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)

    return res_img


def plot_results(pil_img, prob, boxes, model, classes_list):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        cl = p.argmax()
        object_class = model.config.id2label[cl.item()]
        if len(classes_list) > 0:
            if object_class not in classes_list:
                continue
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3))
        text = f"{object_class}: {p[cl]:0.2f}"
        ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
    plt.axis("off")
    return fig2img(plt.gcf())


def fig2img(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img


description = """Upload an image and get the detected classes"""
title = """Object Detection"""

# Create examples list from "examples/" directory
# example_list = [["examples/" + example] for example in os.listdir("examples")]
# example_list = [["carplane.webp"]]

image_in = gr.components.Image(label="Upload an image")
image_out = gr.components.Image()
classes_to_show = gr.components.Textbox(placeholder="e.g. car, dog", label="Classes to filter (leave empty to detect all classes)")
prob_threshold_slider = gr.components.Slider(minimum=0, maximum=1.0, step=0.01, value=0.7, label="Probability Threshold")
inputs = [image_in, prob_threshold_slider, classes_to_show]
# gr.Examples([['carplane.webp'], ['CTH.png']], inputs=image_in)

gr.Interface(fn=model_inference,
             inputs=inputs,
             outputs=image_out,
             title=title,
             description=description,
             # examples=example_list
).launch()