File size: 1,784 Bytes
6784e11
9d9db66
adca989
6784e11
 
9d9db66
 
6784e11
3602b81
5320025
6784e11
9d9db66
 
 
 
 
 
adca989
 
 
6784e11
 
9d9db66
 
6784e11
adca989
 
9d9db66
 
6784e11
4199c36
97f1b93
4199c36
 
 
 
 
29e68f7
 
57c05d4
adca989
68cda73
6784e11
 
 
bc50303
68cda73
6784e11
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
import gradio as gr
from transformers import AutoImageProcessor, AutoModelForObjectDetection
from PIL import Image, ImageDraw
import torch

image_processor = AutoImageProcessor.from_pretrained('hustvl/yolos-small')
model = AutoModelForObjectDetection.from_pretrained('hustvl/yolos-small')

def detect(image): 
    inputs = image_processor(images=image, return_tensors="pt")
    outputs = model(**inputs)

    # convert outputs to COCO API
    target_sizes = torch.tensor([image.size[::-1]])
    results = image_processor.post_process_object_detection(outputs,
                                                            threshold=0.9,
                                                            target_sizes=target_sizes)[0]

    # Bounding box in COCO format:
    # [x_min, y_min, width, height]
    
    # model predicts bounding boxes and corresponding COCO classes
    #logits = outputs.logits
    #bboxes = outputs.pred_boxes

    draw = ImageDraw.Draw(image)
    
    # label and the count
    counts = {}

    for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
        box = [round(i, 4) for i in box.tolist()]
        label_name = model.config.id2label[label.item()]
        if label_name not in counts:
            counts[label_name] = 0
        counts[label_name] += 1

        x1, y1, x2, y2 = tuple(box)
        draw.rectangle((x1, y1, x2, y2), outline="red", width=3)
        draw.text((x1, y2), label_name, fill="white")
        
    return results["labels"], results["scores"], results["boxes"], image

demo = gr.Interface(
    fn=detect,
    inputs=[gr.inputs.Image(label="Input image", type="pil")],
    outputs=["text", "text", "text", "image"], #, gr.Label(num_top_classes=10)],
    title="Object Counts in Image"
)

demo.launch()