File size: 3,225 Bytes
14bf43e
2280f66
af9e027
14bf43e
af9e027
 
 
 
 
14bf43e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28f91bd
 
 
14bf43e
 
 
c0f0f0d
14bf43e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc1d35d
 
43005c7
2280f66
2d1fe2b
2280f66
3751bab
 
 
2280f66
 
0aed315
1436621
85808e2
cc1d35d
37995d7
 
1436621
2d1fe2b
d3e12ae
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import io
import os
import torch
import numpy as np
import gradio as gr
import matplotlib.pyplot as plt

from transformers import AutoFeatureExtractor, YolosForObjectDetection
from PIL import Image


COLORS = [
    [0.000, 0.447, 0.741],
    [0.850, 0.325, 0.098],
    [0.929, 0.694, 0.125],
    [0.494, 0.184, 0.556],
    [0.466, 0.674, 0.188],
    [0.301, 0.745, 0.933],
]


def process_class_list(classes_string: str):
    if classes_string == "":
        return []
    classes_list = classes_string.split(",")
    classes_list = [x.strip() for x in classes_list]
    return classes_list


def model_inference(img, prob_threshold, classes_to_show):
    feature_extractor = AutoFeatureExtractor.from_pretrained(f"hustvl/yolos-small-dwr")
    model = YolosForObjectDetection.from_pretrained(f"hustvl/yolos-small-dwr")

    img = Image.fromarray(img)

    pixel_values = feature_extractor(img, return_tensors="pt").pixel_values

    with torch.no_grad():
        outputs = model(pixel_values, output_attentions=True)

    probas = outputs.logits.softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > prob_threshold

    target_sizes = torch.tensor(img.size[::-1]).unsqueeze(0)
    postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
    bboxes_scaled = postprocessed_outputs[0]["boxes"]

    classes_list = process_class_list(classes_to_show)
    res_img = plot_results(img, probas[keep], bboxes_scaled[keep], model, classes_list)

    return res_img


def plot_results(pil_img, prob, boxes, model, classes_list):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        cl = p.argmax()
        object_class = model.config.id2label[cl.item()]

        if len(classes_list) > 0:
            if object_class not in classes_list:
                continue

        ax.add_patch(
            plt.Rectangle(
                (xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=c, linewidth=3
            )
        )
        text = f"{object_class}: {p[cl]:0.2f}"
        ax.text(xmin, ymin, text, fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
    plt.axis("off")
    return fig2img(plt.gcf())


def fig2img(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img


description = """Upload an image and get the predicted classes"""
title = """Object Detection"""

# Create examples list from "examples/" directory
# example_list = [["examples/" + example] for example in os.listdir("examples")]


image_in = gr.components.Image(label="Upload an image")
image_out = gr.components.Image()
classes_to_show = gr.components.Textbox(placeholder="e.g. car, dog", label="Classes to filter (leave empty to detect all classes)")
prob_threshold_slider = gr.components.Slider(minimum=0, maximum=1.0, step=0.01, value=0.7, label="Probability Threshold")
inputs = [image_in, prob_threshold_slider, classes_to_show]

gr.Interface(fn=model_inference,
             inputs=inputs,
             outputs=image_out,
             title=title,
             description=description,
             # examples=example_list
).launch()