File size: 3,117 Bytes
14bf43e
af9e027
14bf43e
af9e027
 
 
 
 
14bf43e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28f91bd
 
 
14bf43e
 
 
81c41e8
14bf43e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc1d35d
 
43005c7
275ca05
14bf43e
 
e65e9d4
14bf43e
af9e027
14bf43e
 
 
 
 
0aed315
a0a4808
1436621
85808e2
cc1d35d
37995d7
 
1436621
a0a4808
d3e12ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import io
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt

from transformers import AutoFeatureExtractor, YolosForObjectDetection
from PIL import Image


COLORS = [
    [0.000, 0.447, 0.741],
    [0.850, 0.325, 0.098],
    [0.929, 0.694, 0.125],
    [0.494, 0.184, 0.556],
    [0.466, 0.674, 0.188],
    [0.301, 0.745, 0.933],
]


def process_class_list(classes_string: str):
    if classes_string == "":
        return []
    classes_list = classes_string.split(",")
    classes_list = [x.strip() for x in classes_list]
    return classes_list


def model_inference(img, prob_threshold, classes_to_show):
    feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/yolos-small-dwr")
    model = YolosForObjectDetection.from_pretrained(f"hustvl/yolos-small-dwr")

    img = Image.fromarray(img)

    pixel_values = feature_extractor(img, return_tensors="pil").pixel_values

    with torch.no_grad():
        outputs = model(pixel_values, output_attentions=True)

    probas = outputs.logits.softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > prob_threshold

    target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
    postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
    bboxes_scaled = postprocessed_outputs[0]["boxes"]

    classes_list = process_class_list(classes_to_show)
    res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)

    return res_img


def plot_results(pil_img, prob, boxes, model, classes_list):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        cl = p.argmax()
        object_class = model.config.id2label[cl.item()]

        if len(classes_list) > 0:
            if object_class not in classes_list:
                continue

        ax.add_patch(
            plt.Rectangle(
                (xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3
            )
        )
        text = f"{object_class}: {p[cl]:0.2f}"
        ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
    plt.axis("off")
    return fig2img(plt.gcf())


def fig2img(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img


description = """Upload an image and get the predicted classes"""
title = """Object Detection"""

image_in = gr.components.Image(label="Upload an image")
image_out = gr.components.Image()
prob_threshold_slider = gr.components.Slider(
    minimum=0, maximum=1.0, step=0.01, value=0.7, label="Probability Threshold"
)

classes_to_show = gr.components.Textbox(
    placeholder="e.g. car, dog",
    label="Classes to filter (leave empty to detect all classes)",
)

inputs = [image_in, prob_threshold_slider, classes_to_show]


gr.Interface(fn=model_inference,
             inputs=inputs,
             outputs=image_out,
             title=title,
             description=description,
             examples=["carplane.webp", "CTH.png"]
).launch()