Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Point prompt mode ready for review
Browse files- app.py +186 -68
- utils/draw.py +32 -0
- utils/efficient_sam.py +33 -0
app.py
CHANGED
@@ -7,7 +7,8 @@ import torch
|
|
7 |
from PIL import Image
|
8 |
from transformers import SamModel, SamProcessor
|
9 |
|
10 |
-
from utils.efficient_sam import load, inference_with_box
|
|
|
11 |
|
12 |
MARKDOWN = """
|
13 |
# EfficientSAM sv. SAM
|
@@ -17,28 +18,74 @@ This is a demo for ⚔️ SAM Battlegrounds - a speed and accuracy comparison be
|
|
17 |
[SAM](https://arxiv.org/abs/2304.02643).
|
18 |
"""
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
21 |
SAM_MODEL = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
|
22 |
SAM_PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
23 |
EFFICIENT_SAM_MODEL = load(device=DEVICE)
|
24 |
MASK_ANNOTATOR = sv.MaskAnnotator(
|
25 |
-
color=
|
26 |
-
color_lookup=sv.ColorLookup.INDEX)
|
27 |
-
BOX_ANNOTATOR = sv.BoundingBoxAnnotator(
|
28 |
-
color=sv.Color.red(),
|
29 |
color_lookup=sv.ColorLookup.INDEX)
|
30 |
|
31 |
|
32 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
bgr_image = image[:, :, ::-1]
|
34 |
annotated_bgr_image = MASK_ANNOTATOR.annotate(
|
35 |
scene=bgr_image, detections=detections)
|
36 |
-
annotated_bgr_image =
|
37 |
-
scene=annotated_bgr_image,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
return annotated_bgr_image[:, :, ::-1]
|
39 |
|
40 |
|
41 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
image: np.ndarray,
|
43 |
x_min: int,
|
44 |
y_min: int,
|
@@ -49,10 +96,17 @@ def efficient_sam_inference(
|
|
49 |
mask = inference_with_box(image, box, EFFICIENT_SAM_MODEL, DEVICE)
|
50 |
mask = mask[np.newaxis, ...]
|
51 |
detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
|
52 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
|
54 |
|
55 |
-
def
|
56 |
image: np.ndarray,
|
57 |
x_min: int,
|
58 |
y_min: int,
|
@@ -76,10 +130,17 @@ def sam_inference(
|
|
76 |
)[0][0][0].numpy()
|
77 |
mask = mask[np.newaxis, ...]
|
78 |
detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
|
79 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
|
82 |
-
def
|
83 |
image: np.ndarray,
|
84 |
x_min: int,
|
85 |
y_min: int,
|
@@ -87,8 +148,46 @@ def inference(
|
|
87 |
y_max: int
|
88 |
) -> Tuple[np.ndarray, np.ndarray]:
|
89 |
return (
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
)
|
93 |
|
94 |
|
@@ -96,73 +195,92 @@ def clear(_: np.ndarray) -> Tuple[None, None]:
|
|
96 |
return None, None
|
97 |
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
with gr.Blocks() as demo:
|
100 |
gr.Markdown(MARKDOWN)
|
101 |
with gr.Tab(label="Box prompt"):
|
102 |
with gr.Row():
|
103 |
with gr.Column():
|
104 |
-
|
105 |
with gr.Accordion(label="Box", open=False):
|
106 |
with gr.Row():
|
107 |
-
x_min_number
|
108 |
-
y_min_number
|
109 |
-
x_max_number
|
110 |
-
y_max_number
|
111 |
-
|
112 |
-
|
113 |
with gr.Row():
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
gr.Examples(
|
117 |
-
fn=
|
118 |
-
examples=
|
119 |
-
|
120 |
-
|
121 |
-
69,
|
122 |
-
26,
|
123 |
-
625,
|
124 |
-
704
|
125 |
-
],
|
126 |
-
[
|
127 |
-
'https://media.roboflow.com/efficient-sam/corgi.jpg',
|
128 |
-
801,
|
129 |
-
510,
|
130 |
-
1782,
|
131 |
-
993
|
132 |
-
],
|
133 |
-
[
|
134 |
-
'https://media.roboflow.com/efficient-sam/horses.jpg',
|
135 |
-
814,
|
136 |
-
696,
|
137 |
-
1523,
|
138 |
-
1183
|
139 |
-
],
|
140 |
-
[
|
141 |
-
'https://media.roboflow.com/efficient-sam/bears.jpg',
|
142 |
-
653,
|
143 |
-
874,
|
144 |
-
1173,
|
145 |
-
1229
|
146 |
-
]
|
147 |
-
],
|
148 |
-
inputs=[input_image, x_min_number, y_min_number, x_max_number, y_max_number],
|
149 |
-
outputs=[efficient_sam_output_image, sam_output_image],
|
150 |
)
|
151 |
|
152 |
-
|
153 |
-
|
154 |
-
inputs=
|
155 |
-
outputs=
|
156 |
)
|
157 |
-
|
158 |
-
|
159 |
-
inputs=
|
160 |
-
outputs=
|
161 |
)
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
clear,
|
164 |
-
inputs=
|
165 |
-
outputs=[
|
166 |
)
|
167 |
|
168 |
demo.launch(debug=False, show_error=True)
|
|
|
7 |
from PIL import Image
|
8 |
from transformers import SamModel, SamProcessor
|
9 |
|
10 |
+
from utils.efficient_sam import load, inference_with_box, inference_with_point
|
11 |
+
from utils.draw import draw_circle, calculate_dynamic_circle_radius
|
12 |
|
13 |
MARKDOWN = """
|
14 |
# EfficientSAM sv. SAM
|
|
|
18 |
[SAM](https://arxiv.org/abs/2304.02643).
|
19 |
"""
|
20 |
|
21 |
+
BOX_EXAMPLES = [
|
22 |
+
['https://media.roboflow.com/efficient-sam/corgi.jpg', 801, 510, 1782, 993],
|
23 |
+
['https://media.roboflow.com/efficient-sam/horses.jpg', 814, 696, 1523, 1183],
|
24 |
+
['https://media.roboflow.com/efficient-sam/bears.jpg', 653, 874, 1173, 1229]
|
25 |
+
]
|
26 |
+
|
27 |
+
POINT_EXAMPLES = [
|
28 |
+
['https://media.roboflow.com/efficient-sam/corgi.jpg', 1291, 751],
|
29 |
+
['https://media.roboflow.com/efficient-sam/horses.jpg', 1168, 939],
|
30 |
+
['https://media.roboflow.com/efficient-sam/bears.jpg', 913, 1051]
|
31 |
+
]
|
32 |
+
|
33 |
+
PROMPT_COLOR = sv.Color.from_hex("#D3D3D3")
|
34 |
+
MASK_COLOR = sv.Color.from_hex("#FF0000")
|
35 |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
36 |
SAM_MODEL = SamModel.from_pretrained("facebook/sam-vit-huge").to(DEVICE)
|
37 |
SAM_PROCESSOR = SamProcessor.from_pretrained("facebook/sam-vit-huge")
|
38 |
EFFICIENT_SAM_MODEL = load(device=DEVICE)
|
39 |
MASK_ANNOTATOR = sv.MaskAnnotator(
|
40 |
+
color=MASK_COLOR,
|
|
|
|
|
|
|
41 |
color_lookup=sv.ColorLookup.INDEX)
|
42 |
|
43 |
|
44 |
+
def annotate_image_with_box_prompt_result(
|
45 |
+
image: np.ndarray,
|
46 |
+
detections: sv.Detections,
|
47 |
+
x_min: int,
|
48 |
+
y_min: int,
|
49 |
+
x_max: int,
|
50 |
+
y_max: int
|
51 |
+
) -> np.ndarray:
|
52 |
+
h, w, _ = image.shape
|
53 |
bgr_image = image[:, :, ::-1]
|
54 |
annotated_bgr_image = MASK_ANNOTATOR.annotate(
|
55 |
scene=bgr_image, detections=detections)
|
56 |
+
annotated_bgr_image = sv.draw_rectangle(
|
57 |
+
scene=annotated_bgr_image,
|
58 |
+
rect=sv.Rect(
|
59 |
+
x=x_min,
|
60 |
+
y=y_min,
|
61 |
+
width=int(x_max - x_min),
|
62 |
+
height=int(y_max - y_min),
|
63 |
+
),
|
64 |
+
color=PROMPT_COLOR,
|
65 |
+
thickness=sv.calculate_dynamic_line_thickness(resolution_wh=(w, h))
|
66 |
+
)
|
67 |
return annotated_bgr_image[:, :, ::-1]
|
68 |
|
69 |
|
70 |
+
def annotate_image_with_point_prompt_result(
|
71 |
+
image: np.ndarray,
|
72 |
+
detections: sv.Detections,
|
73 |
+
x: int,
|
74 |
+
y: int
|
75 |
+
) -> np.ndarray:
|
76 |
+
h, w, _ = image.shape
|
77 |
+
bgr_image = image[:, :, ::-1]
|
78 |
+
annotated_bgr_image = MASK_ANNOTATOR.annotate(
|
79 |
+
scene=bgr_image, detections=detections)
|
80 |
+
annotated_bgr_image = draw_circle(
|
81 |
+
scene=annotated_bgr_image,
|
82 |
+
center=sv.Point(x=x, y=y),
|
83 |
+
radius=calculate_dynamic_circle_radius(resolution_wh=(w, h)),
|
84 |
+
color=PROMPT_COLOR)
|
85 |
+
return annotated_bgr_image[:, :, ::-1]
|
86 |
+
|
87 |
+
|
88 |
+
def efficient_sam_box_inference(
|
89 |
image: np.ndarray,
|
90 |
x_min: int,
|
91 |
y_min: int,
|
|
|
96 |
mask = inference_with_box(image, box, EFFICIENT_SAM_MODEL, DEVICE)
|
97 |
mask = mask[np.newaxis, ...]
|
98 |
detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
|
99 |
+
return annotate_image_with_box_prompt_result(
|
100 |
+
image=image,
|
101 |
+
detections=detections,
|
102 |
+
x_max=x_max,
|
103 |
+
x_min=x_min,
|
104 |
+
y_max=y_max,
|
105 |
+
y_min=y_min
|
106 |
+
)
|
107 |
|
108 |
|
109 |
+
def sam_box_inference(
|
110 |
image: np.ndarray,
|
111 |
x_min: int,
|
112 |
y_min: int,
|
|
|
130 |
)[0][0][0].numpy()
|
131 |
mask = mask[np.newaxis, ...]
|
132 |
detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
|
133 |
+
return annotate_image_with_box_prompt_result(
|
134 |
+
image=image,
|
135 |
+
detections=detections,
|
136 |
+
x_max=x_max,
|
137 |
+
x_min=x_min,
|
138 |
+
y_max=y_max,
|
139 |
+
y_min=y_min
|
140 |
+
)
|
141 |
|
142 |
|
143 |
+
def box_inference(
|
144 |
image: np.ndarray,
|
145 |
x_min: int,
|
146 |
y_min: int,
|
|
|
148 |
y_max: int
|
149 |
) -> Tuple[np.ndarray, np.ndarray]:
|
150 |
return (
|
151 |
+
efficient_sam_box_inference(image, x_min, y_min, x_max, y_max),
|
152 |
+
sam_box_inference(image, x_min, y_min, x_max, y_max)
|
153 |
+
)
|
154 |
+
|
155 |
+
|
156 |
+
def efficient_sam_point_inference(image: np.ndarray, x: int, y: int) -> np.ndarray:
|
157 |
+
point = np.array([[x, y]])
|
158 |
+
mask = inference_with_point(image, point, EFFICIENT_SAM_MODEL, DEVICE)
|
159 |
+
mask = mask[np.newaxis, ...]
|
160 |
+
detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
|
161 |
+
return annotate_image_with_point_prompt_result(
|
162 |
+
image=image, detections=detections, x=x, y=y)
|
163 |
+
|
164 |
+
|
165 |
+
def sam_point_inference(image: np.ndarray, x: int, y: int) -> np.ndarray:
|
166 |
+
input_points = [[[x, y]]]
|
167 |
+
inputs = SAM_PROCESSOR(
|
168 |
+
Image.fromarray(image),
|
169 |
+
input_points=[input_points],
|
170 |
+
return_tensors="pt"
|
171 |
+
).to(DEVICE)
|
172 |
+
|
173 |
+
with torch.no_grad():
|
174 |
+
outputs = SAM_MODEL(**inputs)
|
175 |
+
|
176 |
+
mask = SAM_PROCESSOR.image_processor.post_process_masks(
|
177 |
+
outputs.pred_masks.cpu(),
|
178 |
+
inputs["original_sizes"].cpu(),
|
179 |
+
inputs["reshaped_input_sizes"].cpu()
|
180 |
+
)[0][0][0].numpy()
|
181 |
+
mask = mask[np.newaxis, ...]
|
182 |
+
detections = sv.Detections(xyxy=sv.mask_to_xyxy(masks=mask), mask=mask)
|
183 |
+
return annotate_image_with_point_prompt_result(
|
184 |
+
image=image, detections=detections, x=x, y=y)
|
185 |
+
|
186 |
+
|
187 |
+
def point_inference(image: np.ndarray, x: int, y: int) -> Tuple[np.ndarray, np.ndarray]:
|
188 |
+
return (
|
189 |
+
efficient_sam_point_inference(image, x, y),
|
190 |
+
sam_point_inference(image, x, y)
|
191 |
)
|
192 |
|
193 |
|
|
|
195 |
return None, None
|
196 |
|
197 |
|
198 |
+
box_input_image = gr.Image()
|
199 |
+
x_min_number = gr.Number(label="x_min")
|
200 |
+
y_min_number = gr.Number(label="y_min")
|
201 |
+
x_max_number = gr.Number(label="x_max")
|
202 |
+
y_max_number = gr.Number(label="y_max")
|
203 |
+
box_inputs = [box_input_image, x_min_number, y_min_number, x_max_number, y_max_number]
|
204 |
+
|
205 |
+
point_input_image = gr.Image()
|
206 |
+
x_number = gr.Number(label="x")
|
207 |
+
y_number = gr.Number(label="y")
|
208 |
+
point_inputs = [point_input_image, x_number, y_number]
|
209 |
+
|
210 |
+
|
211 |
with gr.Blocks() as demo:
|
212 |
gr.Markdown(MARKDOWN)
|
213 |
with gr.Tab(label="Box prompt"):
|
214 |
with gr.Row():
|
215 |
with gr.Column():
|
216 |
+
box_input_image.render()
|
217 |
with gr.Accordion(label="Box", open=False):
|
218 |
with gr.Row():
|
219 |
+
x_min_number.render()
|
220 |
+
y_min_number.render()
|
221 |
+
x_max_number.render()
|
222 |
+
y_max_number.render()
|
223 |
+
efficient_sam_box_output_image = gr.Image(label="EfficientSAM")
|
224 |
+
sam_box_output_image = gr.Image(label="SAM")
|
225 |
with gr.Row():
|
226 |
+
submit_box_inference_button = gr.Button("Submit")
|
227 |
+
gr.Examples(
|
228 |
+
fn=box_inference,
|
229 |
+
examples=BOX_EXAMPLES,
|
230 |
+
inputs=box_inputs,
|
231 |
+
outputs=[efficient_sam_box_output_image, sam_box_output_image],
|
232 |
+
)
|
233 |
+
with gr.Tab(label="Point prompt"):
|
234 |
+
with gr.Row():
|
235 |
+
with gr.Column():
|
236 |
+
point_input_image.render()
|
237 |
+
with gr.Accordion(label="Point", open=False):
|
238 |
+
with gr.Row():
|
239 |
+
x_number.render()
|
240 |
+
y_number.render()
|
241 |
+
efficient_sam_point_output_image = gr.Image(label="EfficientSAM")
|
242 |
+
sam_point_output_image = gr.Image(label="SAM")
|
243 |
+
with gr.Row():
|
244 |
+
submit_point_inference_button = gr.Button("Submit")
|
245 |
gr.Examples(
|
246 |
+
fn=point_inference,
|
247 |
+
examples=POINT_EXAMPLES,
|
248 |
+
inputs=point_inputs,
|
249 |
+
outputs=[efficient_sam_point_output_image, sam_point_output_image],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
250 |
)
|
251 |
|
252 |
+
submit_box_inference_button.click(
|
253 |
+
efficient_sam_box_inference,
|
254 |
+
inputs=box_inputs,
|
255 |
+
outputs=efficient_sam_box_output_image
|
256 |
)
|
257 |
+
submit_box_inference_button.click(
|
258 |
+
sam_box_inference,
|
259 |
+
inputs=box_inputs,
|
260 |
+
outputs=sam_box_output_image
|
261 |
)
|
262 |
+
|
263 |
+
submit_point_inference_button.click(
|
264 |
+
efficient_sam_point_inference,
|
265 |
+
inputs=point_inputs,
|
266 |
+
outputs=efficient_sam_point_output_image
|
267 |
+
)
|
268 |
+
submit_point_inference_button.click(
|
269 |
+
sam_point_inference,
|
270 |
+
inputs=point_inputs,
|
271 |
+
outputs=sam_point_output_image
|
272 |
+
)
|
273 |
+
|
274 |
+
box_input_image.change(
|
275 |
+
clear,
|
276 |
+
inputs=box_input_image,
|
277 |
+
outputs=[efficient_sam_box_output_image, sam_box_output_image]
|
278 |
+
)
|
279 |
+
|
280 |
+
point_input_image.change(
|
281 |
clear,
|
282 |
+
inputs=point_input_image,
|
283 |
+
outputs=[efficient_sam_point_output_image, sam_point_output_image]
|
284 |
)
|
285 |
|
286 |
demo.launch(debug=False, show_error=True)
|
utils/draw.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import supervision as sv
|
6 |
+
|
7 |
+
|
8 |
+
def draw_circle(
|
9 |
+
scene: np.ndarray, center: sv.Point, color: sv.Color, radius: int = 2
|
10 |
+
) -> np.ndarray:
|
11 |
+
cv2.circle(
|
12 |
+
scene,
|
13 |
+
center=center.as_xy_int_tuple(),
|
14 |
+
radius=radius,
|
15 |
+
color=color.as_bgr(),
|
16 |
+
thickness=-1,
|
17 |
+
)
|
18 |
+
return scene
|
19 |
+
|
20 |
+
|
21 |
+
def calculate_dynamic_circle_radius(resolution_wh: Tuple[int, int]) -> int:
|
22 |
+
min_dimension = min(resolution_wh)
|
23 |
+
if min_dimension < 480:
|
24 |
+
return 4
|
25 |
+
if min_dimension < 720:
|
26 |
+
return 8
|
27 |
+
if min_dimension < 1080:
|
28 |
+
return 8
|
29 |
+
if min_dimension < 2160:
|
30 |
+
return 16
|
31 |
+
else:
|
32 |
+
return 16
|
utils/efficient_sam.py
CHANGED
@@ -45,3 +45,36 @@ def inference_with_box(
|
|
45 |
max_predicted_iou = curr_predicted_iou
|
46 |
selected_mask_using_predicted_iou = all_masks[m]
|
47 |
return selected_mask_using_predicted_iou
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
max_predicted_iou = curr_predicted_iou
|
46 |
selected_mask_using_predicted_iou = all_masks[m]
|
47 |
return selected_mask_using_predicted_iou
|
48 |
+
|
49 |
+
|
50 |
+
def inference_with_point(
|
51 |
+
image: np.ndarray,
|
52 |
+
point: np.ndarray,
|
53 |
+
model: torch.jit.ScriptModule,
|
54 |
+
device: torch.device
|
55 |
+
) -> np.ndarray:
|
56 |
+
pts_sampled = torch.reshape(torch.tensor(point), [1, 1, -1, 2])
|
57 |
+
max_num_pts = pts_sampled.shape[2]
|
58 |
+
pts_labels = torch.ones(1, 1, max_num_pts)
|
59 |
+
img_tensor = ToTensor()(image)
|
60 |
+
|
61 |
+
predicted_logits, predicted_iou = model(
|
62 |
+
img_tensor[None, ...].to(device),
|
63 |
+
pts_sampled.to(device),
|
64 |
+
pts_labels.to(device),
|
65 |
+
)
|
66 |
+
predicted_logits = predicted_logits.cpu()
|
67 |
+
all_masks = torch.ge(torch.sigmoid(predicted_logits[0, 0, :, :, :]), 0.5).numpy()
|
68 |
+
predicted_iou = predicted_iou[0, 0, ...].cpu().detach().numpy()
|
69 |
+
|
70 |
+
max_predicted_iou = -1
|
71 |
+
selected_mask_using_predicted_iou = None
|
72 |
+
for m in range(all_masks.shape[0]):
|
73 |
+
curr_predicted_iou = predicted_iou[m]
|
74 |
+
if (
|
75 |
+
curr_predicted_iou > max_predicted_iou
|
76 |
+
or selected_mask_using_predicted_iou is None
|
77 |
+
):
|
78 |
+
max_predicted_iou = curr_predicted_iou
|
79 |
+
selected_mask_using_predicted_iou = all_masks[m]
|
80 |
+
return selected_mask_using_predicted_iou
|