File size: 3,069 Bytes
14bf43e
af9e027
14bf43e
af9e027
 
 
 
 
14bf43e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28f91bd
 
 
14bf43e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37995d7
275ca05
14bf43e
 
e65e9d4
14bf43e
af9e027
14bf43e
 
 
 
 
85808e2
37995d7
 
 
af9e027
37995d7
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import io
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt

from transformers import AutoFeatureExtractor, YolosForObjectDetection
from PIL import Image


COLORS = [
    [0.000, 0.447, 0.741],
    [0.850, 0.325, 0.098],
    [0.929, 0.694, 0.125],
    [0.494, 0.184, 0.556],
    [0.466, 0.674, 0.188],
    [0.301, 0.745, 0.933],
]


def process_class_list(classes_string: str):
    if classes_string == "":
        return []
    classes_list = classes_string.split(",")
    classes_list = [x.strip() for x in classes_list]
    return classes_list


def model_inference(img, prob_threshold, classes_to_show):
    feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/yolos-small-dwr")
    model = YolosForObjectDetection.from_pretrained(f"hustvl/yolos-small-dwr")

    img = Image.fromarray(img)

    pixel_values = feature_extractor(img, return_tensors="pt").pixel_values

    with torch.no_grad():
        outputs = model(pixel_values, output_attentions=True)

    probas = outputs.logits.softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > prob_threshold

    target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
    postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
    bboxes_scaled = postprocessed_outputs[0]["boxes"]

    classes_list = process_class_list(classes_to_show)
    res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)

    return res_img


def plot_results(pil_img, prob, boxes, model, classes_list):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        cl = p.argmax()
        object_class = model.config.id2label[cl.item()]

        if len(classes_list) > 0:
            if object_class not in classes_list:
                continue

        ax.add_patch(
            plt.Rectangle(
                (xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3
            )
        )
        text = f"{object_class}: {p[cl]:0.2f}"
        ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
    plt.axis("off")
    return fig2img(plt.gcf())


def fig2img(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img


description = """Object Detection"""
title = """Object Detection"""
image_in = gr.components.Image(label="Upload an image")
image_out = gr.components.Image()
prob_threshold_slider = gr.components.Slider(
    minimum=0, maximum=1.0, step=0.01, value=0.7, label="Probability Threshold"
)

classes_to_show = gr.components.Textbox(
    placeholder="e.g. car, dog",
    label="Classes to filter (leave empty to detect all classes)",
)

gr.Interface(fn=model_inference,
             inputs=[image_in, prob_threshold_slider, classes_to_show],
             outputs=image_out,
             title=title,
             examples=["00_plane.jpg", "01_car.jpg"],
             description=description).launch()