File size: 2,106 Bytes
19a011f
 
 
ddc4ca6
19a011f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ddc4ca6
19a011f
 
 
 
 
ddc4ca6
19a011f
 
 
 
 
 
 
 
 
 
 
 
 
ddc4ca6
19a011f
 
 
 
 
 
 
 
ddc4ca6
19a011f
b37b4db
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import io

import torch
import gradio as gr
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image

from transformers import AutoFeatureExtractor, AutoModelForObjectDetection

extractor = AutoFeatureExtractor.from_pretrained("hustvl/yolos-tiny")
model = AutoModelForObjectDetection.from_pretrained("hustvl/yolos-tiny")

matplotlib.pyplot.switch_backend('Agg') 

COLORS = [
    [0.000, 0.447, 0.741],
    [0.850, 0.325, 0.098],
    [0.929, 0.694, 0.125],
    [0.494, 0.184, 0.556],
    [0.466, 0.674, 0.188],
    [0.301, 0.745, 0.933]
]

PRED_THRESHOLD = 0.90

def fig2img(fig):
    buf = io.BytesIO()
    fig.savefig(buf)
    buf.seek(0)
    img = Image.open(buf)
    return img

def composite_predictions(img, processed_predictions):
    keep = processed_predictions["labels"] == 1 # only interested in people 
    boxes = processed_predictions["boxes"][keep].tolist()
    scores = processed_predictions["scores"][keep].tolist()
    labels = processed_predictions["labels"][keep].tolist()

    labels = [model.config.id2label[x] for x in labels]
    
    plt.figure(figsize=(16, 10))
    plt.imshow(img)
    axis = plt.gca()
    colors = COLORS * 100
    for score, (xmin, ymin, xmax, ymax), label, color in zip(scores, boxes, labels, colors):
        axis.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin, fill=False, color=color, linewidth=3))
        axis.text(xmin, ymin, f"{label}: {score:0.2f}", fontsize=15, bbox=dict(facecolor="yellow", alpha=0.5))
    plt.axis("off")
    img = fig2img(plt.gcf())
    matplotlib.pyplot.close()
    return img

def process(img):
    inputs = extractor(images=img, return_tensors="pt")
    outputs = model(**inputs)
    img_size = torch.tensor([tuple(reversed(img.size))])
    processed = extractor.post_process_object_detection(outputs, PRED_THRESHOLD, img_size)
    
    # Composite image and prediction bounding boxes + labels prediction
    return composite_predictions(img, processed[0])

demo = gr.Interface(fn=process, inputs=[gr.Image(source="webcam", streaming=True, type='pil')], outputs=["image"], live=True)
demo.launch()