File size: 2,482 Bytes
6784e11
922d2f8
9d9db66
adca989
6784e11
 
9d9db66
 
6784e11
22c47ed
 
 
 
 
 
 
 
 
 
 
 
3602b81
5320025
6784e11
9d9db66
 
 
 
 
 
adca989
 
 
9d9db66
 
6784e11
22c47ed
4199c36
 
 
 
 
582c0ea
22c47ed
 
 
 
 
 
 
 
 
 
 
 
922d2f8
 
22c47ed
 
922d2f8
22c47ed
 
6784e11
 
 
f2e12a9
9761ba8
83aa999
40b9503
6784e11
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import gradio as gr
import pandas as pd
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from PIL import Image, ImageDraw
import torch

image_processor = AutoImageProcessor.from_pretrained('hustvl/yolos-small')
model = AutoModelForObjectDetection.from_pretrained('hustvl/yolos-small')

colors = ["red",
          "orange",
          "yellow",
          "green",
          "blue",
          "indigo",
          "violet",
          "brown",
          "black",
          "slategray",
         ]

def detect(image): 
    inputs = image_processor(images=image, return_tensors="pt")
    outputs = model(**inputs)

    # convert outputs to COCO API
    target_sizes = torch.tensor([image.size[::-1]])
    results = image_processor.post_process_object_detection(outputs,
                                                            threshold=0.9,
                                                            target_sizes=target_sizes)[0]

    draw = ImageDraw.Draw(image)
    
    # label and the count
    counts = {}

    for score, label in zip(results["scores"], results["labels"]):
        label_name = model.config.id2label[label.item()]
        if label_name not in counts:
            counts[label_name] = 0
        counts[label_name] += 1

    count_results = {k: v for k, v in (sorted(counts.items(), key=lambda item: item[1], reverse=True)[:10])}
    label2color = {}
    for idx, label in enumerate(count_results):
        label2color[label] = colors[idx]

    for label, box in zip(results["labels"], results["boxes"]):
        label_name = model.config.id2label[label.item()]

        if label_name in count_results:
            box = [round(i, 4) for i in box.tolist()]
            x1, y1, x2, y2 = tuple(box)
            draw.rectangle((x1, y1, x2, y2), outline=label2color[label_name], width=2)
            draw.text((x1, y1), label_name, fill="white")

    df = pd.DataFrame({
        'label': [label for label in count_results],
        'counts': [counts[label] for label in count_results]
    })
    
    return image, df, count_results

demo = gr.Interface(
    fn=detect,
    examples=["examples/football.jpg", "examples/cats.jpg"],
    inputs=[gr.inputs.Image(label="Input image", type="pil", shape=(400, 360))],
    outputs=[gr.Image(label="Output image"), gr.BarPlot(show_label=False, x="label", y="counts", x_title="Labels", y_title="Counts", vertical=False), gr.Textbox(show_label=False)],
    title="YOLO Object Detection",
)

demo.launch()