Spaces:
Running
Running
Initial commit
Browse files- .gitattributes +1 -0
- app.py +419 -0
- ckpts/sam_vit_b_01ec64.pth +3 -0
- ckpts/zim_vit_b_2043/decoder.onnx +3 -0
- ckpts/zim_vit_b_2043/encoder.onnx +3 -0
- ckpts/zim_vit_l_2092/decoder.onnx +3 -0
- ckpts/zim_vit_l_2092/encoder.onnx +3 -0
- config/__init__.py +1 -0
- config/config.py +66 -0
- examples/example1.jpg +0 -0
- examples/example2.jpg +0 -0
- examples/example3.jpg +0 -0
- examples/example4.jpg +0 -0
- examples/example5.jpg +0 -0
- examples/example6.jpg +0 -0
- examples/example7.jpg +0 -0
- examples/example8.jpg +0 -0
- packages.txt +1 -0
- pre-requirements.txt +3 -0
- requirements.txt +7 -0
- zim/__init__.py +9 -0
- zim/automatic_mask_generator.py +378 -0
- zim/build_model.py +29 -0
- zim/modeling/decoder.py +90 -0
- zim/modeling/encoder.py +62 -0
- zim/modeling/zim.py +190 -0
- zim/predictor.py +275 -0
- zim/utils/__init__.py +10 -0
- zim/utils/amg.py +373 -0
- zim/utils/argparser.py +96 -0
- zim/utils/print.py +20 -0
- zim/utils/utils.py +148 -0
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
ckpts/ filter=lfs diff=lfs merge=lfs -text
|
app.py
ADDED
@@ -0,0 +1,419 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024-present Naver Cloud Corp.
|
3 |
+
This source code is based on code from the Segment Anything Model (SAM)
|
4 |
+
(https://github.com/facebookresearch/segment-anything).
|
5 |
+
|
6 |
+
This source code is licensed under the license found in the
|
7 |
+
LICENSE file in the root directory of this source tree.
|
8 |
+
"""
|
9 |
+
import os, sys
|
10 |
+
sys.path.append(os.getcwd())
|
11 |
+
|
12 |
+
# Gradio demo, comparison SAM vs ZIM
|
13 |
+
import os
|
14 |
+
import torch
|
15 |
+
import gradio as gr
|
16 |
+
from gradio_image_prompter import ImagePrompter
|
17 |
+
import numpy as np
|
18 |
+
import cv2
|
19 |
+
from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
|
20 |
+
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
|
21 |
+
from zim.utils import show_mat_anns
|
22 |
+
|
23 |
+
def get_shortest_axis(image):
|
24 |
+
h, w, _ = image.shape
|
25 |
+
return h if h < w else w
|
26 |
+
|
27 |
+
def reset_image(image, prompts):
|
28 |
+
if image is None:
|
29 |
+
image = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
30 |
+
else:
|
31 |
+
image = image['image']
|
32 |
+
zim_predictor.set_image(image)
|
33 |
+
sam_predictor.set_image(image)
|
34 |
+
prompts = dict()
|
35 |
+
black = np.zeros(image.shape[:2], dtype=np.uint8)
|
36 |
+
|
37 |
+
return (image, image, image, image, black, black, black, black, prompts)
|
38 |
+
|
39 |
+
def reset_example_image(image, prompts):
|
40 |
+
if image is None:
|
41 |
+
image = np.zeros((1024, 1024, 3), dtype=np.uint8)
|
42 |
+
|
43 |
+
zim_predictor.set_image(image)
|
44 |
+
sam_predictor.set_image(image)
|
45 |
+
prompts = dict()
|
46 |
+
black = np.zeros(image.shape[:2], dtype=np.uint8)
|
47 |
+
|
48 |
+
image_dict = {}
|
49 |
+
image_dict['image'] = image
|
50 |
+
image_dict['prompts'] = prompts
|
51 |
+
|
52 |
+
return (image, image_dict, image, image, image, black, black, black, black, prompts)
|
53 |
+
|
54 |
+
def run_amg(image):
|
55 |
+
gr.Info('Checkout ZIM Auto Mask tab.', duration=3)
|
56 |
+
zim_masks = zim_mask_generator.generate(image)
|
57 |
+
zim_masks_vis = show_mat_anns(image, zim_masks)
|
58 |
+
|
59 |
+
sam_masks = sam_mask_generator.generate(image)
|
60 |
+
sam_masks_vis = show_mat_anns(image, sam_masks)
|
61 |
+
|
62 |
+
return zim_masks_vis, sam_masks_vis
|
63 |
+
|
64 |
+
|
65 |
+
def run_model(image, prompts):
|
66 |
+
if not prompts:
|
67 |
+
raise gr.Error(f'Please input any point or BBox')
|
68 |
+
gr.Info('Checkout ZIM Mask tab.', duration=3)
|
69 |
+
point_coords = None
|
70 |
+
point_labels = None
|
71 |
+
boxes = None
|
72 |
+
|
73 |
+
if "point" in prompts:
|
74 |
+
point_coords, point_labels = [], []
|
75 |
+
|
76 |
+
for type, pts in prompts["point"]:
|
77 |
+
point_coords.append(pts)
|
78 |
+
point_labels.append(type)
|
79 |
+
point_coords = np.array(point_coords)
|
80 |
+
point_labels = np.array(point_labels)
|
81 |
+
|
82 |
+
if "bbox" in prompts:
|
83 |
+
boxes = prompts['bbox']
|
84 |
+
boxes = np.array(boxes)
|
85 |
+
|
86 |
+
if "scribble" in prompts:
|
87 |
+
point_coords, point_labels = [], []
|
88 |
+
|
89 |
+
for pts in prompts["scribble"]:
|
90 |
+
point_coords.append(np.flip(pts))
|
91 |
+
point_labels.append(1)
|
92 |
+
if len(point_coords) == 0:
|
93 |
+
raise gr.Error("Please input any scribbles.")
|
94 |
+
point_coords = np.array(point_coords)
|
95 |
+
point_labels = np.array(point_labels)
|
96 |
+
|
97 |
+
# run ZIM
|
98 |
+
zim_mask, _, _ = zim_predictor.predict(
|
99 |
+
point_coords=point_coords,
|
100 |
+
point_labels=point_labels,
|
101 |
+
box=boxes,
|
102 |
+
multimask_output=False,
|
103 |
+
)
|
104 |
+
zim_mask = np.squeeze(zim_mask, axis=0)
|
105 |
+
zim_mask = np.uint8(zim_mask * 255)
|
106 |
+
|
107 |
+
# run SAM
|
108 |
+
sam_mask, _, _ = sam_predictor.predict(
|
109 |
+
point_coords=point_coords,
|
110 |
+
point_labels=point_labels,
|
111 |
+
box=boxes,
|
112 |
+
multimask_output=False,
|
113 |
+
)
|
114 |
+
sam_mask = np.squeeze(sam_mask, axis=0)
|
115 |
+
sam_mask = np.uint8(sam_mask * 255)
|
116 |
+
|
117 |
+
return zim_mask, sam_mask
|
118 |
+
|
119 |
+
def reset_scribble(image, scribble, prompts):
|
120 |
+
# scribble = dict()
|
121 |
+
for k in prompts.keys():
|
122 |
+
prompts[k] = []
|
123 |
+
|
124 |
+
for k, v in scribble.items():
|
125 |
+
scribble[k] = None
|
126 |
+
|
127 |
+
black = np.zeros(image.shape[:3], dtype=np.uint8)
|
128 |
+
|
129 |
+
return scribble, black, black
|
130 |
+
|
131 |
+
def update_scribble(image, scribble, prompts):
|
132 |
+
if "point" in prompts:
|
133 |
+
del prompts["point"]
|
134 |
+
|
135 |
+
if "bbox" in prompts:
|
136 |
+
del prompts["bbox"]
|
137 |
+
|
138 |
+
prompts = dict() # reset prompt
|
139 |
+
scribble_mask = scribble["layers"][0][..., -1] > 0
|
140 |
+
|
141 |
+
scribble_coords = np.argwhere(scribble_mask)
|
142 |
+
n_points = min(len(scribble_coords), 24)
|
143 |
+
indices = np.linspace(0, len(scribble_coords)-1, n_points, dtype=int)
|
144 |
+
scribble_sampled = scribble_coords[indices]
|
145 |
+
|
146 |
+
prompts["scribble"] = scribble_sampled
|
147 |
+
|
148 |
+
zim_mask, sam_mask = run_model(image, prompts)
|
149 |
+
|
150 |
+
return zim_mask, sam_mask, prompts
|
151 |
+
|
152 |
+
|
153 |
+
def draw_point(img, pt, size, color):
|
154 |
+
# draw circle with white boundary region
|
155 |
+
cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 1.3), (255, 255, 255), -1)
|
156 |
+
cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 0.9), color, -1)
|
157 |
+
|
158 |
+
|
159 |
+
def draw_images(image, mask, prompts):
|
160 |
+
if len(prompts) == 0 or mask.shape[1] == 1:
|
161 |
+
return image, image, image
|
162 |
+
|
163 |
+
minor = get_shortest_axis(image)
|
164 |
+
size = int(minor / 80)
|
165 |
+
|
166 |
+
image = np.float32(image)
|
167 |
+
|
168 |
+
def blending(image, mask):
|
169 |
+
mask = np.float32(mask) / 255
|
170 |
+
blended_image = np.zeros_like(image, dtype=np.float32)
|
171 |
+
blended_image[:, :, :] = [108, 0, 192]
|
172 |
+
blended_image = (image * 0.5) + (blended_image * 0.5)
|
173 |
+
|
174 |
+
img_with_mask = mask[:, :, None] * blended_image + (1 - mask[:, :, None]) * image
|
175 |
+
img_with_mask = np.uint8(img_with_mask)
|
176 |
+
|
177 |
+
return img_with_mask
|
178 |
+
|
179 |
+
img_with_mask = blending(image, mask)
|
180 |
+
img_with_point = img_with_mask.copy()
|
181 |
+
|
182 |
+
if "point" in prompts:
|
183 |
+
for type, pts in prompts["point"]:
|
184 |
+
if type == "Positive":
|
185 |
+
color = (0, 0, 255)
|
186 |
+
draw_point(img_with_point, pts, size, color)
|
187 |
+
elif type == "Negative":
|
188 |
+
color = (255, 0, 0)
|
189 |
+
draw_point(img_with_point, pts, size, color)
|
190 |
+
|
191 |
+
size = int(minor / 200)
|
192 |
+
|
193 |
+
return (
|
194 |
+
img,
|
195 |
+
img_with_mask,
|
196 |
+
)
|
197 |
+
|
198 |
+
def get_point_or_box_prompts(img, prompts):
|
199 |
+
image, img_prompts = img['image'], img['points']
|
200 |
+
point_prompts = []
|
201 |
+
box_prompts = []
|
202 |
+
for prompt in img_prompts:
|
203 |
+
for p in range(len(prompt)):
|
204 |
+
prompt[p] = int(prompt[p])
|
205 |
+
if prompt[2] == 2 and prompt[5] == 3: # box prompt
|
206 |
+
if len(box_prompts) != 0:
|
207 |
+
raise gr.Error("Please input only one BBox.", duration=3)
|
208 |
+
box_prompts.append([prompt[0], prompt[1], prompt[3], prompt[4]])
|
209 |
+
elif prompt[2] == 1 and prompt[5] == 4: # Positive point prompt
|
210 |
+
point_prompts.append((1, (prompt[0], prompt[1])))
|
211 |
+
elif prompt[2] == 0 and prompt[5] == 4: # Negative point prompt
|
212 |
+
point_prompts.append((0, (prompt[0], prompt[1])))
|
213 |
+
|
214 |
+
if "scribble" in prompts:
|
215 |
+
del prompts["scribble"]
|
216 |
+
|
217 |
+
if len(point_prompts) > 0:
|
218 |
+
prompts['point'] = point_prompts
|
219 |
+
elif 'point' in prompts:
|
220 |
+
del prompts['point']
|
221 |
+
|
222 |
+
if len(box_prompts) > 0:
|
223 |
+
prompts['bbox'] = box_prompts
|
224 |
+
elif 'bbox' in prompts:
|
225 |
+
del prompts['bbox']
|
226 |
+
|
227 |
+
zim_mask, sam_mask = run_model(image, prompts)
|
228 |
+
|
229 |
+
return image, zim_mask, sam_mask, prompts
|
230 |
+
|
231 |
+
def get_examples():
|
232 |
+
assets_dir = os.path.join(os.path.dirname(__file__), 'examples')
|
233 |
+
images = os.listdir(assets_dir)
|
234 |
+
return [os.path.join(assets_dir, img) for img in images]
|
235 |
+
|
236 |
+
if __name__ == "__main__":
|
237 |
+
backbone = "vit_b"
|
238 |
+
|
239 |
+
# load ZIM
|
240 |
+
ckpt_mat = "ckpts/zim_vit_b_2043"
|
241 |
+
zim = zim_model_registry[backbone](checkpoint=ckpt_mat)
|
242 |
+
if torch.cuda.is_available():
|
243 |
+
zim.cuda()
|
244 |
+
zim_predictor = ZimPredictor(zim)
|
245 |
+
zim_mask_generator = ZimAutomaticMaskGenerator(
|
246 |
+
zim,
|
247 |
+
pred_iou_thresh=0.7,
|
248 |
+
points_per_batch=8,
|
249 |
+
stability_score_thresh=0.9,
|
250 |
+
)
|
251 |
+
|
252 |
+
# load SAM
|
253 |
+
ckpt_sam = "ckpts/sam_vit_b_01ec64.pth"
|
254 |
+
sam = sam_model_registry[backbone](checkpoint=ckpt_sam)
|
255 |
+
if torch.cuda.is_available():
|
256 |
+
sam.cuda()
|
257 |
+
sam_predictor = SamPredictor(sam)
|
258 |
+
sam_mask_generator = SamAutomaticMaskGenerator(
|
259 |
+
sam,
|
260 |
+
points_per_batch=8,
|
261 |
+
)
|
262 |
+
|
263 |
+
with gr.Blocks() as demo:
|
264 |
+
gr.Markdown("# <center> [Demo] ZIM: Zero-Shot Image Matting for Anything")
|
265 |
+
|
266 |
+
prompts = gr.State(dict())
|
267 |
+
img = gr.Image(visible=False)
|
268 |
+
example_image = gr.Image(visible=False)
|
269 |
+
|
270 |
+
with gr.Row():
|
271 |
+
with gr.Column():
|
272 |
+
# Point and Bbox prompt
|
273 |
+
with gr.Tab(label="Point or Box"):
|
274 |
+
img_with_point_or_box = ImagePrompter(
|
275 |
+
label="query image",
|
276 |
+
sources="upload"
|
277 |
+
)
|
278 |
+
interactions = "Left Click (Pos) | Middle/Right Click (Neg) | Press Move (Box)"
|
279 |
+
gr.Markdown("<h3 style='text-align: center'> {} </h3>".format(interactions))
|
280 |
+
run_bttn = gr.Button("Run")
|
281 |
+
amg_bttn = gr.Button("Automatic Mask Generation")
|
282 |
+
|
283 |
+
# Scribble prompt
|
284 |
+
with gr.Tab(label="Scribble"):
|
285 |
+
img_with_scribble = gr.ImageEditor(
|
286 |
+
label="Scribble",
|
287 |
+
brush=gr.Brush(colors=["#00FF00"], default_size=15),
|
288 |
+
sources="upload",
|
289 |
+
transforms=None,
|
290 |
+
layers=False
|
291 |
+
)
|
292 |
+
interactions = "Press Move (Scribble)"
|
293 |
+
gr.Markdown("<h3 style='text-align: center'> Step 1. Select Draw button </h3>")
|
294 |
+
gr.Markdown("<h3 style='text-align: center'> Step 2. {} </h3>".format(interactions))
|
295 |
+
scribble_bttn = gr.Button("Run")
|
296 |
+
scribble_reset_bttn = gr.Button("Reset Scribbles")
|
297 |
+
amg_scribble_bttn = gr.Button("Automatic Mask Generation")
|
298 |
+
|
299 |
+
# Example image
|
300 |
+
gr.Examples(get_examples(), inputs=[example_image])
|
301 |
+
|
302 |
+
# with gr.Row():
|
303 |
+
with gr.Column():
|
304 |
+
with gr.Tab(label="ZIM Image"):
|
305 |
+
img_with_zim_mask = gr.Image(
|
306 |
+
label="ZIM Image",
|
307 |
+
interactive=False
|
308 |
+
)
|
309 |
+
|
310 |
+
with gr.Tab(label="ZIM Mask"):
|
311 |
+
zim_mask = gr.Image(
|
312 |
+
label="ZIM Mask",
|
313 |
+
image_mode="L",
|
314 |
+
interactive=False
|
315 |
+
)
|
316 |
+
with gr.Tab(label="ZIM Auto Mask"):
|
317 |
+
zim_amg = gr.Image(
|
318 |
+
label="ZIM Auto Mask",
|
319 |
+
interactive=False
|
320 |
+
)
|
321 |
+
|
322 |
+
with gr.Column():
|
323 |
+
with gr.Tab(label="SAM Image"):
|
324 |
+
img_with_sam_mask = gr.Image(
|
325 |
+
label="SAM image",
|
326 |
+
interactive=False
|
327 |
+
)
|
328 |
+
|
329 |
+
with gr.Tab(label="SAM Mask"):
|
330 |
+
sam_mask = gr.Image(
|
331 |
+
label="SAM Mask",
|
332 |
+
image_mode="L",
|
333 |
+
interactive=False
|
334 |
+
)
|
335 |
+
|
336 |
+
with gr.Tab(label="SAM Auto Mask"):
|
337 |
+
sam_amg = gr.Image(
|
338 |
+
label="SAM Auto Mask",
|
339 |
+
interactive=False
|
340 |
+
)
|
341 |
+
|
342 |
+
example_image.change(
|
343 |
+
reset_example_image,
|
344 |
+
[example_image, prompts],
|
345 |
+
[
|
346 |
+
img,
|
347 |
+
img_with_point_or_box,
|
348 |
+
img_with_scribble,
|
349 |
+
img_with_zim_mask,
|
350 |
+
img_with_sam_mask,
|
351 |
+
zim_amg,
|
352 |
+
sam_amg,
|
353 |
+
zim_mask,
|
354 |
+
sam_mask,
|
355 |
+
prompts,
|
356 |
+
]
|
357 |
+
)
|
358 |
+
|
359 |
+
img_with_point_or_box.upload(
|
360 |
+
reset_image,
|
361 |
+
[img_with_point_or_box, prompts],
|
362 |
+
[
|
363 |
+
img,
|
364 |
+
img_with_scribble,
|
365 |
+
img_with_zim_mask,
|
366 |
+
img_with_sam_mask,
|
367 |
+
zim_amg,
|
368 |
+
sam_amg,
|
369 |
+
zim_mask,
|
370 |
+
sam_mask,
|
371 |
+
prompts,
|
372 |
+
],
|
373 |
+
)
|
374 |
+
|
375 |
+
amg_bttn.click(
|
376 |
+
run_amg,
|
377 |
+
[img],
|
378 |
+
[zim_amg, sam_amg]
|
379 |
+
)
|
380 |
+
amg_scribble_bttn.click(
|
381 |
+
run_amg,
|
382 |
+
[img],
|
383 |
+
[zim_amg, sam_amg]
|
384 |
+
)
|
385 |
+
|
386 |
+
run_bttn.click(
|
387 |
+
get_point_or_box_prompts,
|
388 |
+
[img_with_point_or_box, prompts],
|
389 |
+
[img, zim_mask, sam_mask, prompts]
|
390 |
+
)
|
391 |
+
|
392 |
+
zim_mask.change(
|
393 |
+
draw_images,
|
394 |
+
[img, zim_mask, prompts],
|
395 |
+
[
|
396 |
+
img, img_with_zim_mask,
|
397 |
+
],
|
398 |
+
)
|
399 |
+
sam_mask.change(
|
400 |
+
draw_images,
|
401 |
+
[img, sam_mask, prompts],
|
402 |
+
[
|
403 |
+
img, img_with_sam_mask,
|
404 |
+
],
|
405 |
+
)
|
406 |
+
|
407 |
+
scribble_reset_bttn.click(
|
408 |
+
reset_scribble,
|
409 |
+
[img, img_with_scribble, prompts],
|
410 |
+
[img_with_scribble, zim_mask, sam_mask],
|
411 |
+
)
|
412 |
+
scribble_bttn.click(
|
413 |
+
update_scribble,
|
414 |
+
[img, img_with_scribble, prompts],
|
415 |
+
[zim_mask, sam_mask, prompts],
|
416 |
+
)
|
417 |
+
|
418 |
+
demo.queue()
|
419 |
+
demo.launch()
|
ckpts/sam_vit_b_01ec64.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
|
3 |
+
size 375042383
|
ckpts/zim_vit_b_2043/decoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a99d14bac3dc28b17809020b1711b1991bd5b7eb18f678c3624ce508748cfa11
|
3 |
+
size 19330176
|
ckpts/zim_vit_b_2043/encoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6123ea10722dd43a0f8bbb96cd6a0532d8a65e2b1e8339f0cd6593d494807809
|
3 |
+
size 360680416
|
ckpts/zim_vit_l_2092/decoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c24a42d79053d20594760ad9afdef9cca985f29493f3e1e0fe3462ca291b57f7
|
3 |
+
size 19330176
|
ckpts/zim_vit_l_2092/encoder.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6b0360b3c32dfa11555fb1de27cd3533cec1e932204da0dec2b3772045dc7db3
|
3 |
+
size 1235692110
|
config/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from config.config import generate_config
|
config/config.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024-present Naver Cloud Corp.
|
3 |
+
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
from easydict import EasyDict as edict
|
9 |
+
|
10 |
+
config_ = edict()
|
11 |
+
|
12 |
+
"""
|
13 |
+
Common configs
|
14 |
+
"""
|
15 |
+
config_.data_root = "/mnt/tmp"
|
16 |
+
config_.use_ddp = True
|
17 |
+
config_.use_amp = False
|
18 |
+
config_.local_rank = 0
|
19 |
+
config_.world_size = 1
|
20 |
+
config_.random_seed = 3407
|
21 |
+
"""
|
22 |
+
Network configs
|
23 |
+
"""
|
24 |
+
config_.network = edict()
|
25 |
+
config_.network.encoder = "vit_b"
|
26 |
+
config_.network.decoder = "zim"
|
27 |
+
config_.network.encode_kernel = 21
|
28 |
+
"""
|
29 |
+
Evaluation configs
|
30 |
+
"""
|
31 |
+
config_.eval = edict()
|
32 |
+
config_.eval.workers = 4
|
33 |
+
config_.eval.image_size = 1024
|
34 |
+
config_.eval.prompt_type = "point,bbox"
|
35 |
+
config_.eval.model_list = "zim,sam"
|
36 |
+
config_.eval.zim_weights = ""
|
37 |
+
config_.eval.sam_weights = ""
|
38 |
+
"""
|
39 |
+
Dataset configs
|
40 |
+
"""
|
41 |
+
config_.dataset = edict()
|
42 |
+
config_.dataset.valset = "MicroMat3K"
|
43 |
+
config_.dataset.data_type = "fine,coarse"
|
44 |
+
config_.dataset.data_list_txt = "data_list.txt"
|
45 |
+
|
46 |
+
|
47 |
+
def remove_prefix(text, prefix):
|
48 |
+
if text.startswith(prefix):
|
49 |
+
return text[len(prefix) :]
|
50 |
+
return text
|
51 |
+
|
52 |
+
|
53 |
+
def generate_config(args):
|
54 |
+
# merge args & config
|
55 |
+
for k, v in args.items():
|
56 |
+
if k.startswith("network_"):
|
57 |
+
config_["network"][remove_prefix(k, "network_")] = v
|
58 |
+
elif k.startswith("eval_"):
|
59 |
+
config_["eval"][remove_prefix(k, "eval_")] = v
|
60 |
+
elif k.startswith("dataset_"):
|
61 |
+
config_["dataset"][remove_prefix(k, "dataset_")] = v
|
62 |
+
elif k == "amp":
|
63 |
+
config_["use_amp"] = v
|
64 |
+
else:
|
65 |
+
config_[k] = v
|
66 |
+
return config_
|
examples/example1.jpg
ADDED
examples/example2.jpg
ADDED
examples/example3.jpg
ADDED
examples/example4.jpg
ADDED
examples/example5.jpg
ADDED
examples/example6.jpg
ADDED
examples/example7.jpg
ADDED
examples/example8.jpg
ADDED
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
wget
|
pre-requirements.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
torchvision
|
3 |
+
numpy==1.24.4
|
requirements.txt
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
easydict
|
2 |
+
opencv-python
|
3 |
+
gradio==4.38.1
|
4 |
+
gradio-image-prompter
|
5 |
+
fastapi==0.112.2
|
6 |
+
git+https://github.com/facebookresearch/segment-anything.git
|
7 |
+
onnxruntime-gpu==1.17.0
|
zim/__init__.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .build_model import build_zim_model, zim_model_registry
|
8 |
+
from .predictor import ZimPredictor
|
9 |
+
from .automatic_mask_generator import ZimAutomaticMaskGenerator
|
zim/automatic_mask_generator.py
ADDED
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024-present Naver Cloud Corp.
|
3 |
+
This source code is based on code from the Segment Anything Model (SAM)
|
4 |
+
(https://github.com/facebookresearch/segment-anything).
|
5 |
+
|
6 |
+
This source code is licensed under the license found in the
|
7 |
+
LICENSE file in the root directory of this source tree.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torchvision.ops.boxes import batched_nms, box_area
|
13 |
+
|
14 |
+
from typing import Any, Dict, List, Optional, Tuple
|
15 |
+
|
16 |
+
from .modeling.zim import Zim
|
17 |
+
from .predictor import ZimPredictor
|
18 |
+
from .utils.amg import (
|
19 |
+
MaskData,
|
20 |
+
area_from_rle,
|
21 |
+
batch_iterator,
|
22 |
+
batched_mask_to_box,
|
23 |
+
box_xyxy_to_xywh,
|
24 |
+
build_all_layer_point_grids,
|
25 |
+
calculate_stability_score,
|
26 |
+
coco_encode_rle,
|
27 |
+
generate_crop_boxes,
|
28 |
+
is_box_near_crop_edge,
|
29 |
+
mask_to_rle_pytorch,
|
30 |
+
remove_small_regions,
|
31 |
+
rle_to_mask,
|
32 |
+
uncrop_boxes_xyxy,
|
33 |
+
uncrop_masks,
|
34 |
+
uncrop_points,
|
35 |
+
)
|
36 |
+
|
37 |
+
|
38 |
+
class ZimAutomaticMaskGenerator:
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
model: Zim,
|
42 |
+
points_per_side: Optional[int] = 32,
|
43 |
+
points_per_batch: int = 64,
|
44 |
+
pred_iou_thresh: float = 0.88,
|
45 |
+
stability_score_thresh: float = 0.9,
|
46 |
+
stability_score_offset: float = 0.1,
|
47 |
+
box_nms_thresh: float = 0.7,
|
48 |
+
crop_n_layers: int = 0,
|
49 |
+
crop_nms_thresh: float = 0.7,
|
50 |
+
crop_overlap_ratio: float = 512 / 1500,
|
51 |
+
crop_n_points_downscale_factor: int = 1,
|
52 |
+
point_grids: Optional[List[np.ndarray]] = None,
|
53 |
+
min_mask_region_area: int = 0,
|
54 |
+
output_mode: str = "binary_mask",
|
55 |
+
) -> None:
|
56 |
+
"""
|
57 |
+
Using a SAM model, generates masks for the entire image.
|
58 |
+
Generates a grid of point prompts over the image, then filters
|
59 |
+
low quality and duplicate masks. The default settings are chosen
|
60 |
+
for SAM with a ViT-H backbone.
|
61 |
+
|
62 |
+
Arguments:
|
63 |
+
model (Sam): The SAM model to use for mask prediction.
|
64 |
+
points_per_side (int or None): The number of points to be sampled
|
65 |
+
along one side of the image. The total number of points is
|
66 |
+
points_per_side**2. If None, 'point_grids' must provide explicit
|
67 |
+
point sampling.
|
68 |
+
points_per_batch (int): Sets the number of points run simultaneously
|
69 |
+
by the model. Higher numbers may be faster but use more GPU memory.
|
70 |
+
pred_iou_thresh (float): A filtering threshold in [0,1], using the
|
71 |
+
model's predicted mask quality.
|
72 |
+
stability_score_thresh (float): A filtering threshold in [0,1], using
|
73 |
+
the stability of the mask under changes to the cutoff used to binarize
|
74 |
+
the model's mask predictions.
|
75 |
+
stability_score_offset (float): The amount to shift the cutoff when
|
76 |
+
calculated the stability score.
|
77 |
+
box_nms_thresh (float): The box IoU cutoff used by non-maximal
|
78 |
+
suppression to filter duplicate masks.
|
79 |
+
crop_n_layers (int): If >0, mask prediction will be run again on
|
80 |
+
crops of the image. Sets the number of layers to run, where each
|
81 |
+
layer has 2**i_layer number of image crops.
|
82 |
+
crop_nms_thresh (float): The box IoU cutoff used by non-maximal
|
83 |
+
suppression to filter duplicate masks between different crops.
|
84 |
+
crop_overlap_ratio (float): Sets the degree to which crops overlap.
|
85 |
+
In the first crop layer, crops will overlap by this fraction of
|
86 |
+
the image length. Later layers with more crops scale down this overlap.
|
87 |
+
crop_n_points_downscale_factor (int): The number of points-per-side
|
88 |
+
sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
|
89 |
+
point_grids (list(np.ndarray) or None): A list over explicit grids
|
90 |
+
of points used for sampling, normalized to [0,1]. The nth grid in the
|
91 |
+
list is used in the nth crop layer. Exclusive with points_per_side.
|
92 |
+
min_mask_region_area (int): If >0, postprocessing will be applied
|
93 |
+
to remove disconnected regions and holes in masks with area smaller
|
94 |
+
than min_mask_region_area. Requires opencv.
|
95 |
+
output_mode (str): The form masks are returned in. Can be 'binary_mask',
|
96 |
+
'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
|
97 |
+
For large resolutions, 'binary_mask' may consume large amounts of
|
98 |
+
memory.
|
99 |
+
"""
|
100 |
+
|
101 |
+
assert (points_per_side is None) != (
|
102 |
+
point_grids is None
|
103 |
+
), "Exactly one of points_per_side or point_grid must be provided."
|
104 |
+
if points_per_side is not None:
|
105 |
+
self.point_grids = build_all_layer_point_grids(
|
106 |
+
points_per_side,
|
107 |
+
crop_n_layers,
|
108 |
+
crop_n_points_downscale_factor,
|
109 |
+
)
|
110 |
+
elif point_grids is not None:
|
111 |
+
self.point_grids = point_grids
|
112 |
+
else:
|
113 |
+
raise ValueError("Can't have both points_per_side and point_grid be None.")
|
114 |
+
|
115 |
+
assert output_mode in [
|
116 |
+
"binary_mask",
|
117 |
+
"uncompressed_rle",
|
118 |
+
"coco_rle",
|
119 |
+
], f"Unknown output_mode {output_mode}."
|
120 |
+
if output_mode == "coco_rle":
|
121 |
+
from pycocotools import mask as mask_utils # type: ignore # noqa: F401
|
122 |
+
|
123 |
+
if min_mask_region_area > 0:
|
124 |
+
import cv2 # type: ignore # noqa: F401
|
125 |
+
|
126 |
+
self.predictor = ZimPredictor(model)
|
127 |
+
self.points_per_batch = points_per_batch
|
128 |
+
self.pred_iou_thresh = pred_iou_thresh
|
129 |
+
self.stability_score_thresh = stability_score_thresh
|
130 |
+
self.stability_score_offset = stability_score_offset
|
131 |
+
self.box_nms_thresh = box_nms_thresh
|
132 |
+
self.crop_n_layers = crop_n_layers
|
133 |
+
self.crop_nms_thresh = crop_nms_thresh
|
134 |
+
self.crop_overlap_ratio = crop_overlap_ratio
|
135 |
+
self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
|
136 |
+
self.min_mask_region_area = min_mask_region_area
|
137 |
+
self.output_mode = output_mode
|
138 |
+
|
139 |
+
@torch.no_grad()
|
140 |
+
def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
|
141 |
+
"""
|
142 |
+
Generates masks for the given image.
|
143 |
+
|
144 |
+
Arguments:
|
145 |
+
image (np.ndarray): The image to generate masks for, in HWC uint8 format.
|
146 |
+
|
147 |
+
Returns:
|
148 |
+
list(dict(str, any)): A list over records for masks. Each record is
|
149 |
+
a dict containing the following keys:
|
150 |
+
segmentation (dict(str, any) or np.ndarray): The mask. If
|
151 |
+
output_mode='binary_mask', is an array of shape HW. Otherwise,
|
152 |
+
is a dictionary containing the RLE.
|
153 |
+
bbox (list(float)): The box around the mask, in XYWH format.
|
154 |
+
area (int): The area in pixels of the mask.
|
155 |
+
predicted_iou (float): The model's own prediction of the mask's
|
156 |
+
quality. This is filtered by the pred_iou_thresh parameter.
|
157 |
+
point_coords (list(list(float))): The point coordinates input
|
158 |
+
to the model to generate this mask.
|
159 |
+
stability_score (float): A measure of the mask's quality. This
|
160 |
+
is filtered on using the stability_score_thresh parameter.
|
161 |
+
crop_box (list(float)): The crop of the image used to generate
|
162 |
+
the mask, given in XYWH format.
|
163 |
+
"""
|
164 |
+
|
165 |
+
# Generate masks
|
166 |
+
mask_data = self._generate_masks(image)
|
167 |
+
|
168 |
+
# Filter small disconnected regions and holes in masks
|
169 |
+
if self.min_mask_region_area > 0:
|
170 |
+
mask_data = self.postprocess_small_regions(
|
171 |
+
mask_data,
|
172 |
+
self.min_mask_region_area,
|
173 |
+
max(self.box_nms_thresh, self.crop_nms_thresh),
|
174 |
+
)
|
175 |
+
|
176 |
+
# Encode masks
|
177 |
+
if self.output_mode == "coco_rle":
|
178 |
+
mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
|
179 |
+
elif self.output_mode == "binary_mask":
|
180 |
+
mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
|
181 |
+
else:
|
182 |
+
mask_data["segmentations"] = mask_data["rles"]
|
183 |
+
|
184 |
+
# Write mask records
|
185 |
+
curr_anns = []
|
186 |
+
for idx in range(len(mask_data["segmentations"])):
|
187 |
+
ann = {
|
188 |
+
"segmentation": mask_data["segmentations"][idx],
|
189 |
+
"logit": mask_data["logits"][idx],
|
190 |
+
"area": area_from_rle(mask_data["rles"][idx]),
|
191 |
+
"bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
|
192 |
+
"predicted_iou": mask_data["iou_preds"][idx].item(),
|
193 |
+
"point_coords": [mask_data["points"][idx].tolist()],
|
194 |
+
"stability_score": mask_data["stability_score"][idx].item(),
|
195 |
+
"crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
|
196 |
+
}
|
197 |
+
curr_anns.append(ann)
|
198 |
+
|
199 |
+
return curr_anns
|
200 |
+
|
201 |
+
def _generate_masks(self, image: np.ndarray) -> MaskData:
|
202 |
+
orig_size = image.shape[:2]
|
203 |
+
crop_boxes, layer_idxs = generate_crop_boxes(
|
204 |
+
orig_size, self.crop_n_layers, self.crop_overlap_ratio
|
205 |
+
)
|
206 |
+
|
207 |
+
# Iterate over image crops
|
208 |
+
data = MaskData()
|
209 |
+
for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
|
210 |
+
crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
|
211 |
+
data.cat(crop_data)
|
212 |
+
|
213 |
+
# Remove duplicate masks between crops
|
214 |
+
if len(crop_boxes) > 1:
|
215 |
+
# Prefer masks from smaller crops
|
216 |
+
scores = 1 / box_area(data["crop_boxes"])
|
217 |
+
scores = scores.to(data["boxes"].device)
|
218 |
+
keep_by_nms = batched_nms(
|
219 |
+
data["boxes"].float(),
|
220 |
+
scores,
|
221 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
222 |
+
iou_threshold=self.crop_nms_thresh,
|
223 |
+
)
|
224 |
+
data.filter(keep_by_nms)
|
225 |
+
|
226 |
+
data.to_numpy()
|
227 |
+
return data
|
228 |
+
|
229 |
+
def _process_crop(
|
230 |
+
self,
|
231 |
+
image: np.ndarray,
|
232 |
+
crop_box: List[int],
|
233 |
+
crop_layer_idx: int,
|
234 |
+
orig_size: Tuple[int, ...],
|
235 |
+
) -> MaskData:
|
236 |
+
# Crop the image and calculate embeddings
|
237 |
+
x0, y0, x1, y1 = crop_box
|
238 |
+
cropped_im = image[y0:y1, x0:x1, :]
|
239 |
+
cropped_im_size = cropped_im.shape[:2]
|
240 |
+
self.predictor.set_image(cropped_im)
|
241 |
+
|
242 |
+
# Get points for this crop
|
243 |
+
points_scale = np.array(cropped_im_size)[None, ::-1]
|
244 |
+
points_for_image = self.point_grids[crop_layer_idx] * points_scale
|
245 |
+
|
246 |
+
# Generate masks for this crop in batches
|
247 |
+
data = MaskData()
|
248 |
+
for (points,) in batch_iterator(self.points_per_batch, points_for_image):
|
249 |
+
batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
|
250 |
+
data.cat(batch_data)
|
251 |
+
del batch_data
|
252 |
+
self.predictor.reset_image()
|
253 |
+
|
254 |
+
# Remove duplicates within this crop.
|
255 |
+
keep_by_nms = batched_nms(
|
256 |
+
data["boxes"].float(),
|
257 |
+
data["iou_preds"],
|
258 |
+
torch.zeros_like(data["boxes"][:, 0]), # categories
|
259 |
+
iou_threshold=self.box_nms_thresh,
|
260 |
+
)
|
261 |
+
data.filter(keep_by_nms)
|
262 |
+
|
263 |
+
# Return to the original image frame
|
264 |
+
data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
|
265 |
+
data["points"] = uncrop_points(data["points"], crop_box)
|
266 |
+
data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
|
267 |
+
|
268 |
+
return data
|
269 |
+
|
270 |
+
def _process_batch(
|
271 |
+
self,
|
272 |
+
points: np.ndarray,
|
273 |
+
im_size: Tuple[int, ...],
|
274 |
+
crop_box: List[int],
|
275 |
+
orig_size: Tuple[int, ...],
|
276 |
+
) -> MaskData:
|
277 |
+
orig_h, orig_w = orig_size
|
278 |
+
|
279 |
+
# Run model on this batch
|
280 |
+
transformed_points = self.predictor.transform.apply_coords(points, im_size)
|
281 |
+
in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
|
282 |
+
in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
|
283 |
+
masks, iou_preds, _ = self.predictor.predict_torch(
|
284 |
+
in_points[:, None, :],
|
285 |
+
in_labels[:, None],
|
286 |
+
multimask_output=True,
|
287 |
+
return_logits=True,
|
288 |
+
)
|
289 |
+
|
290 |
+
# Serialize predictions and store in MaskData
|
291 |
+
data = MaskData(
|
292 |
+
masks=masks.flatten(0, 1),
|
293 |
+
logits=(masks.flatten(0, 1) * 255).byte(),
|
294 |
+
iou_preds=iou_preds.flatten(0, 1),
|
295 |
+
points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
|
296 |
+
)
|
297 |
+
del masks
|
298 |
+
|
299 |
+
# Filter by predicted IoU
|
300 |
+
if self.pred_iou_thresh > 0.0:
|
301 |
+
keep_mask = data["iou_preds"] > self.pred_iou_thresh
|
302 |
+
data.filter(keep_mask)
|
303 |
+
|
304 |
+
# Calculate stability score
|
305 |
+
data["stability_score"] = calculate_stability_score(
|
306 |
+
data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
|
307 |
+
)
|
308 |
+
if self.stability_score_thresh > 0.0:
|
309 |
+
keep_mask = data["stability_score"] >= self.stability_score_thresh
|
310 |
+
data.filter(keep_mask)
|
311 |
+
|
312 |
+
# Threshold masks and calculate boxes
|
313 |
+
data["masks"] = data["masks"] > self.predictor.model.mask_threshold
|
314 |
+
data["boxes"] = batched_mask_to_box(data["masks"])
|
315 |
+
|
316 |
+
# Filter boxes that touch crop boundaries
|
317 |
+
keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
|
318 |
+
if not torch.all(keep_mask):
|
319 |
+
data.filter(keep_mask)
|
320 |
+
|
321 |
+
# Compress to RLE
|
322 |
+
data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
|
323 |
+
data["logits"] = uncrop_masks(data["logits"], crop_box, orig_h, orig_w)
|
324 |
+
data["rles"] = mask_to_rle_pytorch(data["masks"])
|
325 |
+
del data["masks"]
|
326 |
+
|
327 |
+
return data
|
328 |
+
|
329 |
+
@staticmethod
|
330 |
+
def postprocess_small_regions(
|
331 |
+
mask_data: MaskData, min_area: int, nms_thresh: float
|
332 |
+
) -> MaskData:
|
333 |
+
"""
|
334 |
+
Removes small disconnected regions and holes in masks, then reruns
|
335 |
+
box NMS to remove any new duplicates.
|
336 |
+
|
337 |
+
Edits mask_data in place.
|
338 |
+
|
339 |
+
Requires open-cv as a dependency.
|
340 |
+
"""
|
341 |
+
if len(mask_data["rles"]) == 0:
|
342 |
+
return mask_data
|
343 |
+
|
344 |
+
# Filter small disconnected regions and holes
|
345 |
+
new_masks = []
|
346 |
+
scores = []
|
347 |
+
for rle in mask_data["rles"]:
|
348 |
+
mask = rle_to_mask(rle)
|
349 |
+
|
350 |
+
mask, changed = remove_small_regions(mask, min_area, mode="holes")
|
351 |
+
unchanged = not changed
|
352 |
+
mask, changed = remove_small_regions(mask, min_area, mode="islands")
|
353 |
+
unchanged = unchanged and not changed
|
354 |
+
|
355 |
+
new_masks.append(torch.as_tensor(mask).unsqueeze(0))
|
356 |
+
# Give score=0 to changed masks and score=1 to unchanged masks
|
357 |
+
# so NMS will prefer ones that didn't need postprocessing
|
358 |
+
scores.append(float(unchanged))
|
359 |
+
|
360 |
+
# Recalculate boxes and remove any new duplicates
|
361 |
+
masks = torch.cat(new_masks, dim=0)
|
362 |
+
boxes = batched_mask_to_box(masks)
|
363 |
+
keep_by_nms = batched_nms(
|
364 |
+
boxes.float(),
|
365 |
+
torch.as_tensor(scores),
|
366 |
+
torch.zeros_like(boxes[:, 0]), # categories
|
367 |
+
iou_threshold=nms_thresh,
|
368 |
+
)
|
369 |
+
|
370 |
+
# Only recalculate RLEs for masks that have changed
|
371 |
+
for i_mask in keep_by_nms:
|
372 |
+
if scores[i_mask] == 0.0:
|
373 |
+
mask_torch = masks[i_mask].unsqueeze(0)
|
374 |
+
mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
|
375 |
+
mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
|
376 |
+
mask_data.filter(keep_by_nms)
|
377 |
+
|
378 |
+
return mask_data
|
zim/build_model.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024-present Naver Cloud Corp.
|
3 |
+
This source code is based on code from the Segment Anything Model (SAM)
|
4 |
+
(https://github.com/facebookresearch/segment-anything).
|
5 |
+
|
6 |
+
This source code is licensed under the license found in the
|
7 |
+
LICENSE file in the root directory of this source tree.
|
8 |
+
"""
|
9 |
+
import os
|
10 |
+
import torch
|
11 |
+
|
12 |
+
from .modeling.zim import Zim
|
13 |
+
from .modeling.encoder import ZIM_Encoder
|
14 |
+
from .modeling.decoder import ZIM_Decoder
|
15 |
+
|
16 |
+
def build_zim_model(checkpoint):
|
17 |
+
|
18 |
+
encoder = ZIM_Encoder(os.path.join(checkpoint, "encoder.onnx"))
|
19 |
+
decoder = ZIM_Decoder(os.path.join(checkpoint, "decoder.onnx"))
|
20 |
+
net = Zim(encoder, decoder)
|
21 |
+
|
22 |
+
return net
|
23 |
+
|
24 |
+
zim_model_registry = {
|
25 |
+
"default": build_zim_model,
|
26 |
+
"vit_l": build_zim_model,
|
27 |
+
"vit_b": build_zim_model,
|
28 |
+
}
|
29 |
+
|
zim/modeling/decoder.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024-present Naver Cloud Corp.
|
3 |
+
This source code is based on code from the Segment Anything Model (SAM)
|
4 |
+
(https://github.com/facebookresearch/segment-anything).
|
5 |
+
|
6 |
+
This source code is licensed under the license found in the
|
7 |
+
LICENSE file in the root directory of this source tree.
|
8 |
+
"""
|
9 |
+
import torch
|
10 |
+
from typing import Any, Callable
|
11 |
+
import onnxruntime
|
12 |
+
import numpy as np
|
13 |
+
|
14 |
+
def np2tensor(np_array, device):
|
15 |
+
return torch.from_numpy(np_array).to(device)
|
16 |
+
|
17 |
+
def tensor2np(torch_tensor):
|
18 |
+
if torch_tensor is None:
|
19 |
+
return None
|
20 |
+
|
21 |
+
return torch_tensor.detach().cpu().numpy()
|
22 |
+
|
23 |
+
class ZIM_Decoder():
|
24 |
+
def __init__(self, onnx_path, num_threads=16):
|
25 |
+
self.onnx_path = onnx_path
|
26 |
+
|
27 |
+
sessionOptions = onnxruntime.SessionOptions()
|
28 |
+
sessionOptions.intra_op_num_threads = num_threads
|
29 |
+
sessionOptions.inter_op_num_threads = num_threads
|
30 |
+
providers = ["CPUExecutionProvider"]
|
31 |
+
|
32 |
+
self.ort_session = onnxruntime.InferenceSession(
|
33 |
+
onnx_path, sess_options=sessionOptions, providers=providers
|
34 |
+
)
|
35 |
+
self.num_mask_tokens = 4
|
36 |
+
|
37 |
+
def cuda(self, device_id=0):
|
38 |
+
providers = [
|
39 |
+
(
|
40 |
+
"CUDAExecutionProvider",
|
41 |
+
{
|
42 |
+
"device_id": device_id,
|
43 |
+
},
|
44 |
+
),
|
45 |
+
]
|
46 |
+
|
47 |
+
self.ort_session.set_providers(providers)
|
48 |
+
|
49 |
+
def forward(
|
50 |
+
self,
|
51 |
+
interm_feats,
|
52 |
+
image_embeddings,
|
53 |
+
points,
|
54 |
+
boxes,
|
55 |
+
attn_mask,
|
56 |
+
):
|
57 |
+
device = image_embeddings.device
|
58 |
+
|
59 |
+
ort_inputs = {
|
60 |
+
"feat_D0": tensor2np(interm_feats[0]),
|
61 |
+
"feat_D1": tensor2np(interm_feats[1]),
|
62 |
+
"feat_D2": tensor2np(interm_feats[2]),
|
63 |
+
"image_embeddings": tensor2np(image_embeddings),
|
64 |
+
"attn_mask": tensor2np(attn_mask),
|
65 |
+
}
|
66 |
+
|
67 |
+
if points is not None:
|
68 |
+
point_coords, point_labels = points
|
69 |
+
ort_inputs["point_coords"] = tensor2np(point_coords.float())
|
70 |
+
ort_inputs["point_labels"] = tensor2np(point_labels.float())
|
71 |
+
|
72 |
+
# add paddings as done in SAM
|
73 |
+
padding_point = np.zeros((ort_inputs["point_coords"].shape[0], 1, 2), dtype=np.float32) - 0.5
|
74 |
+
padding_label = -np.ones((ort_inputs["point_labels"].shape[0], 1), dtype=np.float32)
|
75 |
+
ort_inputs["point_coords"] = np.concatenate([ort_inputs["point_coords"], padding_point], axis=1)
|
76 |
+
ort_inputs["point_labels"] = np.concatenate([ort_inputs["point_labels"], padding_label], axis=1)
|
77 |
+
|
78 |
+
if boxes is not None:
|
79 |
+
ort_inputs["point_coords"] = tensor2np(boxes.reshape(-1, 2, 2))
|
80 |
+
ort_inputs["point_labels"] = np.array([[2, 3]], dtype=np.float32).repeat(boxes.shape[0], 0)
|
81 |
+
|
82 |
+
masks, iou_predictions = self.ort_session.run(None, ort_inputs)
|
83 |
+
|
84 |
+
masks = np2tensor(masks, device)
|
85 |
+
iou_predictions = np2tensor(iou_predictions, device)
|
86 |
+
|
87 |
+
return masks, iou_predictions
|
88 |
+
|
89 |
+
__call__: Callable[..., Any] = forward
|
90 |
+
|
zim/modeling/encoder.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024-present Naver Cloud Corp.
|
3 |
+
This source code is based on code from the Segment Anything Model (SAM)
|
4 |
+
(https://github.com/facebookresearch/segment-anything).
|
5 |
+
|
6 |
+
This source code is licensed under the license found in the
|
7 |
+
LICENSE file in the root directory of this source tree.
|
8 |
+
"""
|
9 |
+
import torch
|
10 |
+
from typing import Any, Callable
|
11 |
+
import onnxruntime
|
12 |
+
|
13 |
+
def np2tensor(np_array, device):
|
14 |
+
return torch.from_numpy(np_array).to(device)
|
15 |
+
|
16 |
+
def tensor2np(torch_tensor):
|
17 |
+
return torch_tensor.detach().cpu().numpy()
|
18 |
+
|
19 |
+
class ZIM_Encoder():
|
20 |
+
def __init__(self, onnx_path, num_threads=16):
|
21 |
+
self.onnx_path = onnx_path
|
22 |
+
|
23 |
+
sessionOptions = onnxruntime.SessionOptions()
|
24 |
+
sessionOptions.intra_op_num_threads = num_threads
|
25 |
+
sessionOptions.inter_op_num_threads = num_threads
|
26 |
+
providers = ["CPUExecutionProvider"]
|
27 |
+
|
28 |
+
self.ort_session = onnxruntime.InferenceSession(
|
29 |
+
onnx_path, sess_options=sessionOptions, providers=providers
|
30 |
+
)
|
31 |
+
|
32 |
+
def cuda(self, device_id=0):
|
33 |
+
providers = [
|
34 |
+
(
|
35 |
+
"CUDAExecutionProvider",
|
36 |
+
{
|
37 |
+
"device_id": device_id,
|
38 |
+
},
|
39 |
+
),
|
40 |
+
]
|
41 |
+
|
42 |
+
self.ort_session.set_providers(providers)
|
43 |
+
|
44 |
+
def forward(
|
45 |
+
self,
|
46 |
+
image,
|
47 |
+
):
|
48 |
+
device = image.device
|
49 |
+
|
50 |
+
ort_inputs = {
|
51 |
+
"image": tensor2np(image),
|
52 |
+
}
|
53 |
+
image_embeddings, feat_D0, feat_D1, feat_D2 = self.ort_session.run(None, ort_inputs)
|
54 |
+
|
55 |
+
image_embeddings = np2tensor(image_embeddings, device)
|
56 |
+
feat_D0 = np2tensor(feat_D0, device)
|
57 |
+
feat_D1 = np2tensor(feat_D1, device)
|
58 |
+
feat_D2 = np2tensor(feat_D2, device)
|
59 |
+
|
60 |
+
return image_embeddings, (feat_D0, feat_D1, feat_D2)
|
61 |
+
|
62 |
+
__call__: Callable[..., Any] = forward
|
zim/modeling/zim.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024-present Naver Cloud Corp.
|
3 |
+
This source code is based on code from the Segment Anything Model (SAM)
|
4 |
+
(https://github.com/facebookresearch/segment-anything).
|
5 |
+
|
6 |
+
This source code is licensed under the license found in the
|
7 |
+
LICENSE file in the root directory of this source tree.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import nn
|
12 |
+
from torch.nn import functional as F
|
13 |
+
from typing import Any, Dict, List
|
14 |
+
|
15 |
+
def gaussian(sigma=6):
|
16 |
+
"""
|
17 |
+
2D Gaussian Kernel Generation.
|
18 |
+
"""
|
19 |
+
size = 6 * sigma + 3
|
20 |
+
x = torch.arange(0, size, 1)
|
21 |
+
y = x[:, None]
|
22 |
+
x0, y0 = 3 * sigma + 1, 3 * sigma + 1
|
23 |
+
g = torch.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
|
24 |
+
return g
|
25 |
+
|
26 |
+
class Zim(nn.Module):
|
27 |
+
def __init__(
|
28 |
+
self,
|
29 |
+
encoder,
|
30 |
+
decoder,
|
31 |
+
*,
|
32 |
+
image_size: int = 1024,
|
33 |
+
pixel_mean: List[float] = [123.675, 116.28, 103.53],
|
34 |
+
pixel_std: List[float] = [58.395, 57.12, 57.375],
|
35 |
+
) -> None:
|
36 |
+
"""
|
37 |
+
SAM predicts object masks from an image and input prompts.
|
38 |
+
|
39 |
+
Arguments:
|
40 |
+
encoder : The backbone used to encode the
|
41 |
+
image into image embeddings that allow for efficient mask prediction.
|
42 |
+
decoder : Predicts masks from the image embeddings and given prompts.
|
43 |
+
pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
|
44 |
+
pixel_std (list(float)): Std values for normalizing pixels in the input image.
|
45 |
+
"""
|
46 |
+
super().__init__()
|
47 |
+
self.encoder = encoder
|
48 |
+
self.decoder = decoder
|
49 |
+
self.output_activation = nn.Sigmoid()
|
50 |
+
|
51 |
+
self.image_size = image_size
|
52 |
+
self.register_buffer(
|
53 |
+
"pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
|
54 |
+
)
|
55 |
+
self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
|
56 |
+
|
57 |
+
self.mask_threshold: float = 0.5
|
58 |
+
self.image_format: str = "RGB"
|
59 |
+
self.num_mask_tokens = decoder.num_mask_tokens
|
60 |
+
|
61 |
+
self.encode_stride = 16
|
62 |
+
self.encode_kernel = 21
|
63 |
+
self.attn_mask_size = 64
|
64 |
+
self.g = gaussian(self.encode_kernel)
|
65 |
+
|
66 |
+
self.output_conv = nn.Conv2d(
|
67 |
+
self.num_mask_tokens,
|
68 |
+
self.num_mask_tokens,
|
69 |
+
kernel_size=1, stride=1, padding=0,
|
70 |
+
)
|
71 |
+
|
72 |
+
@property
|
73 |
+
def device(self) -> Any:
|
74 |
+
return self.pixel_mean.device
|
75 |
+
|
76 |
+
def cuda(self, device_id=None):
|
77 |
+
if type(device_id) == torch.device:
|
78 |
+
device_id = device_id.index
|
79 |
+
|
80 |
+
if device_id is None:
|
81 |
+
device_id = 0
|
82 |
+
|
83 |
+
device = torch.device(f"cuda:{device_id}")
|
84 |
+
super(Zim, self).cuda(device)
|
85 |
+
|
86 |
+
self.encoder.cuda(device_id)
|
87 |
+
self.decoder.cuda(device_id)
|
88 |
+
|
89 |
+
return self
|
90 |
+
|
91 |
+
def postprocess_masks(
|
92 |
+
self, masks: torch.Tensor, input_size: List[int], original_size: torch.Tensor
|
93 |
+
) -> torch.Tensor:
|
94 |
+
"""
|
95 |
+
Remove padding and upscale masks to the original image size.
|
96 |
+
|
97 |
+
Arguments:
|
98 |
+
masks (torch.Tensor): Batched masks from the decoder,
|
99 |
+
in BxCxHxW format.
|
100 |
+
input_size (tuple(int, int)): The size of the image input to the
|
101 |
+
model, in (H, W) format. Used to remove padding.
|
102 |
+
original_size (tuple(int, int)): The original size of the image
|
103 |
+
before resizing for input to the model, in (H, W) format.
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
(torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
|
107 |
+
is given by original_size.
|
108 |
+
"""
|
109 |
+
masks = F.interpolate(
|
110 |
+
masks,
|
111 |
+
(self.image_size, self.image_size),
|
112 |
+
mode="bilinear",
|
113 |
+
align_corners=False,
|
114 |
+
)
|
115 |
+
masks = masks[..., : input_size[0], : input_size[1]]
|
116 |
+
masks = F.interpolate(
|
117 |
+
masks, original_size, mode="bilinear", align_corners=False
|
118 |
+
)
|
119 |
+
return masks
|
120 |
+
|
121 |
+
def preprocess(self, x: torch.Tensor) -> torch.Tensor:
|
122 |
+
"""Normalize pixel values and pad to a square input."""
|
123 |
+
# Normalize colors
|
124 |
+
x = (x - self.pixel_mean) / self.pixel_std
|
125 |
+
|
126 |
+
# Pad
|
127 |
+
h, w = x.shape[-2:]
|
128 |
+
padh = self.image_size - h
|
129 |
+
padw = self.image_size - w
|
130 |
+
x = F.pad(x, (0, padw, 0, padh))
|
131 |
+
return x
|
132 |
+
|
133 |
+
def bbox_attn_mask(self, boxes):
|
134 |
+
"""Prompt-aware Masked Attention: box prompt (binary attn mask) """
|
135 |
+
bs = boxes.shape[0]
|
136 |
+
attn_mask = torch.zeros((bs, self.attn_mask_size, self.attn_mask_size), device=boxes.device)
|
137 |
+
|
138 |
+
# attn_weight = attn_weight.masked_fill(m.logical_not(), -1e4)
|
139 |
+
|
140 |
+
for n in range(bs):
|
141 |
+
xmin, ymin, xmax, ymax = boxes[n]
|
142 |
+
|
143 |
+
xmin, xmax = min(xmin, xmax), max(xmin, xmax)
|
144 |
+
ymin, ymax = min(ymin, ymax), max(ymin, ymax)
|
145 |
+
|
146 |
+
xmin, xmax = int(xmin / self.encode_stride), int(xmax / self.encode_stride)
|
147 |
+
ymin, ymax = int(ymin / self.encode_stride), int(ymax / self.encode_stride)
|
148 |
+
|
149 |
+
xmin, ymin = max(0, xmin), max(0, ymin)
|
150 |
+
xmax = min(self.attn_mask_size, xmax+1)
|
151 |
+
ymax = min(self.attn_mask_size, ymax+1)
|
152 |
+
|
153 |
+
attn_mask[n, ymin:ymax, xmin:xmax] = 1
|
154 |
+
|
155 |
+
return attn_mask
|
156 |
+
|
157 |
+
def point_attn_mask(self, point_coords):
|
158 |
+
"""Prompt-aware Masked Attention: point prompt (soft attn mask) """
|
159 |
+
bs = point_coords.shape[0]
|
160 |
+
attn_mask = torch.zeros((bs, self.attn_mask_size, self.attn_mask_size), device=point_coords.device)
|
161 |
+
|
162 |
+
if self.g.device != point_coords.device:
|
163 |
+
self.g = self.g.to(point_coords.device)
|
164 |
+
|
165 |
+
for n in range(bs):
|
166 |
+
for point in point_coords[n]:
|
167 |
+
x, y = int(point[0] / self.encode_stride), int(point[1].item() / self.encode_stride)
|
168 |
+
|
169 |
+
# outside image boundary
|
170 |
+
if x < 0 or y < 0 or x >= self.attn_mask_size or y >= self.attn_mask_size:
|
171 |
+
continue
|
172 |
+
|
173 |
+
# upper left
|
174 |
+
ul = int(round(x - 3 * self.encode_kernel - 1)), int(round(y - 3 * self.encode_kernel - 1))
|
175 |
+
# bottom right
|
176 |
+
br = int(round(x + 3 * self.encode_kernel + 2)), int(round(y + 3 * self.encode_kernel + 2))
|
177 |
+
|
178 |
+
c, d = int(max(0, -ul[0])), int(min(br[0], self.attn_mask_size) - ul[0])
|
179 |
+
a, b = int(max(0, -ul[1])), int(min(br[1], self.attn_mask_size) - ul[1])
|
180 |
+
|
181 |
+
cc, dd = int(max(0, ul[0])), int(min(br[0], self.attn_mask_size))
|
182 |
+
aa, bb = int(max(0, ul[1])), int(min(br[1], self.attn_mask_size))
|
183 |
+
|
184 |
+
attn_mask[n, aa:bb, cc:dd] = torch.maximum(
|
185 |
+
attn_mask[n, aa:bb, cc:dd], self.g[a:b, c:d]
|
186 |
+
)
|
187 |
+
|
188 |
+
return attn_mask
|
189 |
+
|
190 |
+
|
zim/predictor.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024-present Naver Cloud Corp.
|
3 |
+
This source code is based on code from the Segment Anything Model (SAM)
|
4 |
+
(https://github.com/facebookresearch/segment-anything).
|
5 |
+
|
6 |
+
This source code is licensed under the license found in the
|
7 |
+
LICENSE file in the root directory of this source tree.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
from torch.nn import functional as F
|
13 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
14 |
+
from typing import Optional, Tuple, List
|
15 |
+
|
16 |
+
from .utils import ResizeLongestSide
|
17 |
+
|
18 |
+
class ZimPredictor:
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
model,
|
22 |
+
) -> None:
|
23 |
+
"""
|
24 |
+
Uses SAM to calculate the image embedding for an image, and then
|
25 |
+
allow repeated, efficient mask prediction given prompts.
|
26 |
+
|
27 |
+
Arguments:
|
28 |
+
sam_model (Sam): The model to use for mask prediction.
|
29 |
+
"""
|
30 |
+
super().__init__()
|
31 |
+
self.model = model.module if isinstance(model, DDP) else model
|
32 |
+
self.transform = ResizeLongestSide(self.model.image_size)
|
33 |
+
self.reset_image()
|
34 |
+
|
35 |
+
def set_image(
|
36 |
+
self,
|
37 |
+
image: np.ndarray,
|
38 |
+
image_format: str = "RGB",
|
39 |
+
) -> None:
|
40 |
+
"""
|
41 |
+
Calculates the image embeddings for the provided image, allowing
|
42 |
+
masks to be predicted with the 'predict' method.
|
43 |
+
|
44 |
+
Arguments:
|
45 |
+
image (np.ndarray): The image for calculating masks. Expects an
|
46 |
+
image in HWC uint8 format, with pixel values in [0, 255].
|
47 |
+
image_format (str): The color format of the image, in ['RGB', 'BGR'].
|
48 |
+
"""
|
49 |
+
assert image_format in [
|
50 |
+
"RGB",
|
51 |
+
"BGR",
|
52 |
+
], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
|
53 |
+
if image_format != self.model.image_format:
|
54 |
+
image = image[..., ::-1]
|
55 |
+
|
56 |
+
# Transform the image to the form expected by the model
|
57 |
+
input_image = self.transform.apply_image(image)
|
58 |
+
input_image_torch = torch.as_tensor(input_image, device=self.device)
|
59 |
+
input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
|
60 |
+
|
61 |
+
self.set_torch_image(input_image_torch, image.shape[:2])
|
62 |
+
|
63 |
+
@torch.no_grad()
|
64 |
+
def set_torch_image(
|
65 |
+
self,
|
66 |
+
transformed_image: torch.Tensor,
|
67 |
+
original_image_size: Tuple[int, ...],
|
68 |
+
) -> None:
|
69 |
+
"""
|
70 |
+
Calculates the image embeddings for the provided image, allowing
|
71 |
+
masks to be predicted with the 'predict' method. Expects the input
|
72 |
+
image to be already transformed to the format expected by the model.
|
73 |
+
|
74 |
+
Arguments:
|
75 |
+
transformed_image (torch.Tensor): The input image, with shape
|
76 |
+
1x3xHxW, which has been transformed with ResizeLongestSide.
|
77 |
+
original_image_size (tuple(int, int)): The size of the image
|
78 |
+
before transformation, in (H, W) format.
|
79 |
+
"""
|
80 |
+
assert (
|
81 |
+
len(transformed_image.shape) == 4
|
82 |
+
and transformed_image.shape[1] == 3
|
83 |
+
and max(*transformed_image.shape[2:]) == self.model.image_size
|
84 |
+
), f"set_torch_image input must be BCHW with long side {self.model.image_size}."
|
85 |
+
self.reset_image()
|
86 |
+
|
87 |
+
self.original_size = original_image_size
|
88 |
+
self.input_size = tuple(transformed_image.shape[-2:])
|
89 |
+
input_image = self.model.preprocess(transformed_image)
|
90 |
+
self.features, self.interm_feats = self.model.encoder(input_image)
|
91 |
+
self.is_image_set = True
|
92 |
+
|
93 |
+
def predict(
|
94 |
+
self,
|
95 |
+
point_coords: Optional[np.ndarray] = None,
|
96 |
+
point_labels: Optional[np.ndarray] = None,
|
97 |
+
box: Optional[np.ndarray] = None,
|
98 |
+
multimask_output: bool = True,
|
99 |
+
return_logits: bool = False,
|
100 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
101 |
+
"""
|
102 |
+
Predict masks for the given input prompts, using the currently set image.
|
103 |
+
|
104 |
+
Arguments:
|
105 |
+
point_coords (np.ndarray or None): A Nx2 array of point prompts to the
|
106 |
+
model. Each point is in (X,Y) in pixels.
|
107 |
+
point_labels (np.ndarray or None): A length N array of labels for the
|
108 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
109 |
+
background point.
|
110 |
+
box (np.ndarray or None): A length 4 array given a box prompt to the
|
111 |
+
model, in XYXY format.
|
112 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
113 |
+
coming from a previous prediction iteration. Has form 1xHxW, where
|
114 |
+
for SAM, H=W=256.
|
115 |
+
multimask_output (bool): If true, the model will return three masks.
|
116 |
+
For ambiguous input prompts (such as a single click), this will often
|
117 |
+
produce better masks than a single prediction. If only a single
|
118 |
+
mask is needed, the model's predicted quality score can be used
|
119 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
120 |
+
input prompts, multimask_output=False can give better results.
|
121 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
122 |
+
instead of a binary mask.
|
123 |
+
|
124 |
+
Returns:
|
125 |
+
(np.ndarray): The output masks in CxHxW format, where C is the
|
126 |
+
number of masks, and (H, W) is the original image size.
|
127 |
+
(np.ndarray): An array of length C containing the model's
|
128 |
+
predictions for the quality of each mask.
|
129 |
+
(np.ndarray): An array of shape CxHxW, where C is the number
|
130 |
+
of masks and H=W=256. These low resolution logits can be passed to
|
131 |
+
a subsequent iteration as mask input.
|
132 |
+
"""
|
133 |
+
if not self.is_image_set:
|
134 |
+
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
135 |
+
|
136 |
+
# Transform input prompts
|
137 |
+
coords_torch = None
|
138 |
+
labels_torch = None
|
139 |
+
box_torch = None
|
140 |
+
|
141 |
+
if point_coords is not None:
|
142 |
+
assert (
|
143 |
+
point_labels is not None
|
144 |
+
), "point_labels must be supplied if point_coords is supplied."
|
145 |
+
point_coords = self.transform.apply_coords(point_coords, self.original_size)
|
146 |
+
coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
|
147 |
+
labels_torch = torch.as_tensor(point_labels, dtype=torch.float, device=self.device)
|
148 |
+
coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
|
149 |
+
if box is not None:
|
150 |
+
box = self.transform.apply_boxes(box, self.original_size)
|
151 |
+
box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
|
152 |
+
|
153 |
+
masks, iou_predictions, low_res_masks = self.predict_torch(
|
154 |
+
coords_torch,
|
155 |
+
labels_torch,
|
156 |
+
box_torch,
|
157 |
+
multimask_output,
|
158 |
+
return_logits=return_logits,
|
159 |
+
)
|
160 |
+
|
161 |
+
masks_np = masks[0].detach().cpu().numpy()
|
162 |
+
iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
|
163 |
+
low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
|
164 |
+
|
165 |
+
return masks_np, iou_predictions_np, low_res_masks_np
|
166 |
+
|
167 |
+
@torch.no_grad()
|
168 |
+
def predict_torch(
|
169 |
+
self,
|
170 |
+
point_coords: Optional[torch.Tensor],
|
171 |
+
point_labels: Optional[torch.Tensor],
|
172 |
+
boxes: Optional[torch.Tensor] = None,
|
173 |
+
multimask_output: bool = True,
|
174 |
+
return_logits: bool = False,
|
175 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
176 |
+
"""
|
177 |
+
Predict masks for the given input prompts, using the currently set image.
|
178 |
+
Input prompts are batched torch tensors and are expected to already be
|
179 |
+
transformed to the input frame using ResizeLongestSide.
|
180 |
+
|
181 |
+
Arguments:
|
182 |
+
point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
|
183 |
+
model. Each point is in (X,Y) in pixels.
|
184 |
+
point_labels (torch.Tensor or None): A BxN array of labels for the
|
185 |
+
point prompts. 1 indicates a foreground point and 0 indicates a
|
186 |
+
background point.
|
187 |
+
boxes (np.ndarray or None): A Bx4 array given a box prompt to the
|
188 |
+
model, in XYXY format.
|
189 |
+
mask_input (np.ndarray): A low resolution mask input to the model, typically
|
190 |
+
coming from a previous prediction iteration. Has form Bx1xHxW, where
|
191 |
+
for SAM, H=W=256. Masks returned by a previous iteration of the
|
192 |
+
predict method do not need further transformation.
|
193 |
+
multimask_output (bool): If true, the model will return three masks.
|
194 |
+
For ambiguous input prompts (such as a single click), this will often
|
195 |
+
produce better masks than a single prediction. If only a single
|
196 |
+
mask is needed, the model's predicted quality score can be used
|
197 |
+
to select the best mask. For non-ambiguous prompts, such as multiple
|
198 |
+
input prompts, multimask_output=False can give better results.
|
199 |
+
return_logits (bool): If true, returns un-thresholded masks logits
|
200 |
+
instead of a binary mask.
|
201 |
+
|
202 |
+
Returns:
|
203 |
+
(torch.Tensor): The output masks in BxCxHxW format, where C is the
|
204 |
+
number of masks, and (H, W) is the original image size.
|
205 |
+
(torch.Tensor): An array of shape BxC containing the model's
|
206 |
+
predictions for the quality of each mask.
|
207 |
+
(torch.Tensor): An array of shape BxCxHxW, where C is the number
|
208 |
+
of masks and H=W=256. These low res logits can be passed to
|
209 |
+
a subsequent iteration as mask input.
|
210 |
+
"""
|
211 |
+
if not self.is_image_set:
|
212 |
+
raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
|
213 |
+
|
214 |
+
if point_coords is not None:
|
215 |
+
points = (point_coords, point_labels)
|
216 |
+
attn_mask = self.model.point_attn_mask(point_coords)
|
217 |
+
else:
|
218 |
+
points = None
|
219 |
+
attn_mask = self.model.bbox_attn_mask(boxes)
|
220 |
+
|
221 |
+
# Embed prompts
|
222 |
+
masks, iou_predictions = self.model.decoder(
|
223 |
+
interm_feats=self.interm_feats,
|
224 |
+
image_embeddings=self.features,
|
225 |
+
points=points,
|
226 |
+
boxes=boxes,
|
227 |
+
attn_mask=attn_mask,
|
228 |
+
)
|
229 |
+
|
230 |
+
# Select the correct mask or masks for output
|
231 |
+
if multimask_output:
|
232 |
+
mask_slice = slice(0, None)
|
233 |
+
else:
|
234 |
+
mask_slice = slice(0, 1)
|
235 |
+
|
236 |
+
masks = masks[:, mask_slice, :, :]
|
237 |
+
iou_predictions = iou_predictions[:, mask_slice]
|
238 |
+
|
239 |
+
low_res_masks = F.interpolate(masks, scale_factor=2, mode='bilinear', align_corners=False)
|
240 |
+
|
241 |
+
masks = self.model.postprocess_masks(
|
242 |
+
masks,
|
243 |
+
input_size=self.input_size,
|
244 |
+
original_size=self.original_size,
|
245 |
+
)
|
246 |
+
|
247 |
+
return masks.sigmoid(), iou_predictions, low_res_masks.sigmoid()
|
248 |
+
|
249 |
+
def get_image_embedding(self) -> torch.Tensor:
|
250 |
+
"""
|
251 |
+
Returns the image embeddings for the currently set image, with
|
252 |
+
shape 1xCxHxW, where C is the embedding dimension and (H,W) are
|
253 |
+
the embedding spatial dimension of SAM (typically C=256, H=W=64).
|
254 |
+
"""
|
255 |
+
if not self.is_image_set:
|
256 |
+
raise RuntimeError(
|
257 |
+
"An image must be set with .set_image(...) to generate an embedding."
|
258 |
+
)
|
259 |
+
assert self.features is not None, "Features must exist if an image has been set."
|
260 |
+
return self.features
|
261 |
+
|
262 |
+
@property
|
263 |
+
def device(self) -> torch.device:
|
264 |
+
return self.model.device
|
265 |
+
|
266 |
+
def reset_image(self) -> None:
|
267 |
+
"""Resets the currently set image."""
|
268 |
+
self.is_image_set = False
|
269 |
+
self.features = None
|
270 |
+
self.interm_feats = None
|
271 |
+
self.orig_h = None
|
272 |
+
self.orig_w = None
|
273 |
+
self.input_h = None
|
274 |
+
self.input_w = None
|
275 |
+
|
zim/utils/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .argparser import get_parser
|
8 |
+
from .print import print_once, pretty
|
9 |
+
from .utils import AverageMeter, ResizeLongestSide
|
10 |
+
from .amg import show_mat_anns
|
zim/utils/amg.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024-present Naver Cloud Corp.
|
3 |
+
This source code is based on code from the Segment Anything Model (SAM)
|
4 |
+
(https://github.com/facebookresearch/segment-anything).
|
5 |
+
|
6 |
+
This source code is licensed under the license found in the
|
7 |
+
LICENSE file in the root directory of this source tree.
|
8 |
+
"""
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import cv2
|
13 |
+
|
14 |
+
import math
|
15 |
+
from copy import deepcopy
|
16 |
+
from itertools import product
|
17 |
+
from typing import Any, Dict, Generator, ItemsView, List, Tuple
|
18 |
+
|
19 |
+
|
20 |
+
class MaskData:
|
21 |
+
"""
|
22 |
+
A structure for storing masks and their related data in batched format.
|
23 |
+
Implements basic filtering and concatenation.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, **kwargs) -> None:
|
27 |
+
for v in kwargs.values():
|
28 |
+
assert isinstance(
|
29 |
+
v, (list, np.ndarray, torch.Tensor)
|
30 |
+
), "MaskData only supports list, numpy arrays, and torch tensors."
|
31 |
+
self._stats = dict(**kwargs)
|
32 |
+
|
33 |
+
def __setitem__(self, key: str, item: Any) -> None:
|
34 |
+
assert isinstance(
|
35 |
+
item, (list, np.ndarray, torch.Tensor)
|
36 |
+
), "MaskData only supports list, numpy arrays, and torch tensors."
|
37 |
+
self._stats[key] = item
|
38 |
+
|
39 |
+
def __delitem__(self, key: str) -> None:
|
40 |
+
del self._stats[key]
|
41 |
+
|
42 |
+
def __getitem__(self, key: str) -> Any:
|
43 |
+
return self._stats[key]
|
44 |
+
|
45 |
+
def items(self) -> ItemsView[str, Any]:
|
46 |
+
return self._stats.items()
|
47 |
+
|
48 |
+
def filter(self, keep: torch.Tensor) -> None:
|
49 |
+
for k, v in self._stats.items():
|
50 |
+
if v is None:
|
51 |
+
self._stats[k] = None
|
52 |
+
elif isinstance(v, torch.Tensor):
|
53 |
+
self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
|
54 |
+
elif isinstance(v, np.ndarray):
|
55 |
+
self._stats[k] = v[keep.detach().cpu().numpy()]
|
56 |
+
elif isinstance(v, list) and keep.dtype == torch.bool:
|
57 |
+
self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
|
58 |
+
elif isinstance(v, list):
|
59 |
+
self._stats[k] = [v[i] for i in keep]
|
60 |
+
else:
|
61 |
+
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
62 |
+
|
63 |
+
def cat(self, new_stats: "MaskData") -> None:
|
64 |
+
for k, v in new_stats.items():
|
65 |
+
if k not in self._stats or self._stats[k] is None:
|
66 |
+
self._stats[k] = deepcopy(v)
|
67 |
+
elif isinstance(v, torch.Tensor):
|
68 |
+
self._stats[k] = torch.cat([self._stats[k], v], dim=0)
|
69 |
+
elif isinstance(v, np.ndarray):
|
70 |
+
self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
|
71 |
+
elif isinstance(v, list):
|
72 |
+
self._stats[k] = self._stats[k] + deepcopy(v)
|
73 |
+
else:
|
74 |
+
raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
|
75 |
+
|
76 |
+
def to_numpy(self) -> None:
|
77 |
+
for k, v in self._stats.items():
|
78 |
+
if isinstance(v, torch.Tensor):
|
79 |
+
self._stats[k] = v.detach().cpu().numpy()
|
80 |
+
|
81 |
+
|
82 |
+
def is_box_near_crop_edge(
|
83 |
+
boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
|
84 |
+
) -> torch.Tensor:
|
85 |
+
"""Filter masks at the edge of a crop, but not at the edge of the original image."""
|
86 |
+
crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
|
87 |
+
orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
|
88 |
+
boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
|
89 |
+
near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
|
90 |
+
near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
|
91 |
+
near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
|
92 |
+
return torch.any(near_crop_edge, dim=1)
|
93 |
+
|
94 |
+
|
95 |
+
def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
|
96 |
+
box_xywh = deepcopy(box_xyxy)
|
97 |
+
box_xywh[2] = box_xywh[2] - box_xywh[0]
|
98 |
+
box_xywh[3] = box_xywh[3] - box_xywh[1]
|
99 |
+
return box_xywh
|
100 |
+
|
101 |
+
|
102 |
+
def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
|
103 |
+
assert len(args) > 0 and all(
|
104 |
+
len(a) == len(args[0]) for a in args
|
105 |
+
), "Batched iteration must have inputs of all the same size."
|
106 |
+
n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
|
107 |
+
for b in range(n_batches):
|
108 |
+
yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
|
109 |
+
|
110 |
+
|
111 |
+
def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
|
112 |
+
"""
|
113 |
+
Encodes masks to an uncompressed RLE, in the format expected by
|
114 |
+
pycoco tools.
|
115 |
+
"""
|
116 |
+
# Put in fortran order and flatten h,w
|
117 |
+
b, h, w = tensor.shape
|
118 |
+
tensor = tensor.permute(0, 2, 1).flatten(1)
|
119 |
+
|
120 |
+
# Compute change indices
|
121 |
+
diff = tensor[:, 1:] ^ tensor[:, :-1]
|
122 |
+
change_indices = diff.nonzero()
|
123 |
+
|
124 |
+
# Encode run length
|
125 |
+
out = []
|
126 |
+
for i in range(b):
|
127 |
+
cur_idxs = change_indices[change_indices[:, 0] == i, 1]
|
128 |
+
cur_idxs = torch.cat(
|
129 |
+
[
|
130 |
+
torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
131 |
+
cur_idxs + 1,
|
132 |
+
torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
|
133 |
+
]
|
134 |
+
)
|
135 |
+
btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
|
136 |
+
counts = [] if tensor[i, 0] == 0 else [0]
|
137 |
+
counts.extend(btw_idxs.detach().cpu().tolist())
|
138 |
+
out.append({"size": [h, w], "counts": counts})
|
139 |
+
return out
|
140 |
+
|
141 |
+
|
142 |
+
def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
|
143 |
+
"""Compute a binary mask from an uncompressed RLE."""
|
144 |
+
h, w = rle["size"]
|
145 |
+
mask = np.empty(h * w, dtype=bool)
|
146 |
+
idx = 0
|
147 |
+
parity = False
|
148 |
+
for count in rle["counts"]:
|
149 |
+
mask[idx : idx + count] = parity
|
150 |
+
idx += count
|
151 |
+
parity ^= True
|
152 |
+
mask = mask.reshape(w, h)
|
153 |
+
return mask.transpose() # Put in C order
|
154 |
+
|
155 |
+
|
156 |
+
def area_from_rle(rle: Dict[str, Any]) -> int:
|
157 |
+
return sum(rle["counts"][1::2])
|
158 |
+
|
159 |
+
|
160 |
+
def calculate_stability_score(
|
161 |
+
masks: torch.Tensor, mask_threshold: float, threshold_offset: float
|
162 |
+
) -> torch.Tensor:
|
163 |
+
"""
|
164 |
+
Computes the stability score for a batch of masks. The stability
|
165 |
+
score is the IoU between the binary masks obtained by thresholding
|
166 |
+
the predicted mask logits at high and low values.
|
167 |
+
"""
|
168 |
+
# One mask is always contained inside the other.
|
169 |
+
# Save memory by preventing unnecessary cast to torch.int64
|
170 |
+
intersections = (
|
171 |
+
(masks > (mask_threshold + threshold_offset))
|
172 |
+
.sum(-1, dtype=torch.int16)
|
173 |
+
.sum(-1, dtype=torch.int32)
|
174 |
+
)
|
175 |
+
unions = (
|
176 |
+
(masks > (mask_threshold - threshold_offset))
|
177 |
+
.sum(-1, dtype=torch.int16)
|
178 |
+
.sum(-1, dtype=torch.int32)
|
179 |
+
)
|
180 |
+
return intersections / unions
|
181 |
+
|
182 |
+
|
183 |
+
def build_point_grid(n_per_side: int) -> np.ndarray:
|
184 |
+
"""Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
|
185 |
+
offset = 1 / (2 * n_per_side)
|
186 |
+
points_one_side = np.linspace(offset, 1 - offset, n_per_side)
|
187 |
+
points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
|
188 |
+
points_y = np.tile(points_one_side[:, None], (1, n_per_side))
|
189 |
+
points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
|
190 |
+
return points
|
191 |
+
|
192 |
+
|
193 |
+
def build_all_layer_point_grids(
|
194 |
+
n_per_side: int, n_layers: int, scale_per_layer: int
|
195 |
+
) -> List[np.ndarray]:
|
196 |
+
"""Generates point grids for all crop layers."""
|
197 |
+
points_by_layer = []
|
198 |
+
for i in range(n_layers + 1):
|
199 |
+
n_points = int(n_per_side / (scale_per_layer**i))
|
200 |
+
points_by_layer.append(build_point_grid(n_points))
|
201 |
+
return points_by_layer
|
202 |
+
|
203 |
+
|
204 |
+
def generate_crop_boxes(
|
205 |
+
im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
|
206 |
+
) -> Tuple[List[List[int]], List[int]]:
|
207 |
+
"""
|
208 |
+
Generates a list of crop boxes of different sizes. Each layer
|
209 |
+
has (2**i)**2 boxes for the ith layer.
|
210 |
+
"""
|
211 |
+
crop_boxes, layer_idxs = [], []
|
212 |
+
im_h, im_w = im_size
|
213 |
+
short_side = min(im_h, im_w)
|
214 |
+
|
215 |
+
# Original image
|
216 |
+
crop_boxes.append([0, 0, im_w, im_h])
|
217 |
+
layer_idxs.append(0)
|
218 |
+
|
219 |
+
def crop_len(orig_len, n_crops, overlap):
|
220 |
+
return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
|
221 |
+
|
222 |
+
for i_layer in range(n_layers):
|
223 |
+
n_crops_per_side = 2 ** (i_layer + 1)
|
224 |
+
overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
|
225 |
+
|
226 |
+
crop_w = crop_len(im_w, n_crops_per_side, overlap)
|
227 |
+
crop_h = crop_len(im_h, n_crops_per_side, overlap)
|
228 |
+
|
229 |
+
crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
|
230 |
+
crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
|
231 |
+
|
232 |
+
# Crops in XYWH format
|
233 |
+
for x0, y0 in product(crop_box_x0, crop_box_y0):
|
234 |
+
box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
|
235 |
+
crop_boxes.append(box)
|
236 |
+
layer_idxs.append(i_layer + 1)
|
237 |
+
|
238 |
+
return crop_boxes, layer_idxs
|
239 |
+
|
240 |
+
|
241 |
+
def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
242 |
+
x0, y0, _, _ = crop_box
|
243 |
+
offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
|
244 |
+
# Check if boxes has a channel dimension
|
245 |
+
if len(boxes.shape) == 3:
|
246 |
+
offset = offset.unsqueeze(1)
|
247 |
+
return boxes + offset
|
248 |
+
|
249 |
+
|
250 |
+
def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
|
251 |
+
x0, y0, _, _ = crop_box
|
252 |
+
offset = torch.tensor([[x0, y0]], device=points.device)
|
253 |
+
# Check if points has a channel dimension
|
254 |
+
if len(points.shape) == 3:
|
255 |
+
offset = offset.unsqueeze(1)
|
256 |
+
return points + offset
|
257 |
+
|
258 |
+
|
259 |
+
def uncrop_masks(
|
260 |
+
masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
|
261 |
+
) -> torch.Tensor:
|
262 |
+
x0, y0, x1, y1 = crop_box
|
263 |
+
if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
|
264 |
+
return masks
|
265 |
+
# Coordinate transform masks
|
266 |
+
pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
|
267 |
+
pad = (x0, pad_x - x0, y0, pad_y - y0)
|
268 |
+
return torch.nn.functional.pad(masks, pad, value=0)
|
269 |
+
|
270 |
+
|
271 |
+
def remove_small_regions(
|
272 |
+
mask: np.ndarray, area_thresh: float, mode: str
|
273 |
+
) -> Tuple[np.ndarray, bool]:
|
274 |
+
"""
|
275 |
+
Removes small disconnected regions and holes in a mask. Returns the
|
276 |
+
mask and an indicator of if the mask has been modified.
|
277 |
+
"""
|
278 |
+
import cv2 # type: ignore
|
279 |
+
|
280 |
+
assert mode in ["holes", "islands"]
|
281 |
+
correct_holes = mode == "holes"
|
282 |
+
working_mask = (correct_holes ^ mask).astype(np.uint8)
|
283 |
+
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
|
284 |
+
sizes = stats[:, -1][1:] # Row 0 is background label
|
285 |
+
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
|
286 |
+
if len(small_regions) == 0:
|
287 |
+
return mask, False
|
288 |
+
fill_labels = [0] + small_regions
|
289 |
+
if not correct_holes:
|
290 |
+
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
|
291 |
+
# If every region is below threshold, keep largest
|
292 |
+
if len(fill_labels) == 0:
|
293 |
+
fill_labels = [int(np.argmax(sizes)) + 1]
|
294 |
+
mask = np.isin(regions, fill_labels)
|
295 |
+
return mask, True
|
296 |
+
|
297 |
+
|
298 |
+
def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
|
299 |
+
from pycocotools import mask as mask_utils # type: ignore
|
300 |
+
|
301 |
+
h, w = uncompressed_rle["size"]
|
302 |
+
rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
|
303 |
+
rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
|
304 |
+
return rle
|
305 |
+
|
306 |
+
|
307 |
+
def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
|
308 |
+
"""
|
309 |
+
Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
|
310 |
+
an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
|
311 |
+
"""
|
312 |
+
# torch.max below raises an error on empty inputs, just skip in this case
|
313 |
+
if torch.numel(masks) == 0:
|
314 |
+
return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
|
315 |
+
|
316 |
+
# Normalize shape to CxHxW
|
317 |
+
shape = masks.shape
|
318 |
+
h, w = shape[-2:]
|
319 |
+
if len(shape) > 2:
|
320 |
+
masks = masks.flatten(0, -3)
|
321 |
+
else:
|
322 |
+
masks = masks.unsqueeze(0)
|
323 |
+
|
324 |
+
# Get top and bottom edges
|
325 |
+
in_height, _ = torch.max(masks, dim=-1)
|
326 |
+
in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
|
327 |
+
bottom_edges, _ = torch.max(in_height_coords, dim=-1)
|
328 |
+
in_height_coords = in_height_coords + h * (~in_height)
|
329 |
+
top_edges, _ = torch.min(in_height_coords, dim=-1)
|
330 |
+
|
331 |
+
# Get left and right edges
|
332 |
+
in_width, _ = torch.max(masks, dim=-2)
|
333 |
+
in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
|
334 |
+
right_edges, _ = torch.max(in_width_coords, dim=-1)
|
335 |
+
in_width_coords = in_width_coords + w * (~in_width)
|
336 |
+
left_edges, _ = torch.min(in_width_coords, dim=-1)
|
337 |
+
|
338 |
+
# If the mask is empty the right edge will be to the left of the left edge.
|
339 |
+
# Replace these boxes with [0, 0, 0, 0]
|
340 |
+
empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
|
341 |
+
out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
|
342 |
+
out = out * (~empty_filter).unsqueeze(-1)
|
343 |
+
|
344 |
+
# Return to original shape
|
345 |
+
if len(shape) > 2:
|
346 |
+
out = out.reshape(*shape[:-2], 4)
|
347 |
+
else:
|
348 |
+
out = out[0]
|
349 |
+
|
350 |
+
return out
|
351 |
+
|
352 |
+
def show_mat_anns(image, anns):
|
353 |
+
if len(anns) == 0:
|
354 |
+
return np.zeros_like(image) + 128
|
355 |
+
|
356 |
+
sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
|
357 |
+
|
358 |
+
image = image.astype(np.float32)
|
359 |
+
colorized_mat = np.zeros_like(image)
|
360 |
+
|
361 |
+
for ann in sorted_anns:
|
362 |
+
color = (np.random.random(3) * 255).astype(np.float32)
|
363 |
+
if 'logit' in ann:
|
364 |
+
mat = ann['logit'].astype(np.float32) / 255.
|
365 |
+
else:
|
366 |
+
mat = ann['segmentation'].astype(np.float32)
|
367 |
+
|
368 |
+
color_mat = np.zeros_like(image) + color[None, None]
|
369 |
+
colorized_mat = color_mat * mat[:, :, None] + colorized_mat * (1. - mat[:, :, None])
|
370 |
+
|
371 |
+
colorized_mat = np.uint8(colorized_mat)
|
372 |
+
|
373 |
+
return colorized_mat
|
zim/utils/argparser.py
ADDED
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024-present Naver Cloud Corp.
|
3 |
+
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import os
|
9 |
+
import argparse
|
10 |
+
from config.config import config_
|
11 |
+
|
12 |
+
def str2bool(v):
|
13 |
+
if isinstance(v, bool):
|
14 |
+
return v
|
15 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
16 |
+
return True
|
17 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
18 |
+
return False
|
19 |
+
else:
|
20 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
21 |
+
|
22 |
+
|
23 |
+
def get_parser(verbose=False):
|
24 |
+
p = argparse.ArgumentParser("argparser", add_help=False)
|
25 |
+
|
26 |
+
p.add_argument(
|
27 |
+
"--data-root", type=str, default=config_.data_root, help="data root directory"
|
28 |
+
)
|
29 |
+
p.add_argument(
|
30 |
+
"--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0"))
|
31 |
+
)
|
32 |
+
p.add_argument(
|
33 |
+
"--amp", type=str2bool, default=True
|
34 |
+
)
|
35 |
+
p.add_argument(
|
36 |
+
"--ddp", action="store_true"
|
37 |
+
)
|
38 |
+
p.add_argument(
|
39 |
+
"--random-seed", type=int, default=config_.random_seed
|
40 |
+
)
|
41 |
+
|
42 |
+
# network config
|
43 |
+
p.add_argument(
|
44 |
+
"--network-encoder",
|
45 |
+
type=str,
|
46 |
+
default=config_.network.encoder,
|
47 |
+
choices=["vit_b", "vit_l"],
|
48 |
+
)
|
49 |
+
p.add_argument(
|
50 |
+
"--network-decoder",
|
51 |
+
type=str,
|
52 |
+
default=config_.network.decoder,
|
53 |
+
choices=["zim", "sam"],
|
54 |
+
)
|
55 |
+
p.add_argument(
|
56 |
+
"--network-encode-kernel",
|
57 |
+
type=int,
|
58 |
+
default=config_.network.encode_kernel,
|
59 |
+
)
|
60 |
+
|
61 |
+
# evaluation config
|
62 |
+
p.add_argument(
|
63 |
+
"--eval-workers", type=int, default=config_.eval.workers,
|
64 |
+
)
|
65 |
+
p.add_argument(
|
66 |
+
"--eval-image-size", type=int, default=config_.eval.image_size,
|
67 |
+
)
|
68 |
+
p.add_argument(
|
69 |
+
"--eval-prompt-type", type=str, default=config_.eval.prompt_type,
|
70 |
+
)
|
71 |
+
p.add_argument(
|
72 |
+
"--eval-model-list", type=str, default=config_.eval.model_list,
|
73 |
+
)
|
74 |
+
p.add_argument(
|
75 |
+
"--eval-zim-weights",
|
76 |
+
type=str,
|
77 |
+
default=config_.eval.zim_weights,
|
78 |
+
)
|
79 |
+
p.add_argument(
|
80 |
+
"--eval-sam-weights",
|
81 |
+
type=str,
|
82 |
+
default=config_.eval.sam_weights,
|
83 |
+
)
|
84 |
+
|
85 |
+
# dataset config
|
86 |
+
p.add_argument(
|
87 |
+
"--dataset-valset", type=str, default=config_.dataset.valset,
|
88 |
+
)
|
89 |
+
p.add_argument(
|
90 |
+
"--dataset-data-type", type=str, default=config_.dataset.data_type,
|
91 |
+
)
|
92 |
+
p.add_argument(
|
93 |
+
"--dataset-data-list-txt", type=str, default=config_.dataset.data_list_txt,
|
94 |
+
)
|
95 |
+
|
96 |
+
return p
|
zim/utils/print.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024-present Naver Cloud Corp.
|
3 |
+
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import torch
|
9 |
+
|
10 |
+
def print_once(message):
|
11 |
+
if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
|
12 |
+
print(message)
|
13 |
+
|
14 |
+
def pretty(d, indent=0):
|
15 |
+
for key, value in d.items():
|
16 |
+
print_once("\t" * indent + str(key))
|
17 |
+
if isinstance(value, dict):
|
18 |
+
pretty(value, indent + 1)
|
19 |
+
else:
|
20 |
+
print_once("\t" * (indent + 1) + str(value))
|
zim/utils/utils.py
ADDED
@@ -0,0 +1,148 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Copyright (c) 2024-present Naver Cloud Corp.
|
3 |
+
|
4 |
+
This source code is licensed under the license found in the
|
5 |
+
LICENSE file in the root directory of this source tree.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch
|
10 |
+
from torch.nn import functional as F
|
11 |
+
from torchvision.transforms.functional import resize, to_pil_image, InterpolationMode
|
12 |
+
from copy import deepcopy
|
13 |
+
from typing import Optional, Tuple, List
|
14 |
+
|
15 |
+
class ResizeLongestSide:
|
16 |
+
"""
|
17 |
+
Resizes images to the longest side 'target_length', as well as provides
|
18 |
+
methods for resizing coordinates and boxes. Provides methods for
|
19 |
+
transforming both numpy array and batched torch tensors.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, target_length: int) -> None:
|
23 |
+
self.target_length = target_length
|
24 |
+
|
25 |
+
def apply_image(self, image: np.ndarray) -> np.ndarray:
|
26 |
+
"""
|
27 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
28 |
+
"""
|
29 |
+
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
|
30 |
+
return np.array(resize(to_pil_image(image), target_size))
|
31 |
+
|
32 |
+
def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
33 |
+
"""
|
34 |
+
Expects a numpy array of length 2 in the final dimension. Requires the
|
35 |
+
original image size in (H, W) format.
|
36 |
+
"""
|
37 |
+
old_h, old_w = original_size
|
38 |
+
new_h, new_w = self.get_preprocess_shape(
|
39 |
+
original_size[0], original_size[1], self.target_length
|
40 |
+
)
|
41 |
+
coords = deepcopy(coords).astype(float)
|
42 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
43 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
44 |
+
return coords
|
45 |
+
|
46 |
+
def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
|
47 |
+
"""
|
48 |
+
Expects a numpy array shape Bx4. Requires the original image size
|
49 |
+
in (H, W) format.
|
50 |
+
"""
|
51 |
+
boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
|
52 |
+
return boxes.reshape(-1, 4)
|
53 |
+
|
54 |
+
def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
|
55 |
+
"""
|
56 |
+
Expects batched images with shape BxCxHxW and float format. This
|
57 |
+
transformation may not exactly match apply_image. apply_image is
|
58 |
+
the transformation expected by the model.
|
59 |
+
"""
|
60 |
+
# Expects an image in BCHW format. May not exactly match apply_image.
|
61 |
+
target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
|
62 |
+
return F.interpolate(
|
63 |
+
image, target_size, mode="bilinear", align_corners=False, antialias=True
|
64 |
+
)
|
65 |
+
|
66 |
+
def apply_coords_torch(
|
67 |
+
self, coords: torch.Tensor, original_size: Tuple[int, ...]
|
68 |
+
) -> torch.Tensor:
|
69 |
+
"""
|
70 |
+
Expects a torch tensor with length 2 in the last dimension. Requires the
|
71 |
+
original image size in (H, W) format.
|
72 |
+
"""
|
73 |
+
old_h, old_w = original_size
|
74 |
+
new_h, new_w = self.get_preprocess_shape(
|
75 |
+
original_size[0], original_size[1], self.target_length
|
76 |
+
)
|
77 |
+
coords = deepcopy(coords).to(torch.float)
|
78 |
+
coords[..., 0] = coords[..., 0] * (new_w / old_w)
|
79 |
+
coords[..., 1] = coords[..., 1] * (new_h / old_h)
|
80 |
+
return coords
|
81 |
+
|
82 |
+
def apply_boxes_torch(
|
83 |
+
self, boxes: torch.Tensor, original_size: Tuple[int, ...]
|
84 |
+
) -> torch.Tensor:
|
85 |
+
"""
|
86 |
+
Expects a torch tensor with shape Bx4. Requires the original image
|
87 |
+
size in (H, W) format.
|
88 |
+
"""
|
89 |
+
boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
|
90 |
+
return boxes.reshape(-1, 4)
|
91 |
+
|
92 |
+
def apply_mask(self, image: np.ndarray) -> np.ndarray:
|
93 |
+
"""
|
94 |
+
Expects a numpy array with shape HxWxC in uint8 format.
|
95 |
+
"""
|
96 |
+
target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
|
97 |
+
return np.array(resize(to_pil_image(image), target_size, interpolation=InterpolationMode.NEAREST))
|
98 |
+
|
99 |
+
@staticmethod
|
100 |
+
def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
|
101 |
+
"""
|
102 |
+
Compute the output size given input size and target long side length.
|
103 |
+
"""
|
104 |
+
scale = long_side_length * 1.0 / max(oldh, oldw)
|
105 |
+
newh, neww = oldh * scale, oldw * scale
|
106 |
+
neww = int(neww + 0.5)
|
107 |
+
newh = int(newh + 0.5)
|
108 |
+
return (newh, neww)
|
109 |
+
|
110 |
+
|
111 |
+
def remove_prefix(text, prefix):
|
112 |
+
if text.startswith(prefix):
|
113 |
+
return text[len(prefix) :]
|
114 |
+
return text
|
115 |
+
|
116 |
+
class AverageMeter(object):
|
117 |
+
"""Computes and stores the average and current value"""
|
118 |
+
|
119 |
+
def __init__(self, is_ddp):
|
120 |
+
self.is_ddp = is_ddp
|
121 |
+
self.reset()
|
122 |
+
|
123 |
+
def reset(self):
|
124 |
+
self.val = 0.0
|
125 |
+
self.avg = 0.0
|
126 |
+
self.sum = 0.0
|
127 |
+
self.count = 0.0
|
128 |
+
|
129 |
+
def update(self, val, n=1):
|
130 |
+
self.val = val
|
131 |
+
self.sum += val * n
|
132 |
+
self.count += n
|
133 |
+
self.avg = self.sum / (self.count + 1e-5)
|
134 |
+
|
135 |
+
def synch(self, device):
|
136 |
+
if self.is_ddp is False:
|
137 |
+
return
|
138 |
+
|
139 |
+
_sum = torch.tensor(self.sum).to(device)
|
140 |
+
_count = torch.tensor(self.count).to(device)
|
141 |
+
|
142 |
+
torch.distributed.reduce(_sum, dst=0)
|
143 |
+
torch.distributed.reduce(_count, dst=0)
|
144 |
+
|
145 |
+
if torch.distributed.get_rank() == 0:
|
146 |
+
self.sum = _sum.item()
|
147 |
+
self.count = _count.item()
|
148 |
+
self.avg = self.sum / (self.count + 1e-5)
|