File size: 5,200 Bytes
f0c408b
 
725f958
b577d3a
 
 
f0c408b
b577d3a
 
f0c408b
725f958
 
 
f0c408b
33e6030
f0c408b
 
725f958
 
b577d3a
 
 
f0c408b
b577d3a
 
 
f0c408b
 
 
b577d3a
725f958
b577d3a
 
 
 
f0c408b
 
b577d3a
725f958
b577d3a
 
 
 
 
 
 
 
f0c408b
 
 
 
 
725f958
 
b577d3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de55408
 
f0c408b
 
725f958
 
b577d3a
 
 
 
 
 
 
 
 
 
f0c408b
 
b577d3a
 
 
 
 
 
 
f0c408b
b577d3a
f0c408b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b577d3a
 
 
 
 
 
 
 
 
 
 
725f958
b577d3a
 
 
725f958
f0c408b
 
 
 
 
725f958
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from typing import Tuple

import gradio as gr
import numpy as np
import supervision as sv
import torch
from PIL import Image
from transformers import SamModel, SamProcessor

from utils.efficient_sam import load, inference_with_box

MARKDOWN = """
# EfficientSAM sv. SAM

This is a demo for ⚔️ SAM Battlegrounds - a speed and accuracy comparison between 
[EfficientSAM](https://arxiv.org/abs/2312.00863) and 
[SAM](https://arxiv.org/abs/2304.02643).
"""

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SAM_MODEL = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
SAM_PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-huge")
EFFICIENT_SAM_MODEL = load(device=DEVICE)
MASK_ANNOTATOR = sv.MaskAnnotator(
    color=sv.Color.red(),
    color_lookup=sv.ColorLookup.INDEX)
BOX_ANNOTATOR = sv.BoundingBoxAnnotator(
    color=sv.Color.red(),
    color_lookup=sv.ColorLookup.INDEX)


def annotate_image(image: np.ndarray, detections: sv.Detections) -> np.ndarray:
    bgr_image = image[:, :, ::-1]
    annotated_bgr_image = MASK_ANNOTATOR.annotate(
        scene=bgr_image, detections=detections)
    annotated_bgr_image = BOX_ANNOTATOR.annotate(
        scene=annotated_bgr_image, detections=detections)
    return annotated_bgr_image[:, :, ::-1]


def efficient_sam_inference(
    image: np.ndarray,
    x_min: int,
    y_min: int,
    x_max: int,
    y_max: int
) -> np.ndarray:
    box = np.array([[x_min, y_min], [x_max, y_max]])
    mask = inference_with_box(image, box, EFFICIENT_SAM_MODEL, DEVICE)
    mask = mask[np.newaxis, ...]
    detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
    return annotate_image(image=image, detections=detections)


def sam_inference(
    image: np.ndarray,
    x_min: int,
    y_min: int,
    x_max: int,
    y_max: int
) -> np.ndarray:
    input_boxes = [[[x_min, y_min, x_max, y_max]]]
    inputs = SAM_PROCESSOR(
        Image.fromarray(image),
        input_boxes=[input_boxes],
        return_tensors="pt"
    ).to(DEVICE)

    with torch.no_grad():
        outputs = SAM_MODEL(**inputs)

    mask = SAM_PROCESSOR.image_processor.post_process_masks(
        outputs.pred_masks.cpu(),
        inputs["original_sizes"].cpu(),
        inputs["reshaped_input_sizes"].cpu()
    )[0][0][0].numpy()
    mask = mask[np.newaxis, ...]
    detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
    return annotate_image(image=image, detections=detections)


def inference(
    image: np.ndarray,
    x_min: int,
    y_min: int,
    x_max: int,
    y_max: int
) -> Tuple[np.ndarray, np.ndarray]:
    return (
        efficient_sam_inference(image, x_min, y_min, x_max, y_max),
        sam_inference(image, x_min, y_min, x_max, y_max)
    )


def clear(_: np.ndarray) -> Tuple[None, None]:
    return None, None


with gr.Blocks() as demo:
    gr.Markdown(MARKDOWN)
    with gr.Tab(label="Box prompt"):
        with gr.Row():
            with gr.Column():
                input_image = gr.Image()
                with gr.Accordion(label="Box", open=False):
                    with gr.Row():
                        x_min_number = gr.Number(label="x_min")
                        y_min_number = gr.Number(label="y_min")
                        x_max_number = gr.Number(label="x_max")
                        y_max_number = gr.Number(label="y_max")
            efficient_sam_output_image = gr.Image(label="EfficientSAM")
            sam_output_image = gr.Image(label="SAM")
        with gr.Row():
            submit_button = gr.Button("Submit")

        gr.Examples(
            fn=inference,
            examples=[
                [
                    'https://media.roboflow.com/efficient-sam/beagle.jpeg',
                    69,
                    26,
                    625,
                    704
                ],
                [
                    'https://media.roboflow.com/efficient-sam/corgi.jpg',
                    801,
                    510,
                    1782,
                    993
                ],
                [
                    'https://media.roboflow.com/efficient-sam/horses.jpg',
                    814,
                    696,
                    1523,
                    1183
                ],
                [
                    'https://media.roboflow.com/efficient-sam/bears.jpg',
                    653,
                    874,
                    1173,
                    1229
                ]
            ],
            inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
            outputs=[efficient_sam_output_image, sam_output_image],
        )

    submit_button.click(
        efficient_sam_inference,
        inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
        outputs=efficient_sam_output_image
    )
    submit_button.click(
        sam_inference,
        inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
        outputs=sam_output_image
    )
    input_image.change(
        clear,
        inputs=input_image,
        outputs=[efficient_sam_output_image, sam_output_image]
    )

demo.launch(debug=False, show_error=True)