scyonggg commited on
Commit
9860a06
1 Parent(s): 4ef8570

Initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ ckpts/ filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,419 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024-present Naver Cloud Corp.
3
+ This source code is based on code from the Segment Anything Model (SAM)
4
+ (https://github.com/facebookresearch/segment-anything).
5
+
6
+ This source code is licensed under the license found in the
7
+ LICENSE file in the root directory of this source tree.
8
+ """
9
+ import os, sys
10
+ sys.path.append(os.getcwd())
11
+
12
+ # Gradio demo, comparison SAM vs ZIM
13
+ import os
14
+ import torch
15
+ import gradio as gr
16
+ from gradio_image_prompter import ImagePrompter
17
+ import numpy as np
18
+ import cv2
19
+ from zim import zim_model_registry, ZimPredictor, ZimAutomaticMaskGenerator
20
+ from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
21
+ from zim.utils import show_mat_anns
22
+
23
+ def get_shortest_axis(image):
24
+ h, w, _ = image.shape
25
+ return h if h < w else w
26
+
27
+ def reset_image(image, prompts):
28
+ if image is None:
29
+ image = np.zeros((1024, 1024, 3), dtype=np.uint8)
30
+ else:
31
+ image = image['image']
32
+ zim_predictor.set_image(image)
33
+ sam_predictor.set_image(image)
34
+ prompts = dict()
35
+ black = np.zeros(image.shape[:2], dtype=np.uint8)
36
+
37
+ return (image, image, image, image, black, black, black, black, prompts)
38
+
39
+ def reset_example_image(image, prompts):
40
+ if image is None:
41
+ image = np.zeros((1024, 1024, 3), dtype=np.uint8)
42
+
43
+ zim_predictor.set_image(image)
44
+ sam_predictor.set_image(image)
45
+ prompts = dict()
46
+ black = np.zeros(image.shape[:2], dtype=np.uint8)
47
+
48
+ image_dict = {}
49
+ image_dict['image'] = image
50
+ image_dict['prompts'] = prompts
51
+
52
+ return (image, image_dict, image, image, image, black, black, black, black, prompts)
53
+
54
+ def run_amg(image):
55
+ gr.Info('Checkout ZIM Auto Mask tab.', duration=3)
56
+ zim_masks = zim_mask_generator.generate(image)
57
+ zim_masks_vis = show_mat_anns(image, zim_masks)
58
+
59
+ sam_masks = sam_mask_generator.generate(image)
60
+ sam_masks_vis = show_mat_anns(image, sam_masks)
61
+
62
+ return zim_masks_vis, sam_masks_vis
63
+
64
+
65
+ def run_model(image, prompts):
66
+ if not prompts:
67
+ raise gr.Error(f'Please input any point or BBox')
68
+ gr.Info('Checkout ZIM Mask tab.', duration=3)
69
+ point_coords = None
70
+ point_labels = None
71
+ boxes = None
72
+
73
+ if "point" in prompts:
74
+ point_coords, point_labels = [], []
75
+
76
+ for type, pts in prompts["point"]:
77
+ point_coords.append(pts)
78
+ point_labels.append(type)
79
+ point_coords = np.array(point_coords)
80
+ point_labels = np.array(point_labels)
81
+
82
+ if "bbox" in prompts:
83
+ boxes = prompts['bbox']
84
+ boxes = np.array(boxes)
85
+
86
+ if "scribble" in prompts:
87
+ point_coords, point_labels = [], []
88
+
89
+ for pts in prompts["scribble"]:
90
+ point_coords.append(np.flip(pts))
91
+ point_labels.append(1)
92
+ if len(point_coords) == 0:
93
+ raise gr.Error("Please input any scribbles.")
94
+ point_coords = np.array(point_coords)
95
+ point_labels = np.array(point_labels)
96
+
97
+ # run ZIM
98
+ zim_mask, _, _ = zim_predictor.predict(
99
+ point_coords=point_coords,
100
+ point_labels=point_labels,
101
+ box=boxes,
102
+ multimask_output=False,
103
+ )
104
+ zim_mask = np.squeeze(zim_mask, axis=0)
105
+ zim_mask = np.uint8(zim_mask * 255)
106
+
107
+ # run SAM
108
+ sam_mask, _, _ = sam_predictor.predict(
109
+ point_coords=point_coords,
110
+ point_labels=point_labels,
111
+ box=boxes,
112
+ multimask_output=False,
113
+ )
114
+ sam_mask = np.squeeze(sam_mask, axis=0)
115
+ sam_mask = np.uint8(sam_mask * 255)
116
+
117
+ return zim_mask, sam_mask
118
+
119
+ def reset_scribble(image, scribble, prompts):
120
+ # scribble = dict()
121
+ for k in prompts.keys():
122
+ prompts[k] = []
123
+
124
+ for k, v in scribble.items():
125
+ scribble[k] = None
126
+
127
+ black = np.zeros(image.shape[:3], dtype=np.uint8)
128
+
129
+ return scribble, black, black
130
+
131
+ def update_scribble(image, scribble, prompts):
132
+ if "point" in prompts:
133
+ del prompts["point"]
134
+
135
+ if "bbox" in prompts:
136
+ del prompts["bbox"]
137
+
138
+ prompts = dict() # reset prompt
139
+ scribble_mask = scribble["layers"][0][..., -1] > 0
140
+
141
+ scribble_coords = np.argwhere(scribble_mask)
142
+ n_points = min(len(scribble_coords), 24)
143
+ indices = np.linspace(0, len(scribble_coords)-1, n_points, dtype=int)
144
+ scribble_sampled = scribble_coords[indices]
145
+
146
+ prompts["scribble"] = scribble_sampled
147
+
148
+ zim_mask, sam_mask = run_model(image, prompts)
149
+
150
+ return zim_mask, sam_mask, prompts
151
+
152
+
153
+ def draw_point(img, pt, size, color):
154
+ # draw circle with white boundary region
155
+ cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 1.3), (255, 255, 255), -1)
156
+ cv2.circle(img, (int(pt[0]), int(pt[1])), int(size * 0.9), color, -1)
157
+
158
+
159
+ def draw_images(image, mask, prompts):
160
+ if len(prompts) == 0 or mask.shape[1] == 1:
161
+ return image, image, image
162
+
163
+ minor = get_shortest_axis(image)
164
+ size = int(minor / 80)
165
+
166
+ image = np.float32(image)
167
+
168
+ def blending(image, mask):
169
+ mask = np.float32(mask) / 255
170
+ blended_image = np.zeros_like(image, dtype=np.float32)
171
+ blended_image[:, :, :] = [108, 0, 192]
172
+ blended_image = (image * 0.5) + (blended_image * 0.5)
173
+
174
+ img_with_mask = mask[:, :, None] * blended_image + (1 - mask[:, :, None]) * image
175
+ img_with_mask = np.uint8(img_with_mask)
176
+
177
+ return img_with_mask
178
+
179
+ img_with_mask = blending(image, mask)
180
+ img_with_point = img_with_mask.copy()
181
+
182
+ if "point" in prompts:
183
+ for type, pts in prompts["point"]:
184
+ if type == "Positive":
185
+ color = (0, 0, 255)
186
+ draw_point(img_with_point, pts, size, color)
187
+ elif type == "Negative":
188
+ color = (255, 0, 0)
189
+ draw_point(img_with_point, pts, size, color)
190
+
191
+ size = int(minor / 200)
192
+
193
+ return (
194
+ img,
195
+ img_with_mask,
196
+ )
197
+
198
+ def get_point_or_box_prompts(img, prompts):
199
+ image, img_prompts = img['image'], img['points']
200
+ point_prompts = []
201
+ box_prompts = []
202
+ for prompt in img_prompts:
203
+ for p in range(len(prompt)):
204
+ prompt[p] = int(prompt[p])
205
+ if prompt[2] == 2 and prompt[5] == 3: # box prompt
206
+ if len(box_prompts) != 0:
207
+ raise gr.Error("Please input only one BBox.", duration=3)
208
+ box_prompts.append([prompt[0], prompt[1], prompt[3], prompt[4]])
209
+ elif prompt[2] == 1 and prompt[5] == 4: # Positive point prompt
210
+ point_prompts.append((1, (prompt[0], prompt[1])))
211
+ elif prompt[2] == 0 and prompt[5] == 4: # Negative point prompt
212
+ point_prompts.append((0, (prompt[0], prompt[1])))
213
+
214
+ if "scribble" in prompts:
215
+ del prompts["scribble"]
216
+
217
+ if len(point_prompts) > 0:
218
+ prompts['point'] = point_prompts
219
+ elif 'point' in prompts:
220
+ del prompts['point']
221
+
222
+ if len(box_prompts) > 0:
223
+ prompts['bbox'] = box_prompts
224
+ elif 'bbox' in prompts:
225
+ del prompts['bbox']
226
+
227
+ zim_mask, sam_mask = run_model(image, prompts)
228
+
229
+ return image, zim_mask, sam_mask, prompts
230
+
231
+ def get_examples():
232
+ assets_dir = os.path.join(os.path.dirname(__file__), 'examples')
233
+ images = os.listdir(assets_dir)
234
+ return [os.path.join(assets_dir, img) for img in images]
235
+
236
+ if __name__ == "__main__":
237
+ backbone = "vit_b"
238
+
239
+ # load ZIM
240
+ ckpt_mat = "ckpts/zim_vit_b_2043"
241
+ zim = zim_model_registry[backbone](checkpoint=ckpt_mat)
242
+ if torch.cuda.is_available():
243
+ zim.cuda()
244
+ zim_predictor = ZimPredictor(zim)
245
+ zim_mask_generator = ZimAutomaticMaskGenerator(
246
+ zim,
247
+ pred_iou_thresh=0.7,
248
+ points_per_batch=8,
249
+ stability_score_thresh=0.9,
250
+ )
251
+
252
+ # load SAM
253
+ ckpt_sam = "ckpts/sam_vit_b_01ec64.pth"
254
+ sam = sam_model_registry[backbone](checkpoint=ckpt_sam)
255
+ if torch.cuda.is_available():
256
+ sam.cuda()
257
+ sam_predictor = SamPredictor(sam)
258
+ sam_mask_generator = SamAutomaticMaskGenerator(
259
+ sam,
260
+ points_per_batch=8,
261
+ )
262
+
263
+ with gr.Blocks() as demo:
264
+ gr.Markdown("# <center> [Demo] ZIM: Zero-Shot Image Matting for Anything")
265
+
266
+ prompts = gr.State(dict())
267
+ img = gr.Image(visible=False)
268
+ example_image = gr.Image(visible=False)
269
+
270
+ with gr.Row():
271
+ with gr.Column():
272
+ # Point and Bbox prompt
273
+ with gr.Tab(label="Point or Box"):
274
+ img_with_point_or_box = ImagePrompter(
275
+ label="query image",
276
+ sources="upload"
277
+ )
278
+ interactions = "Left Click (Pos) | Middle/Right Click (Neg) | Press Move (Box)"
279
+ gr.Markdown("<h3 style='text-align: center'> {} </h3>".format(interactions))
280
+ run_bttn = gr.Button("Run")
281
+ amg_bttn = gr.Button("Automatic Mask Generation")
282
+
283
+ # Scribble prompt
284
+ with gr.Tab(label="Scribble"):
285
+ img_with_scribble = gr.ImageEditor(
286
+ label="Scribble",
287
+ brush=gr.Brush(colors=["#00FF00"], default_size=15),
288
+ sources="upload",
289
+ transforms=None,
290
+ layers=False
291
+ )
292
+ interactions = "Press Move (Scribble)"
293
+ gr.Markdown("<h3 style='text-align: center'> Step 1. Select Draw button </h3>")
294
+ gr.Markdown("<h3 style='text-align: center'> Step 2. {} </h3>".format(interactions))
295
+ scribble_bttn = gr.Button("Run")
296
+ scribble_reset_bttn = gr.Button("Reset Scribbles")
297
+ amg_scribble_bttn = gr.Button("Automatic Mask Generation")
298
+
299
+ # Example image
300
+ gr.Examples(get_examples(), inputs=[example_image])
301
+
302
+ # with gr.Row():
303
+ with gr.Column():
304
+ with gr.Tab(label="ZIM Image"):
305
+ img_with_zim_mask = gr.Image(
306
+ label="ZIM Image",
307
+ interactive=False
308
+ )
309
+
310
+ with gr.Tab(label="ZIM Mask"):
311
+ zim_mask = gr.Image(
312
+ label="ZIM Mask",
313
+ image_mode="L",
314
+ interactive=False
315
+ )
316
+ with gr.Tab(label="ZIM Auto Mask"):
317
+ zim_amg = gr.Image(
318
+ label="ZIM Auto Mask",
319
+ interactive=False
320
+ )
321
+
322
+ with gr.Column():
323
+ with gr.Tab(label="SAM Image"):
324
+ img_with_sam_mask = gr.Image(
325
+ label="SAM image",
326
+ interactive=False
327
+ )
328
+
329
+ with gr.Tab(label="SAM Mask"):
330
+ sam_mask = gr.Image(
331
+ label="SAM Mask",
332
+ image_mode="L",
333
+ interactive=False
334
+ )
335
+
336
+ with gr.Tab(label="SAM Auto Mask"):
337
+ sam_amg = gr.Image(
338
+ label="SAM Auto Mask",
339
+ interactive=False
340
+ )
341
+
342
+ example_image.change(
343
+ reset_example_image,
344
+ [example_image, prompts],
345
+ [
346
+ img,
347
+ img_with_point_or_box,
348
+ img_with_scribble,
349
+ img_with_zim_mask,
350
+ img_with_sam_mask,
351
+ zim_amg,
352
+ sam_amg,
353
+ zim_mask,
354
+ sam_mask,
355
+ prompts,
356
+ ]
357
+ )
358
+
359
+ img_with_point_or_box.upload(
360
+ reset_image,
361
+ [img_with_point_or_box, prompts],
362
+ [
363
+ img,
364
+ img_with_scribble,
365
+ img_with_zim_mask,
366
+ img_with_sam_mask,
367
+ zim_amg,
368
+ sam_amg,
369
+ zim_mask,
370
+ sam_mask,
371
+ prompts,
372
+ ],
373
+ )
374
+
375
+ amg_bttn.click(
376
+ run_amg,
377
+ [img],
378
+ [zim_amg, sam_amg]
379
+ )
380
+ amg_scribble_bttn.click(
381
+ run_amg,
382
+ [img],
383
+ [zim_amg, sam_amg]
384
+ )
385
+
386
+ run_bttn.click(
387
+ get_point_or_box_prompts,
388
+ [img_with_point_or_box, prompts],
389
+ [img, zim_mask, sam_mask, prompts]
390
+ )
391
+
392
+ zim_mask.change(
393
+ draw_images,
394
+ [img, zim_mask, prompts],
395
+ [
396
+ img, img_with_zim_mask,
397
+ ],
398
+ )
399
+ sam_mask.change(
400
+ draw_images,
401
+ [img, sam_mask, prompts],
402
+ [
403
+ img, img_with_sam_mask,
404
+ ],
405
+ )
406
+
407
+ scribble_reset_bttn.click(
408
+ reset_scribble,
409
+ [img, img_with_scribble, prompts],
410
+ [img_with_scribble, zim_mask, sam_mask],
411
+ )
412
+ scribble_bttn.click(
413
+ update_scribble,
414
+ [img, img_with_scribble, prompts],
415
+ [zim_mask, sam_mask, prompts],
416
+ )
417
+
418
+ demo.queue()
419
+ demo.launch()
ckpts/sam_vit_b_01ec64.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ec2df62732614e57411cdcf32a23ffdf28910380d03139ee0f4fcbe91eb8c912
3
+ size 375042383
ckpts/zim_vit_b_2043/decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a99d14bac3dc28b17809020b1711b1991bd5b7eb18f678c3624ce508748cfa11
3
+ size 19330176
ckpts/zim_vit_b_2043/encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6123ea10722dd43a0f8bbb96cd6a0532d8a65e2b1e8339f0cd6593d494807809
3
+ size 360680416
ckpts/zim_vit_l_2092/decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c24a42d79053d20594760ad9afdef9cca985f29493f3e1e0fe3462ca291b57f7
3
+ size 19330176
ckpts/zim_vit_l_2092/encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6b0360b3c32dfa11555fb1de27cd3533cec1e932204da0dec2b3772045dc7db3
3
+ size 1235692110
config/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from config.config import generate_config
config/config.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024-present Naver Cloud Corp.
3
+
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ from easydict import EasyDict as edict
9
+
10
+ config_ = edict()
11
+
12
+ """
13
+ Common configs
14
+ """
15
+ config_.data_root = "/mnt/tmp"
16
+ config_.use_ddp = True
17
+ config_.use_amp = False
18
+ config_.local_rank = 0
19
+ config_.world_size = 1
20
+ config_.random_seed = 3407
21
+ """
22
+ Network configs
23
+ """
24
+ config_.network = edict()
25
+ config_.network.encoder = "vit_b"
26
+ config_.network.decoder = "zim"
27
+ config_.network.encode_kernel = 21
28
+ """
29
+ Evaluation configs
30
+ """
31
+ config_.eval = edict()
32
+ config_.eval.workers = 4
33
+ config_.eval.image_size = 1024
34
+ config_.eval.prompt_type = "point,bbox"
35
+ config_.eval.model_list = "zim,sam"
36
+ config_.eval.zim_weights = ""
37
+ config_.eval.sam_weights = ""
38
+ """
39
+ Dataset configs
40
+ """
41
+ config_.dataset = edict()
42
+ config_.dataset.valset = "MicroMat3K"
43
+ config_.dataset.data_type = "fine,coarse"
44
+ config_.dataset.data_list_txt = "data_list.txt"
45
+
46
+
47
+ def remove_prefix(text, prefix):
48
+ if text.startswith(prefix):
49
+ return text[len(prefix) :]
50
+ return text
51
+
52
+
53
+ def generate_config(args):
54
+ # merge args & config
55
+ for k, v in args.items():
56
+ if k.startswith("network_"):
57
+ config_["network"][remove_prefix(k, "network_")] = v
58
+ elif k.startswith("eval_"):
59
+ config_["eval"][remove_prefix(k, "eval_")] = v
60
+ elif k.startswith("dataset_"):
61
+ config_["dataset"][remove_prefix(k, "dataset_")] = v
62
+ elif k == "amp":
63
+ config_["use_amp"] = v
64
+ else:
65
+ config_[k] = v
66
+ return config_
examples/example1.jpg ADDED
examples/example2.jpg ADDED
examples/example3.jpg ADDED
examples/example4.jpg ADDED
examples/example5.jpg ADDED
examples/example6.jpg ADDED
examples/example7.jpg ADDED
examples/example8.jpg ADDED
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ wget
pre-requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy==1.24.4
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ easydict
2
+ opencv-python
3
+ gradio==4.38.1
4
+ gradio-image-prompter
5
+ fastapi==0.112.2
6
+ git+https://github.com/facebookresearch/segment-anything.git
7
+ onnxruntime-gpu==1.17.0
zim/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .build_model import build_zim_model, zim_model_registry
8
+ from .predictor import ZimPredictor
9
+ from .automatic_mask_generator import ZimAutomaticMaskGenerator
zim/automatic_mask_generator.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024-present Naver Cloud Corp.
3
+ This source code is based on code from the Segment Anything Model (SAM)
4
+ (https://github.com/facebookresearch/segment-anything).
5
+
6
+ This source code is licensed under the license found in the
7
+ LICENSE file in the root directory of this source tree.
8
+ """
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torchvision.ops.boxes import batched_nms, box_area
13
+
14
+ from typing import Any, Dict, List, Optional, Tuple
15
+
16
+ from .modeling.zim import Zim
17
+ from .predictor import ZimPredictor
18
+ from .utils.amg import (
19
+ MaskData,
20
+ area_from_rle,
21
+ batch_iterator,
22
+ batched_mask_to_box,
23
+ box_xyxy_to_xywh,
24
+ build_all_layer_point_grids,
25
+ calculate_stability_score,
26
+ coco_encode_rle,
27
+ generate_crop_boxes,
28
+ is_box_near_crop_edge,
29
+ mask_to_rle_pytorch,
30
+ remove_small_regions,
31
+ rle_to_mask,
32
+ uncrop_boxes_xyxy,
33
+ uncrop_masks,
34
+ uncrop_points,
35
+ )
36
+
37
+
38
+ class ZimAutomaticMaskGenerator:
39
+ def __init__(
40
+ self,
41
+ model: Zim,
42
+ points_per_side: Optional[int] = 32,
43
+ points_per_batch: int = 64,
44
+ pred_iou_thresh: float = 0.88,
45
+ stability_score_thresh: float = 0.9,
46
+ stability_score_offset: float = 0.1,
47
+ box_nms_thresh: float = 0.7,
48
+ crop_n_layers: int = 0,
49
+ crop_nms_thresh: float = 0.7,
50
+ crop_overlap_ratio: float = 512 / 1500,
51
+ crop_n_points_downscale_factor: int = 1,
52
+ point_grids: Optional[List[np.ndarray]] = None,
53
+ min_mask_region_area: int = 0,
54
+ output_mode: str = "binary_mask",
55
+ ) -> None:
56
+ """
57
+ Using a SAM model, generates masks for the entire image.
58
+ Generates a grid of point prompts over the image, then filters
59
+ low quality and duplicate masks. The default settings are chosen
60
+ for SAM with a ViT-H backbone.
61
+
62
+ Arguments:
63
+ model (Sam): The SAM model to use for mask prediction.
64
+ points_per_side (int or None): The number of points to be sampled
65
+ along one side of the image. The total number of points is
66
+ points_per_side**2. If None, 'point_grids' must provide explicit
67
+ point sampling.
68
+ points_per_batch (int): Sets the number of points run simultaneously
69
+ by the model. Higher numbers may be faster but use more GPU memory.
70
+ pred_iou_thresh (float): A filtering threshold in [0,1], using the
71
+ model's predicted mask quality.
72
+ stability_score_thresh (float): A filtering threshold in [0,1], using
73
+ the stability of the mask under changes to the cutoff used to binarize
74
+ the model's mask predictions.
75
+ stability_score_offset (float): The amount to shift the cutoff when
76
+ calculated the stability score.
77
+ box_nms_thresh (float): The box IoU cutoff used by non-maximal
78
+ suppression to filter duplicate masks.
79
+ crop_n_layers (int): If >0, mask prediction will be run again on
80
+ crops of the image. Sets the number of layers to run, where each
81
+ layer has 2**i_layer number of image crops.
82
+ crop_nms_thresh (float): The box IoU cutoff used by non-maximal
83
+ suppression to filter duplicate masks between different crops.
84
+ crop_overlap_ratio (float): Sets the degree to which crops overlap.
85
+ In the first crop layer, crops will overlap by this fraction of
86
+ the image length. Later layers with more crops scale down this overlap.
87
+ crop_n_points_downscale_factor (int): The number of points-per-side
88
+ sampled in layer n is scaled down by crop_n_points_downscale_factor**n.
89
+ point_grids (list(np.ndarray) or None): A list over explicit grids
90
+ of points used for sampling, normalized to [0,1]. The nth grid in the
91
+ list is used in the nth crop layer. Exclusive with points_per_side.
92
+ min_mask_region_area (int): If >0, postprocessing will be applied
93
+ to remove disconnected regions and holes in masks with area smaller
94
+ than min_mask_region_area. Requires opencv.
95
+ output_mode (str): The form masks are returned in. Can be 'binary_mask',
96
+ 'uncompressed_rle', or 'coco_rle'. 'coco_rle' requires pycocotools.
97
+ For large resolutions, 'binary_mask' may consume large amounts of
98
+ memory.
99
+ """
100
+
101
+ assert (points_per_side is None) != (
102
+ point_grids is None
103
+ ), "Exactly one of points_per_side or point_grid must be provided."
104
+ if points_per_side is not None:
105
+ self.point_grids = build_all_layer_point_grids(
106
+ points_per_side,
107
+ crop_n_layers,
108
+ crop_n_points_downscale_factor,
109
+ )
110
+ elif point_grids is not None:
111
+ self.point_grids = point_grids
112
+ else:
113
+ raise ValueError("Can't have both points_per_side and point_grid be None.")
114
+
115
+ assert output_mode in [
116
+ "binary_mask",
117
+ "uncompressed_rle",
118
+ "coco_rle",
119
+ ], f"Unknown output_mode {output_mode}."
120
+ if output_mode == "coco_rle":
121
+ from pycocotools import mask as mask_utils # type: ignore # noqa: F401
122
+
123
+ if min_mask_region_area > 0:
124
+ import cv2 # type: ignore # noqa: F401
125
+
126
+ self.predictor = ZimPredictor(model)
127
+ self.points_per_batch = points_per_batch
128
+ self.pred_iou_thresh = pred_iou_thresh
129
+ self.stability_score_thresh = stability_score_thresh
130
+ self.stability_score_offset = stability_score_offset
131
+ self.box_nms_thresh = box_nms_thresh
132
+ self.crop_n_layers = crop_n_layers
133
+ self.crop_nms_thresh = crop_nms_thresh
134
+ self.crop_overlap_ratio = crop_overlap_ratio
135
+ self.crop_n_points_downscale_factor = crop_n_points_downscale_factor
136
+ self.min_mask_region_area = min_mask_region_area
137
+ self.output_mode = output_mode
138
+
139
+ @torch.no_grad()
140
+ def generate(self, image: np.ndarray) -> List[Dict[str, Any]]:
141
+ """
142
+ Generates masks for the given image.
143
+
144
+ Arguments:
145
+ image (np.ndarray): The image to generate masks for, in HWC uint8 format.
146
+
147
+ Returns:
148
+ list(dict(str, any)): A list over records for masks. Each record is
149
+ a dict containing the following keys:
150
+ segmentation (dict(str, any) or np.ndarray): The mask. If
151
+ output_mode='binary_mask', is an array of shape HW. Otherwise,
152
+ is a dictionary containing the RLE.
153
+ bbox (list(float)): The box around the mask, in XYWH format.
154
+ area (int): The area in pixels of the mask.
155
+ predicted_iou (float): The model's own prediction of the mask's
156
+ quality. This is filtered by the pred_iou_thresh parameter.
157
+ point_coords (list(list(float))): The point coordinates input
158
+ to the model to generate this mask.
159
+ stability_score (float): A measure of the mask's quality. This
160
+ is filtered on using the stability_score_thresh parameter.
161
+ crop_box (list(float)): The crop of the image used to generate
162
+ the mask, given in XYWH format.
163
+ """
164
+
165
+ # Generate masks
166
+ mask_data = self._generate_masks(image)
167
+
168
+ # Filter small disconnected regions and holes in masks
169
+ if self.min_mask_region_area > 0:
170
+ mask_data = self.postprocess_small_regions(
171
+ mask_data,
172
+ self.min_mask_region_area,
173
+ max(self.box_nms_thresh, self.crop_nms_thresh),
174
+ )
175
+
176
+ # Encode masks
177
+ if self.output_mode == "coco_rle":
178
+ mask_data["segmentations"] = [coco_encode_rle(rle) for rle in mask_data["rles"]]
179
+ elif self.output_mode == "binary_mask":
180
+ mask_data["segmentations"] = [rle_to_mask(rle) for rle in mask_data["rles"]]
181
+ else:
182
+ mask_data["segmentations"] = mask_data["rles"]
183
+
184
+ # Write mask records
185
+ curr_anns = []
186
+ for idx in range(len(mask_data["segmentations"])):
187
+ ann = {
188
+ "segmentation": mask_data["segmentations"][idx],
189
+ "logit": mask_data["logits"][idx],
190
+ "area": area_from_rle(mask_data["rles"][idx]),
191
+ "bbox": box_xyxy_to_xywh(mask_data["boxes"][idx]).tolist(),
192
+ "predicted_iou": mask_data["iou_preds"][idx].item(),
193
+ "point_coords": [mask_data["points"][idx].tolist()],
194
+ "stability_score": mask_data["stability_score"][idx].item(),
195
+ "crop_box": box_xyxy_to_xywh(mask_data["crop_boxes"][idx]).tolist(),
196
+ }
197
+ curr_anns.append(ann)
198
+
199
+ return curr_anns
200
+
201
+ def _generate_masks(self, image: np.ndarray) -> MaskData:
202
+ orig_size = image.shape[:2]
203
+ crop_boxes, layer_idxs = generate_crop_boxes(
204
+ orig_size, self.crop_n_layers, self.crop_overlap_ratio
205
+ )
206
+
207
+ # Iterate over image crops
208
+ data = MaskData()
209
+ for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
210
+ crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
211
+ data.cat(crop_data)
212
+
213
+ # Remove duplicate masks between crops
214
+ if len(crop_boxes) > 1:
215
+ # Prefer masks from smaller crops
216
+ scores = 1 / box_area(data["crop_boxes"])
217
+ scores = scores.to(data["boxes"].device)
218
+ keep_by_nms = batched_nms(
219
+ data["boxes"].float(),
220
+ scores,
221
+ torch.zeros_like(data["boxes"][:, 0]), # categories
222
+ iou_threshold=self.crop_nms_thresh,
223
+ )
224
+ data.filter(keep_by_nms)
225
+
226
+ data.to_numpy()
227
+ return data
228
+
229
+ def _process_crop(
230
+ self,
231
+ image: np.ndarray,
232
+ crop_box: List[int],
233
+ crop_layer_idx: int,
234
+ orig_size: Tuple[int, ...],
235
+ ) -> MaskData:
236
+ # Crop the image and calculate embeddings
237
+ x0, y0, x1, y1 = crop_box
238
+ cropped_im = image[y0:y1, x0:x1, :]
239
+ cropped_im_size = cropped_im.shape[:2]
240
+ self.predictor.set_image(cropped_im)
241
+
242
+ # Get points for this crop
243
+ points_scale = np.array(cropped_im_size)[None, ::-1]
244
+ points_for_image = self.point_grids[crop_layer_idx] * points_scale
245
+
246
+ # Generate masks for this crop in batches
247
+ data = MaskData()
248
+ for (points,) in batch_iterator(self.points_per_batch, points_for_image):
249
+ batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
250
+ data.cat(batch_data)
251
+ del batch_data
252
+ self.predictor.reset_image()
253
+
254
+ # Remove duplicates within this crop.
255
+ keep_by_nms = batched_nms(
256
+ data["boxes"].float(),
257
+ data["iou_preds"],
258
+ torch.zeros_like(data["boxes"][:, 0]), # categories
259
+ iou_threshold=self.box_nms_thresh,
260
+ )
261
+ data.filter(keep_by_nms)
262
+
263
+ # Return to the original image frame
264
+ data["boxes"] = uncrop_boxes_xyxy(data["boxes"], crop_box)
265
+ data["points"] = uncrop_points(data["points"], crop_box)
266
+ data["crop_boxes"] = torch.tensor([crop_box for _ in range(len(data["rles"]))])
267
+
268
+ return data
269
+
270
+ def _process_batch(
271
+ self,
272
+ points: np.ndarray,
273
+ im_size: Tuple[int, ...],
274
+ crop_box: List[int],
275
+ orig_size: Tuple[int, ...],
276
+ ) -> MaskData:
277
+ orig_h, orig_w = orig_size
278
+
279
+ # Run model on this batch
280
+ transformed_points = self.predictor.transform.apply_coords(points, im_size)
281
+ in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
282
+ in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
283
+ masks, iou_preds, _ = self.predictor.predict_torch(
284
+ in_points[:, None, :],
285
+ in_labels[:, None],
286
+ multimask_output=True,
287
+ return_logits=True,
288
+ )
289
+
290
+ # Serialize predictions and store in MaskData
291
+ data = MaskData(
292
+ masks=masks.flatten(0, 1),
293
+ logits=(masks.flatten(0, 1) * 255).byte(),
294
+ iou_preds=iou_preds.flatten(0, 1),
295
+ points=torch.as_tensor(points.repeat(masks.shape[1], axis=0)),
296
+ )
297
+ del masks
298
+
299
+ # Filter by predicted IoU
300
+ if self.pred_iou_thresh > 0.0:
301
+ keep_mask = data["iou_preds"] > self.pred_iou_thresh
302
+ data.filter(keep_mask)
303
+
304
+ # Calculate stability score
305
+ data["stability_score"] = calculate_stability_score(
306
+ data["masks"], self.predictor.model.mask_threshold, self.stability_score_offset
307
+ )
308
+ if self.stability_score_thresh > 0.0:
309
+ keep_mask = data["stability_score"] >= self.stability_score_thresh
310
+ data.filter(keep_mask)
311
+
312
+ # Threshold masks and calculate boxes
313
+ data["masks"] = data["masks"] > self.predictor.model.mask_threshold
314
+ data["boxes"] = batched_mask_to_box(data["masks"])
315
+
316
+ # Filter boxes that touch crop boundaries
317
+ keep_mask = ~is_box_near_crop_edge(data["boxes"], crop_box, [0, 0, orig_w, orig_h])
318
+ if not torch.all(keep_mask):
319
+ data.filter(keep_mask)
320
+
321
+ # Compress to RLE
322
+ data["masks"] = uncrop_masks(data["masks"], crop_box, orig_h, orig_w)
323
+ data["logits"] = uncrop_masks(data["logits"], crop_box, orig_h, orig_w)
324
+ data["rles"] = mask_to_rle_pytorch(data["masks"])
325
+ del data["masks"]
326
+
327
+ return data
328
+
329
+ @staticmethod
330
+ def postprocess_small_regions(
331
+ mask_data: MaskData, min_area: int, nms_thresh: float
332
+ ) -> MaskData:
333
+ """
334
+ Removes small disconnected regions and holes in masks, then reruns
335
+ box NMS to remove any new duplicates.
336
+
337
+ Edits mask_data in place.
338
+
339
+ Requires open-cv as a dependency.
340
+ """
341
+ if len(mask_data["rles"]) == 0:
342
+ return mask_data
343
+
344
+ # Filter small disconnected regions and holes
345
+ new_masks = []
346
+ scores = []
347
+ for rle in mask_data["rles"]:
348
+ mask = rle_to_mask(rle)
349
+
350
+ mask, changed = remove_small_regions(mask, min_area, mode="holes")
351
+ unchanged = not changed
352
+ mask, changed = remove_small_regions(mask, min_area, mode="islands")
353
+ unchanged = unchanged and not changed
354
+
355
+ new_masks.append(torch.as_tensor(mask).unsqueeze(0))
356
+ # Give score=0 to changed masks and score=1 to unchanged masks
357
+ # so NMS will prefer ones that didn't need postprocessing
358
+ scores.append(float(unchanged))
359
+
360
+ # Recalculate boxes and remove any new duplicates
361
+ masks = torch.cat(new_masks, dim=0)
362
+ boxes = batched_mask_to_box(masks)
363
+ keep_by_nms = batched_nms(
364
+ boxes.float(),
365
+ torch.as_tensor(scores),
366
+ torch.zeros_like(boxes[:, 0]), # categories
367
+ iou_threshold=nms_thresh,
368
+ )
369
+
370
+ # Only recalculate RLEs for masks that have changed
371
+ for i_mask in keep_by_nms:
372
+ if scores[i_mask] == 0.0:
373
+ mask_torch = masks[i_mask].unsqueeze(0)
374
+ mask_data["rles"][i_mask] = mask_to_rle_pytorch(mask_torch)[0]
375
+ mask_data["boxes"][i_mask] = boxes[i_mask] # update res directly
376
+ mask_data.filter(keep_by_nms)
377
+
378
+ return mask_data
zim/build_model.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024-present Naver Cloud Corp.
3
+ This source code is based on code from the Segment Anything Model (SAM)
4
+ (https://github.com/facebookresearch/segment-anything).
5
+
6
+ This source code is licensed under the license found in the
7
+ LICENSE file in the root directory of this source tree.
8
+ """
9
+ import os
10
+ import torch
11
+
12
+ from .modeling.zim import Zim
13
+ from .modeling.encoder import ZIM_Encoder
14
+ from .modeling.decoder import ZIM_Decoder
15
+
16
+ def build_zim_model(checkpoint):
17
+
18
+ encoder = ZIM_Encoder(os.path.join(checkpoint, "encoder.onnx"))
19
+ decoder = ZIM_Decoder(os.path.join(checkpoint, "decoder.onnx"))
20
+ net = Zim(encoder, decoder)
21
+
22
+ return net
23
+
24
+ zim_model_registry = {
25
+ "default": build_zim_model,
26
+ "vit_l": build_zim_model,
27
+ "vit_b": build_zim_model,
28
+ }
29
+
zim/modeling/decoder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024-present Naver Cloud Corp.
3
+ This source code is based on code from the Segment Anything Model (SAM)
4
+ (https://github.com/facebookresearch/segment-anything).
5
+
6
+ This source code is licensed under the license found in the
7
+ LICENSE file in the root directory of this source tree.
8
+ """
9
+ import torch
10
+ from typing import Any, Callable
11
+ import onnxruntime
12
+ import numpy as np
13
+
14
+ def np2tensor(np_array, device):
15
+ return torch.from_numpy(np_array).to(device)
16
+
17
+ def tensor2np(torch_tensor):
18
+ if torch_tensor is None:
19
+ return None
20
+
21
+ return torch_tensor.detach().cpu().numpy()
22
+
23
+ class ZIM_Decoder():
24
+ def __init__(self, onnx_path, num_threads=16):
25
+ self.onnx_path = onnx_path
26
+
27
+ sessionOptions = onnxruntime.SessionOptions()
28
+ sessionOptions.intra_op_num_threads = num_threads
29
+ sessionOptions.inter_op_num_threads = num_threads
30
+ providers = ["CPUExecutionProvider"]
31
+
32
+ self.ort_session = onnxruntime.InferenceSession(
33
+ onnx_path, sess_options=sessionOptions, providers=providers
34
+ )
35
+ self.num_mask_tokens = 4
36
+
37
+ def cuda(self, device_id=0):
38
+ providers = [
39
+ (
40
+ "CUDAExecutionProvider",
41
+ {
42
+ "device_id": device_id,
43
+ },
44
+ ),
45
+ ]
46
+
47
+ self.ort_session.set_providers(providers)
48
+
49
+ def forward(
50
+ self,
51
+ interm_feats,
52
+ image_embeddings,
53
+ points,
54
+ boxes,
55
+ attn_mask,
56
+ ):
57
+ device = image_embeddings.device
58
+
59
+ ort_inputs = {
60
+ "feat_D0": tensor2np(interm_feats[0]),
61
+ "feat_D1": tensor2np(interm_feats[1]),
62
+ "feat_D2": tensor2np(interm_feats[2]),
63
+ "image_embeddings": tensor2np(image_embeddings),
64
+ "attn_mask": tensor2np(attn_mask),
65
+ }
66
+
67
+ if points is not None:
68
+ point_coords, point_labels = points
69
+ ort_inputs["point_coords"] = tensor2np(point_coords.float())
70
+ ort_inputs["point_labels"] = tensor2np(point_labels.float())
71
+
72
+ # add paddings as done in SAM
73
+ padding_point = np.zeros((ort_inputs["point_coords"].shape[0], 1, 2), dtype=np.float32) - 0.5
74
+ padding_label = -np.ones((ort_inputs["point_labels"].shape[0], 1), dtype=np.float32)
75
+ ort_inputs["point_coords"] = np.concatenate([ort_inputs["point_coords"], padding_point], axis=1)
76
+ ort_inputs["point_labels"] = np.concatenate([ort_inputs["point_labels"], padding_label], axis=1)
77
+
78
+ if boxes is not None:
79
+ ort_inputs["point_coords"] = tensor2np(boxes.reshape(-1, 2, 2))
80
+ ort_inputs["point_labels"] = np.array([[2, 3]], dtype=np.float32).repeat(boxes.shape[0], 0)
81
+
82
+ masks, iou_predictions = self.ort_session.run(None, ort_inputs)
83
+
84
+ masks = np2tensor(masks, device)
85
+ iou_predictions = np2tensor(iou_predictions, device)
86
+
87
+ return masks, iou_predictions
88
+
89
+ __call__: Callable[..., Any] = forward
90
+
zim/modeling/encoder.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024-present Naver Cloud Corp.
3
+ This source code is based on code from the Segment Anything Model (SAM)
4
+ (https://github.com/facebookresearch/segment-anything).
5
+
6
+ This source code is licensed under the license found in the
7
+ LICENSE file in the root directory of this source tree.
8
+ """
9
+ import torch
10
+ from typing import Any, Callable
11
+ import onnxruntime
12
+
13
+ def np2tensor(np_array, device):
14
+ return torch.from_numpy(np_array).to(device)
15
+
16
+ def tensor2np(torch_tensor):
17
+ return torch_tensor.detach().cpu().numpy()
18
+
19
+ class ZIM_Encoder():
20
+ def __init__(self, onnx_path, num_threads=16):
21
+ self.onnx_path = onnx_path
22
+
23
+ sessionOptions = onnxruntime.SessionOptions()
24
+ sessionOptions.intra_op_num_threads = num_threads
25
+ sessionOptions.inter_op_num_threads = num_threads
26
+ providers = ["CPUExecutionProvider"]
27
+
28
+ self.ort_session = onnxruntime.InferenceSession(
29
+ onnx_path, sess_options=sessionOptions, providers=providers
30
+ )
31
+
32
+ def cuda(self, device_id=0):
33
+ providers = [
34
+ (
35
+ "CUDAExecutionProvider",
36
+ {
37
+ "device_id": device_id,
38
+ },
39
+ ),
40
+ ]
41
+
42
+ self.ort_session.set_providers(providers)
43
+
44
+ def forward(
45
+ self,
46
+ image,
47
+ ):
48
+ device = image.device
49
+
50
+ ort_inputs = {
51
+ "image": tensor2np(image),
52
+ }
53
+ image_embeddings, feat_D0, feat_D1, feat_D2 = self.ort_session.run(None, ort_inputs)
54
+
55
+ image_embeddings = np2tensor(image_embeddings, device)
56
+ feat_D0 = np2tensor(feat_D0, device)
57
+ feat_D1 = np2tensor(feat_D1, device)
58
+ feat_D2 = np2tensor(feat_D2, device)
59
+
60
+ return image_embeddings, (feat_D0, feat_D1, feat_D2)
61
+
62
+ __call__: Callable[..., Any] = forward
zim/modeling/zim.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024-present Naver Cloud Corp.
3
+ This source code is based on code from the Segment Anything Model (SAM)
4
+ (https://github.com/facebookresearch/segment-anything).
5
+
6
+ This source code is licensed under the license found in the
7
+ LICENSE file in the root directory of this source tree.
8
+ """
9
+
10
+ import torch
11
+ from torch import nn
12
+ from torch.nn import functional as F
13
+ from typing import Any, Dict, List
14
+
15
+ def gaussian(sigma=6):
16
+ """
17
+ 2D Gaussian Kernel Generation.
18
+ """
19
+ size = 6 * sigma + 3
20
+ x = torch.arange(0, size, 1)
21
+ y = x[:, None]
22
+ x0, y0 = 3 * sigma + 1, 3 * sigma + 1
23
+ g = torch.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
24
+ return g
25
+
26
+ class Zim(nn.Module):
27
+ def __init__(
28
+ self,
29
+ encoder,
30
+ decoder,
31
+ *,
32
+ image_size: int = 1024,
33
+ pixel_mean: List[float] = [123.675, 116.28, 103.53],
34
+ pixel_std: List[float] = [58.395, 57.12, 57.375],
35
+ ) -> None:
36
+ """
37
+ SAM predicts object masks from an image and input prompts.
38
+
39
+ Arguments:
40
+ encoder : The backbone used to encode the
41
+ image into image embeddings that allow for efficient mask prediction.
42
+ decoder : Predicts masks from the image embeddings and given prompts.
43
+ pixel_mean (list(float)): Mean values for normalizing pixels in the input image.
44
+ pixel_std (list(float)): Std values for normalizing pixels in the input image.
45
+ """
46
+ super().__init__()
47
+ self.encoder = encoder
48
+ self.decoder = decoder
49
+ self.output_activation = nn.Sigmoid()
50
+
51
+ self.image_size = image_size
52
+ self.register_buffer(
53
+ "pixel_mean", torch.Tensor(pixel_mean).view(-1, 1, 1), False
54
+ )
55
+ self.register_buffer("pixel_std", torch.Tensor(pixel_std).view(-1, 1, 1), False)
56
+
57
+ self.mask_threshold: float = 0.5
58
+ self.image_format: str = "RGB"
59
+ self.num_mask_tokens = decoder.num_mask_tokens
60
+
61
+ self.encode_stride = 16
62
+ self.encode_kernel = 21
63
+ self.attn_mask_size = 64
64
+ self.g = gaussian(self.encode_kernel)
65
+
66
+ self.output_conv = nn.Conv2d(
67
+ self.num_mask_tokens,
68
+ self.num_mask_tokens,
69
+ kernel_size=1, stride=1, padding=0,
70
+ )
71
+
72
+ @property
73
+ def device(self) -> Any:
74
+ return self.pixel_mean.device
75
+
76
+ def cuda(self, device_id=None):
77
+ if type(device_id) == torch.device:
78
+ device_id = device_id.index
79
+
80
+ if device_id is None:
81
+ device_id = 0
82
+
83
+ device = torch.device(f"cuda:{device_id}")
84
+ super(Zim, self).cuda(device)
85
+
86
+ self.encoder.cuda(device_id)
87
+ self.decoder.cuda(device_id)
88
+
89
+ return self
90
+
91
+ def postprocess_masks(
92
+ self, masks: torch.Tensor, input_size: List[int], original_size: torch.Tensor
93
+ ) -> torch.Tensor:
94
+ """
95
+ Remove padding and upscale masks to the original image size.
96
+
97
+ Arguments:
98
+ masks (torch.Tensor): Batched masks from the decoder,
99
+ in BxCxHxW format.
100
+ input_size (tuple(int, int)): The size of the image input to the
101
+ model, in (H, W) format. Used to remove padding.
102
+ original_size (tuple(int, int)): The original size of the image
103
+ before resizing for input to the model, in (H, W) format.
104
+
105
+ Returns:
106
+ (torch.Tensor): Batched masks in BxCxHxW format, where (H, W)
107
+ is given by original_size.
108
+ """
109
+ masks = F.interpolate(
110
+ masks,
111
+ (self.image_size, self.image_size),
112
+ mode="bilinear",
113
+ align_corners=False,
114
+ )
115
+ masks = masks[..., : input_size[0], : input_size[1]]
116
+ masks = F.interpolate(
117
+ masks, original_size, mode="bilinear", align_corners=False
118
+ )
119
+ return masks
120
+
121
+ def preprocess(self, x: torch.Tensor) -> torch.Tensor:
122
+ """Normalize pixel values and pad to a square input."""
123
+ # Normalize colors
124
+ x = (x - self.pixel_mean) / self.pixel_std
125
+
126
+ # Pad
127
+ h, w = x.shape[-2:]
128
+ padh = self.image_size - h
129
+ padw = self.image_size - w
130
+ x = F.pad(x, (0, padw, 0, padh))
131
+ return x
132
+
133
+ def bbox_attn_mask(self, boxes):
134
+ """Prompt-aware Masked Attention: box prompt (binary attn mask) """
135
+ bs = boxes.shape[0]
136
+ attn_mask = torch.zeros((bs, self.attn_mask_size, self.attn_mask_size), device=boxes.device)
137
+
138
+ # attn_weight = attn_weight.masked_fill(m.logical_not(), -1e4)
139
+
140
+ for n in range(bs):
141
+ xmin, ymin, xmax, ymax = boxes[n]
142
+
143
+ xmin, xmax = min(xmin, xmax), max(xmin, xmax)
144
+ ymin, ymax = min(ymin, ymax), max(ymin, ymax)
145
+
146
+ xmin, xmax = int(xmin / self.encode_stride), int(xmax / self.encode_stride)
147
+ ymin, ymax = int(ymin / self.encode_stride), int(ymax / self.encode_stride)
148
+
149
+ xmin, ymin = max(0, xmin), max(0, ymin)
150
+ xmax = min(self.attn_mask_size, xmax+1)
151
+ ymax = min(self.attn_mask_size, ymax+1)
152
+
153
+ attn_mask[n, ymin:ymax, xmin:xmax] = 1
154
+
155
+ return attn_mask
156
+
157
+ def point_attn_mask(self, point_coords):
158
+ """Prompt-aware Masked Attention: point prompt (soft attn mask) """
159
+ bs = point_coords.shape[0]
160
+ attn_mask = torch.zeros((bs, self.attn_mask_size, self.attn_mask_size), device=point_coords.device)
161
+
162
+ if self.g.device != point_coords.device:
163
+ self.g = self.g.to(point_coords.device)
164
+
165
+ for n in range(bs):
166
+ for point in point_coords[n]:
167
+ x, y = int(point[0] / self.encode_stride), int(point[1].item() / self.encode_stride)
168
+
169
+ # outside image boundary
170
+ if x < 0 or y < 0 or x >= self.attn_mask_size or y >= self.attn_mask_size:
171
+ continue
172
+
173
+ # upper left
174
+ ul = int(round(x - 3 * self.encode_kernel - 1)), int(round(y - 3 * self.encode_kernel - 1))
175
+ # bottom right
176
+ br = int(round(x + 3 * self.encode_kernel + 2)), int(round(y + 3 * self.encode_kernel + 2))
177
+
178
+ c, d = int(max(0, -ul[0])), int(min(br[0], self.attn_mask_size) - ul[0])
179
+ a, b = int(max(0, -ul[1])), int(min(br[1], self.attn_mask_size) - ul[1])
180
+
181
+ cc, dd = int(max(0, ul[0])), int(min(br[0], self.attn_mask_size))
182
+ aa, bb = int(max(0, ul[1])), int(min(br[1], self.attn_mask_size))
183
+
184
+ attn_mask[n, aa:bb, cc:dd] = torch.maximum(
185
+ attn_mask[n, aa:bb, cc:dd], self.g[a:b, c:d]
186
+ )
187
+
188
+ return attn_mask
189
+
190
+
zim/predictor.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024-present Naver Cloud Corp.
3
+ This source code is based on code from the Segment Anything Model (SAM)
4
+ (https://github.com/facebookresearch/segment-anything).
5
+
6
+ This source code is licensed under the license found in the
7
+ LICENSE file in the root directory of this source tree.
8
+ """
9
+
10
+ import numpy as np
11
+ import torch
12
+ from torch.nn import functional as F
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ from typing import Optional, Tuple, List
15
+
16
+ from .utils import ResizeLongestSide
17
+
18
+ class ZimPredictor:
19
+ def __init__(
20
+ self,
21
+ model,
22
+ ) -> None:
23
+ """
24
+ Uses SAM to calculate the image embedding for an image, and then
25
+ allow repeated, efficient mask prediction given prompts.
26
+
27
+ Arguments:
28
+ sam_model (Sam): The model to use for mask prediction.
29
+ """
30
+ super().__init__()
31
+ self.model = model.module if isinstance(model, DDP) else model
32
+ self.transform = ResizeLongestSide(self.model.image_size)
33
+ self.reset_image()
34
+
35
+ def set_image(
36
+ self,
37
+ image: np.ndarray,
38
+ image_format: str = "RGB",
39
+ ) -> None:
40
+ """
41
+ Calculates the image embeddings for the provided image, allowing
42
+ masks to be predicted with the 'predict' method.
43
+
44
+ Arguments:
45
+ image (np.ndarray): The image for calculating masks. Expects an
46
+ image in HWC uint8 format, with pixel values in [0, 255].
47
+ image_format (str): The color format of the image, in ['RGB', 'BGR'].
48
+ """
49
+ assert image_format in [
50
+ "RGB",
51
+ "BGR",
52
+ ], f"image_format must be in ['RGB', 'BGR'], is {image_format}."
53
+ if image_format != self.model.image_format:
54
+ image = image[..., ::-1]
55
+
56
+ # Transform the image to the form expected by the model
57
+ input_image = self.transform.apply_image(image)
58
+ input_image_torch = torch.as_tensor(input_image, device=self.device)
59
+ input_image_torch = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
60
+
61
+ self.set_torch_image(input_image_torch, image.shape[:2])
62
+
63
+ @torch.no_grad()
64
+ def set_torch_image(
65
+ self,
66
+ transformed_image: torch.Tensor,
67
+ original_image_size: Tuple[int, ...],
68
+ ) -> None:
69
+ """
70
+ Calculates the image embeddings for the provided image, allowing
71
+ masks to be predicted with the 'predict' method. Expects the input
72
+ image to be already transformed to the format expected by the model.
73
+
74
+ Arguments:
75
+ transformed_image (torch.Tensor): The input image, with shape
76
+ 1x3xHxW, which has been transformed with ResizeLongestSide.
77
+ original_image_size (tuple(int, int)): The size of the image
78
+ before transformation, in (H, W) format.
79
+ """
80
+ assert (
81
+ len(transformed_image.shape) == 4
82
+ and transformed_image.shape[1] == 3
83
+ and max(*transformed_image.shape[2:]) == self.model.image_size
84
+ ), f"set_torch_image input must be BCHW with long side {self.model.image_size}."
85
+ self.reset_image()
86
+
87
+ self.original_size = original_image_size
88
+ self.input_size = tuple(transformed_image.shape[-2:])
89
+ input_image = self.model.preprocess(transformed_image)
90
+ self.features, self.interm_feats = self.model.encoder(input_image)
91
+ self.is_image_set = True
92
+
93
+ def predict(
94
+ self,
95
+ point_coords: Optional[np.ndarray] = None,
96
+ point_labels: Optional[np.ndarray] = None,
97
+ box: Optional[np.ndarray] = None,
98
+ multimask_output: bool = True,
99
+ return_logits: bool = False,
100
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
101
+ """
102
+ Predict masks for the given input prompts, using the currently set image.
103
+
104
+ Arguments:
105
+ point_coords (np.ndarray or None): A Nx2 array of point prompts to the
106
+ model. Each point is in (X,Y) in pixels.
107
+ point_labels (np.ndarray or None): A length N array of labels for the
108
+ point prompts. 1 indicates a foreground point and 0 indicates a
109
+ background point.
110
+ box (np.ndarray or None): A length 4 array given a box prompt to the
111
+ model, in XYXY format.
112
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
113
+ coming from a previous prediction iteration. Has form 1xHxW, where
114
+ for SAM, H=W=256.
115
+ multimask_output (bool): If true, the model will return three masks.
116
+ For ambiguous input prompts (such as a single click), this will often
117
+ produce better masks than a single prediction. If only a single
118
+ mask is needed, the model's predicted quality score can be used
119
+ to select the best mask. For non-ambiguous prompts, such as multiple
120
+ input prompts, multimask_output=False can give better results.
121
+ return_logits (bool): If true, returns un-thresholded masks logits
122
+ instead of a binary mask.
123
+
124
+ Returns:
125
+ (np.ndarray): The output masks in CxHxW format, where C is the
126
+ number of masks, and (H, W) is the original image size.
127
+ (np.ndarray): An array of length C containing the model's
128
+ predictions for the quality of each mask.
129
+ (np.ndarray): An array of shape CxHxW, where C is the number
130
+ of masks and H=W=256. These low resolution logits can be passed to
131
+ a subsequent iteration as mask input.
132
+ """
133
+ if not self.is_image_set:
134
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
135
+
136
+ # Transform input prompts
137
+ coords_torch = None
138
+ labels_torch = None
139
+ box_torch = None
140
+
141
+ if point_coords is not None:
142
+ assert (
143
+ point_labels is not None
144
+ ), "point_labels must be supplied if point_coords is supplied."
145
+ point_coords = self.transform.apply_coords(point_coords, self.original_size)
146
+ coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=self.device)
147
+ labels_torch = torch.as_tensor(point_labels, dtype=torch.float, device=self.device)
148
+ coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :]
149
+ if box is not None:
150
+ box = self.transform.apply_boxes(box, self.original_size)
151
+ box_torch = torch.as_tensor(box, dtype=torch.float, device=self.device)
152
+
153
+ masks, iou_predictions, low_res_masks = self.predict_torch(
154
+ coords_torch,
155
+ labels_torch,
156
+ box_torch,
157
+ multimask_output,
158
+ return_logits=return_logits,
159
+ )
160
+
161
+ masks_np = masks[0].detach().cpu().numpy()
162
+ iou_predictions_np = iou_predictions[0].detach().cpu().numpy()
163
+ low_res_masks_np = low_res_masks[0].detach().cpu().numpy()
164
+
165
+ return masks_np, iou_predictions_np, low_res_masks_np
166
+
167
+ @torch.no_grad()
168
+ def predict_torch(
169
+ self,
170
+ point_coords: Optional[torch.Tensor],
171
+ point_labels: Optional[torch.Tensor],
172
+ boxes: Optional[torch.Tensor] = None,
173
+ multimask_output: bool = True,
174
+ return_logits: bool = False,
175
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
176
+ """
177
+ Predict masks for the given input prompts, using the currently set image.
178
+ Input prompts are batched torch tensors and are expected to already be
179
+ transformed to the input frame using ResizeLongestSide.
180
+
181
+ Arguments:
182
+ point_coords (torch.Tensor or None): A BxNx2 array of point prompts to the
183
+ model. Each point is in (X,Y) in pixels.
184
+ point_labels (torch.Tensor or None): A BxN array of labels for the
185
+ point prompts. 1 indicates a foreground point and 0 indicates a
186
+ background point.
187
+ boxes (np.ndarray or None): A Bx4 array given a box prompt to the
188
+ model, in XYXY format.
189
+ mask_input (np.ndarray): A low resolution mask input to the model, typically
190
+ coming from a previous prediction iteration. Has form Bx1xHxW, where
191
+ for SAM, H=W=256. Masks returned by a previous iteration of the
192
+ predict method do not need further transformation.
193
+ multimask_output (bool): If true, the model will return three masks.
194
+ For ambiguous input prompts (such as a single click), this will often
195
+ produce better masks than a single prediction. If only a single
196
+ mask is needed, the model's predicted quality score can be used
197
+ to select the best mask. For non-ambiguous prompts, such as multiple
198
+ input prompts, multimask_output=False can give better results.
199
+ return_logits (bool): If true, returns un-thresholded masks logits
200
+ instead of a binary mask.
201
+
202
+ Returns:
203
+ (torch.Tensor): The output masks in BxCxHxW format, where C is the
204
+ number of masks, and (H, W) is the original image size.
205
+ (torch.Tensor): An array of shape BxC containing the model's
206
+ predictions for the quality of each mask.
207
+ (torch.Tensor): An array of shape BxCxHxW, where C is the number
208
+ of masks and H=W=256. These low res logits can be passed to
209
+ a subsequent iteration as mask input.
210
+ """
211
+ if not self.is_image_set:
212
+ raise RuntimeError("An image must be set with .set_image(...) before mask prediction.")
213
+
214
+ if point_coords is not None:
215
+ points = (point_coords, point_labels)
216
+ attn_mask = self.model.point_attn_mask(point_coords)
217
+ else:
218
+ points = None
219
+ attn_mask = self.model.bbox_attn_mask(boxes)
220
+
221
+ # Embed prompts
222
+ masks, iou_predictions = self.model.decoder(
223
+ interm_feats=self.interm_feats,
224
+ image_embeddings=self.features,
225
+ points=points,
226
+ boxes=boxes,
227
+ attn_mask=attn_mask,
228
+ )
229
+
230
+ # Select the correct mask or masks for output
231
+ if multimask_output:
232
+ mask_slice = slice(0, None)
233
+ else:
234
+ mask_slice = slice(0, 1)
235
+
236
+ masks = masks[:, mask_slice, :, :]
237
+ iou_predictions = iou_predictions[:, mask_slice]
238
+
239
+ low_res_masks = F.interpolate(masks, scale_factor=2, mode='bilinear', align_corners=False)
240
+
241
+ masks = self.model.postprocess_masks(
242
+ masks,
243
+ input_size=self.input_size,
244
+ original_size=self.original_size,
245
+ )
246
+
247
+ return masks.sigmoid(), iou_predictions, low_res_masks.sigmoid()
248
+
249
+ def get_image_embedding(self) -> torch.Tensor:
250
+ """
251
+ Returns the image embeddings for the currently set image, with
252
+ shape 1xCxHxW, where C is the embedding dimension and (H,W) are
253
+ the embedding spatial dimension of SAM (typically C=256, H=W=64).
254
+ """
255
+ if not self.is_image_set:
256
+ raise RuntimeError(
257
+ "An image must be set with .set_image(...) to generate an embedding."
258
+ )
259
+ assert self.features is not None, "Features must exist if an image has been set."
260
+ return self.features
261
+
262
+ @property
263
+ def device(self) -> torch.device:
264
+ return self.model.device
265
+
266
+ def reset_image(self) -> None:
267
+ """Resets the currently set image."""
268
+ self.is_image_set = False
269
+ self.features = None
270
+ self.interm_feats = None
271
+ self.orig_h = None
272
+ self.orig_w = None
273
+ self.input_h = None
274
+ self.input_w = None
275
+
zim/utils/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from .argparser import get_parser
8
+ from .print import print_once, pretty
9
+ from .utils import AverageMeter, ResizeLongestSide
10
+ from .amg import show_mat_anns
zim/utils/amg.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024-present Naver Cloud Corp.
3
+ This source code is based on code from the Segment Anything Model (SAM)
4
+ (https://github.com/facebookresearch/segment-anything).
5
+
6
+ This source code is licensed under the license found in the
7
+ LICENSE file in the root directory of this source tree.
8
+ """
9
+
10
+ import numpy as np
11
+ import torch
12
+ import cv2
13
+
14
+ import math
15
+ from copy import deepcopy
16
+ from itertools import product
17
+ from typing import Any, Dict, Generator, ItemsView, List, Tuple
18
+
19
+
20
+ class MaskData:
21
+ """
22
+ A structure for storing masks and their related data in batched format.
23
+ Implements basic filtering and concatenation.
24
+ """
25
+
26
+ def __init__(self, **kwargs) -> None:
27
+ for v in kwargs.values():
28
+ assert isinstance(
29
+ v, (list, np.ndarray, torch.Tensor)
30
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
31
+ self._stats = dict(**kwargs)
32
+
33
+ def __setitem__(self, key: str, item: Any) -> None:
34
+ assert isinstance(
35
+ item, (list, np.ndarray, torch.Tensor)
36
+ ), "MaskData only supports list, numpy arrays, and torch tensors."
37
+ self._stats[key] = item
38
+
39
+ def __delitem__(self, key: str) -> None:
40
+ del self._stats[key]
41
+
42
+ def __getitem__(self, key: str) -> Any:
43
+ return self._stats[key]
44
+
45
+ def items(self) -> ItemsView[str, Any]:
46
+ return self._stats.items()
47
+
48
+ def filter(self, keep: torch.Tensor) -> None:
49
+ for k, v in self._stats.items():
50
+ if v is None:
51
+ self._stats[k] = None
52
+ elif isinstance(v, torch.Tensor):
53
+ self._stats[k] = v[torch.as_tensor(keep, device=v.device)]
54
+ elif isinstance(v, np.ndarray):
55
+ self._stats[k] = v[keep.detach().cpu().numpy()]
56
+ elif isinstance(v, list) and keep.dtype == torch.bool:
57
+ self._stats[k] = [a for i, a in enumerate(v) if keep[i]]
58
+ elif isinstance(v, list):
59
+ self._stats[k] = [v[i] for i in keep]
60
+ else:
61
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
62
+
63
+ def cat(self, new_stats: "MaskData") -> None:
64
+ for k, v in new_stats.items():
65
+ if k not in self._stats or self._stats[k] is None:
66
+ self._stats[k] = deepcopy(v)
67
+ elif isinstance(v, torch.Tensor):
68
+ self._stats[k] = torch.cat([self._stats[k], v], dim=0)
69
+ elif isinstance(v, np.ndarray):
70
+ self._stats[k] = np.concatenate([self._stats[k], v], axis=0)
71
+ elif isinstance(v, list):
72
+ self._stats[k] = self._stats[k] + deepcopy(v)
73
+ else:
74
+ raise TypeError(f"MaskData key {k} has an unsupported type {type(v)}.")
75
+
76
+ def to_numpy(self) -> None:
77
+ for k, v in self._stats.items():
78
+ if isinstance(v, torch.Tensor):
79
+ self._stats[k] = v.detach().cpu().numpy()
80
+
81
+
82
+ def is_box_near_crop_edge(
83
+ boxes: torch.Tensor, crop_box: List[int], orig_box: List[int], atol: float = 20.0
84
+ ) -> torch.Tensor:
85
+ """Filter masks at the edge of a crop, but not at the edge of the original image."""
86
+ crop_box_torch = torch.as_tensor(crop_box, dtype=torch.float, device=boxes.device)
87
+ orig_box_torch = torch.as_tensor(orig_box, dtype=torch.float, device=boxes.device)
88
+ boxes = uncrop_boxes_xyxy(boxes, crop_box).float()
89
+ near_crop_edge = torch.isclose(boxes, crop_box_torch[None, :], atol=atol, rtol=0)
90
+ near_image_edge = torch.isclose(boxes, orig_box_torch[None, :], atol=atol, rtol=0)
91
+ near_crop_edge = torch.logical_and(near_crop_edge, ~near_image_edge)
92
+ return torch.any(near_crop_edge, dim=1)
93
+
94
+
95
+ def box_xyxy_to_xywh(box_xyxy: torch.Tensor) -> torch.Tensor:
96
+ box_xywh = deepcopy(box_xyxy)
97
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
98
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
99
+ return box_xywh
100
+
101
+
102
+ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
103
+ assert len(args) > 0 and all(
104
+ len(a) == len(args[0]) for a in args
105
+ ), "Batched iteration must have inputs of all the same size."
106
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
107
+ for b in range(n_batches):
108
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
109
+
110
+
111
+ def mask_to_rle_pytorch(tensor: torch.Tensor) -> List[Dict[str, Any]]:
112
+ """
113
+ Encodes masks to an uncompressed RLE, in the format expected by
114
+ pycoco tools.
115
+ """
116
+ # Put in fortran order and flatten h,w
117
+ b, h, w = tensor.shape
118
+ tensor = tensor.permute(0, 2, 1).flatten(1)
119
+
120
+ # Compute change indices
121
+ diff = tensor[:, 1:] ^ tensor[:, :-1]
122
+ change_indices = diff.nonzero()
123
+
124
+ # Encode run length
125
+ out = []
126
+ for i in range(b):
127
+ cur_idxs = change_indices[change_indices[:, 0] == i, 1]
128
+ cur_idxs = torch.cat(
129
+ [
130
+ torch.tensor([0], dtype=cur_idxs.dtype, device=cur_idxs.device),
131
+ cur_idxs + 1,
132
+ torch.tensor([h * w], dtype=cur_idxs.dtype, device=cur_idxs.device),
133
+ ]
134
+ )
135
+ btw_idxs = cur_idxs[1:] - cur_idxs[:-1]
136
+ counts = [] if tensor[i, 0] == 0 else [0]
137
+ counts.extend(btw_idxs.detach().cpu().tolist())
138
+ out.append({"size": [h, w], "counts": counts})
139
+ return out
140
+
141
+
142
+ def rle_to_mask(rle: Dict[str, Any]) -> np.ndarray:
143
+ """Compute a binary mask from an uncompressed RLE."""
144
+ h, w = rle["size"]
145
+ mask = np.empty(h * w, dtype=bool)
146
+ idx = 0
147
+ parity = False
148
+ for count in rle["counts"]:
149
+ mask[idx : idx + count] = parity
150
+ idx += count
151
+ parity ^= True
152
+ mask = mask.reshape(w, h)
153
+ return mask.transpose() # Put in C order
154
+
155
+
156
+ def area_from_rle(rle: Dict[str, Any]) -> int:
157
+ return sum(rle["counts"][1::2])
158
+
159
+
160
+ def calculate_stability_score(
161
+ masks: torch.Tensor, mask_threshold: float, threshold_offset: float
162
+ ) -> torch.Tensor:
163
+ """
164
+ Computes the stability score for a batch of masks. The stability
165
+ score is the IoU between the binary masks obtained by thresholding
166
+ the predicted mask logits at high and low values.
167
+ """
168
+ # One mask is always contained inside the other.
169
+ # Save memory by preventing unnecessary cast to torch.int64
170
+ intersections = (
171
+ (masks > (mask_threshold + threshold_offset))
172
+ .sum(-1, dtype=torch.int16)
173
+ .sum(-1, dtype=torch.int32)
174
+ )
175
+ unions = (
176
+ (masks > (mask_threshold - threshold_offset))
177
+ .sum(-1, dtype=torch.int16)
178
+ .sum(-1, dtype=torch.int32)
179
+ )
180
+ return intersections / unions
181
+
182
+
183
+ def build_point_grid(n_per_side: int) -> np.ndarray:
184
+ """Generates a 2D grid of points evenly spaced in [0,1]x[0,1]."""
185
+ offset = 1 / (2 * n_per_side)
186
+ points_one_side = np.linspace(offset, 1 - offset, n_per_side)
187
+ points_x = np.tile(points_one_side[None, :], (n_per_side, 1))
188
+ points_y = np.tile(points_one_side[:, None], (1, n_per_side))
189
+ points = np.stack([points_x, points_y], axis=-1).reshape(-1, 2)
190
+ return points
191
+
192
+
193
+ def build_all_layer_point_grids(
194
+ n_per_side: int, n_layers: int, scale_per_layer: int
195
+ ) -> List[np.ndarray]:
196
+ """Generates point grids for all crop layers."""
197
+ points_by_layer = []
198
+ for i in range(n_layers + 1):
199
+ n_points = int(n_per_side / (scale_per_layer**i))
200
+ points_by_layer.append(build_point_grid(n_points))
201
+ return points_by_layer
202
+
203
+
204
+ def generate_crop_boxes(
205
+ im_size: Tuple[int, ...], n_layers: int, overlap_ratio: float
206
+ ) -> Tuple[List[List[int]], List[int]]:
207
+ """
208
+ Generates a list of crop boxes of different sizes. Each layer
209
+ has (2**i)**2 boxes for the ith layer.
210
+ """
211
+ crop_boxes, layer_idxs = [], []
212
+ im_h, im_w = im_size
213
+ short_side = min(im_h, im_w)
214
+
215
+ # Original image
216
+ crop_boxes.append([0, 0, im_w, im_h])
217
+ layer_idxs.append(0)
218
+
219
+ def crop_len(orig_len, n_crops, overlap):
220
+ return int(math.ceil((overlap * (n_crops - 1) + orig_len) / n_crops))
221
+
222
+ for i_layer in range(n_layers):
223
+ n_crops_per_side = 2 ** (i_layer + 1)
224
+ overlap = int(overlap_ratio * short_side * (2 / n_crops_per_side))
225
+
226
+ crop_w = crop_len(im_w, n_crops_per_side, overlap)
227
+ crop_h = crop_len(im_h, n_crops_per_side, overlap)
228
+
229
+ crop_box_x0 = [int((crop_w - overlap) * i) for i in range(n_crops_per_side)]
230
+ crop_box_y0 = [int((crop_h - overlap) * i) for i in range(n_crops_per_side)]
231
+
232
+ # Crops in XYWH format
233
+ for x0, y0 in product(crop_box_x0, crop_box_y0):
234
+ box = [x0, y0, min(x0 + crop_w, im_w), min(y0 + crop_h, im_h)]
235
+ crop_boxes.append(box)
236
+ layer_idxs.append(i_layer + 1)
237
+
238
+ return crop_boxes, layer_idxs
239
+
240
+
241
+ def uncrop_boxes_xyxy(boxes: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
242
+ x0, y0, _, _ = crop_box
243
+ offset = torch.tensor([[x0, y0, x0, y0]], device=boxes.device)
244
+ # Check if boxes has a channel dimension
245
+ if len(boxes.shape) == 3:
246
+ offset = offset.unsqueeze(1)
247
+ return boxes + offset
248
+
249
+
250
+ def uncrop_points(points: torch.Tensor, crop_box: List[int]) -> torch.Tensor:
251
+ x0, y0, _, _ = crop_box
252
+ offset = torch.tensor([[x0, y0]], device=points.device)
253
+ # Check if points has a channel dimension
254
+ if len(points.shape) == 3:
255
+ offset = offset.unsqueeze(1)
256
+ return points + offset
257
+
258
+
259
+ def uncrop_masks(
260
+ masks: torch.Tensor, crop_box: List[int], orig_h: int, orig_w: int
261
+ ) -> torch.Tensor:
262
+ x0, y0, x1, y1 = crop_box
263
+ if x0 == 0 and y0 == 0 and x1 == orig_w and y1 == orig_h:
264
+ return masks
265
+ # Coordinate transform masks
266
+ pad_x, pad_y = orig_w - (x1 - x0), orig_h - (y1 - y0)
267
+ pad = (x0, pad_x - x0, y0, pad_y - y0)
268
+ return torch.nn.functional.pad(masks, pad, value=0)
269
+
270
+
271
+ def remove_small_regions(
272
+ mask: np.ndarray, area_thresh: float, mode: str
273
+ ) -> Tuple[np.ndarray, bool]:
274
+ """
275
+ Removes small disconnected regions and holes in a mask. Returns the
276
+ mask and an indicator of if the mask has been modified.
277
+ """
278
+ import cv2 # type: ignore
279
+
280
+ assert mode in ["holes", "islands"]
281
+ correct_holes = mode == "holes"
282
+ working_mask = (correct_holes ^ mask).astype(np.uint8)
283
+ n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
284
+ sizes = stats[:, -1][1:] # Row 0 is background label
285
+ small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
286
+ if len(small_regions) == 0:
287
+ return mask, False
288
+ fill_labels = [0] + small_regions
289
+ if not correct_holes:
290
+ fill_labels = [i for i in range(n_labels) if i not in fill_labels]
291
+ # If every region is below threshold, keep largest
292
+ if len(fill_labels) == 0:
293
+ fill_labels = [int(np.argmax(sizes)) + 1]
294
+ mask = np.isin(regions, fill_labels)
295
+ return mask, True
296
+
297
+
298
+ def coco_encode_rle(uncompressed_rle: Dict[str, Any]) -> Dict[str, Any]:
299
+ from pycocotools import mask as mask_utils # type: ignore
300
+
301
+ h, w = uncompressed_rle["size"]
302
+ rle = mask_utils.frPyObjects(uncompressed_rle, h, w)
303
+ rle["counts"] = rle["counts"].decode("utf-8") # Necessary to serialize with json
304
+ return rle
305
+
306
+
307
+ def batched_mask_to_box(masks: torch.Tensor) -> torch.Tensor:
308
+ """
309
+ Calculates boxes in XYXY format around masks. Return [0,0,0,0] for
310
+ an empty mask. For input shape C1xC2x...xHxW, the output shape is C1xC2x...x4.
311
+ """
312
+ # torch.max below raises an error on empty inputs, just skip in this case
313
+ if torch.numel(masks) == 0:
314
+ return torch.zeros(*masks.shape[:-2], 4, device=masks.device)
315
+
316
+ # Normalize shape to CxHxW
317
+ shape = masks.shape
318
+ h, w = shape[-2:]
319
+ if len(shape) > 2:
320
+ masks = masks.flatten(0, -3)
321
+ else:
322
+ masks = masks.unsqueeze(0)
323
+
324
+ # Get top and bottom edges
325
+ in_height, _ = torch.max(masks, dim=-1)
326
+ in_height_coords = in_height * torch.arange(h, device=in_height.device)[None, :]
327
+ bottom_edges, _ = torch.max(in_height_coords, dim=-1)
328
+ in_height_coords = in_height_coords + h * (~in_height)
329
+ top_edges, _ = torch.min(in_height_coords, dim=-1)
330
+
331
+ # Get left and right edges
332
+ in_width, _ = torch.max(masks, dim=-2)
333
+ in_width_coords = in_width * torch.arange(w, device=in_width.device)[None, :]
334
+ right_edges, _ = torch.max(in_width_coords, dim=-1)
335
+ in_width_coords = in_width_coords + w * (~in_width)
336
+ left_edges, _ = torch.min(in_width_coords, dim=-1)
337
+
338
+ # If the mask is empty the right edge will be to the left of the left edge.
339
+ # Replace these boxes with [0, 0, 0, 0]
340
+ empty_filter = (right_edges < left_edges) | (bottom_edges < top_edges)
341
+ out = torch.stack([left_edges, top_edges, right_edges, bottom_edges], dim=-1)
342
+ out = out * (~empty_filter).unsqueeze(-1)
343
+
344
+ # Return to original shape
345
+ if len(shape) > 2:
346
+ out = out.reshape(*shape[:-2], 4)
347
+ else:
348
+ out = out[0]
349
+
350
+ return out
351
+
352
+ def show_mat_anns(image, anns):
353
+ if len(anns) == 0:
354
+ return np.zeros_like(image) + 128
355
+
356
+ sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
357
+
358
+ image = image.astype(np.float32)
359
+ colorized_mat = np.zeros_like(image)
360
+
361
+ for ann in sorted_anns:
362
+ color = (np.random.random(3) * 255).astype(np.float32)
363
+ if 'logit' in ann:
364
+ mat = ann['logit'].astype(np.float32) / 255.
365
+ else:
366
+ mat = ann['segmentation'].astype(np.float32)
367
+
368
+ color_mat = np.zeros_like(image) + color[None, None]
369
+ colorized_mat = color_mat * mat[:, :, None] + colorized_mat * (1. - mat[:, :, None])
370
+
371
+ colorized_mat = np.uint8(colorized_mat)
372
+
373
+ return colorized_mat
zim/utils/argparser.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024-present Naver Cloud Corp.
3
+
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import os
9
+ import argparse
10
+ from config.config import config_
11
+
12
+ def str2bool(v):
13
+ if isinstance(v, bool):
14
+ return v
15
+ if v.lower() in ("yes", "true", "t", "y", "1"):
16
+ return True
17
+ elif v.lower() in ("no", "false", "f", "n", "0"):
18
+ return False
19
+ else:
20
+ raise argparse.ArgumentTypeError("Boolean value expected.")
21
+
22
+
23
+ def get_parser(verbose=False):
24
+ p = argparse.ArgumentParser("argparser", add_help=False)
25
+
26
+ p.add_argument(
27
+ "--data-root", type=str, default=config_.data_root, help="data root directory"
28
+ )
29
+ p.add_argument(
30
+ "--local_rank", type=int, default=int(os.getenv("LOCAL_RANK", "0"))
31
+ )
32
+ p.add_argument(
33
+ "--amp", type=str2bool, default=True
34
+ )
35
+ p.add_argument(
36
+ "--ddp", action="store_true"
37
+ )
38
+ p.add_argument(
39
+ "--random-seed", type=int, default=config_.random_seed
40
+ )
41
+
42
+ # network config
43
+ p.add_argument(
44
+ "--network-encoder",
45
+ type=str,
46
+ default=config_.network.encoder,
47
+ choices=["vit_b", "vit_l"],
48
+ )
49
+ p.add_argument(
50
+ "--network-decoder",
51
+ type=str,
52
+ default=config_.network.decoder,
53
+ choices=["zim", "sam"],
54
+ )
55
+ p.add_argument(
56
+ "--network-encode-kernel",
57
+ type=int,
58
+ default=config_.network.encode_kernel,
59
+ )
60
+
61
+ # evaluation config
62
+ p.add_argument(
63
+ "--eval-workers", type=int, default=config_.eval.workers,
64
+ )
65
+ p.add_argument(
66
+ "--eval-image-size", type=int, default=config_.eval.image_size,
67
+ )
68
+ p.add_argument(
69
+ "--eval-prompt-type", type=str, default=config_.eval.prompt_type,
70
+ )
71
+ p.add_argument(
72
+ "--eval-model-list", type=str, default=config_.eval.model_list,
73
+ )
74
+ p.add_argument(
75
+ "--eval-zim-weights",
76
+ type=str,
77
+ default=config_.eval.zim_weights,
78
+ )
79
+ p.add_argument(
80
+ "--eval-sam-weights",
81
+ type=str,
82
+ default=config_.eval.sam_weights,
83
+ )
84
+
85
+ # dataset config
86
+ p.add_argument(
87
+ "--dataset-valset", type=str, default=config_.dataset.valset,
88
+ )
89
+ p.add_argument(
90
+ "--dataset-data-type", type=str, default=config_.dataset.data_type,
91
+ )
92
+ p.add_argument(
93
+ "--dataset-data-list-txt", type=str, default=config_.dataset.data_list_txt,
94
+ )
95
+
96
+ return p
zim/utils/print.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024-present Naver Cloud Corp.
3
+
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import torch
9
+
10
+ def print_once(message):
11
+ if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0:
12
+ print(message)
13
+
14
+ def pretty(d, indent=0):
15
+ for key, value in d.items():
16
+ print_once("\t" * indent + str(key))
17
+ if isinstance(value, dict):
18
+ pretty(value, indent + 1)
19
+ else:
20
+ print_once("\t" * (indent + 1) + str(value))
zim/utils/utils.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright (c) 2024-present Naver Cloud Corp.
3
+
4
+ This source code is licensed under the license found in the
5
+ LICENSE file in the root directory of this source tree.
6
+ """
7
+
8
+ import numpy as np
9
+ import torch
10
+ from torch.nn import functional as F
11
+ from torchvision.transforms.functional import resize, to_pil_image, InterpolationMode
12
+ from copy import deepcopy
13
+ from typing import Optional, Tuple, List
14
+
15
+ class ResizeLongestSide:
16
+ """
17
+ Resizes images to the longest side 'target_length', as well as provides
18
+ methods for resizing coordinates and boxes. Provides methods for
19
+ transforming both numpy array and batched torch tensors.
20
+ """
21
+
22
+ def __init__(self, target_length: int) -> None:
23
+ self.target_length = target_length
24
+
25
+ def apply_image(self, image: np.ndarray) -> np.ndarray:
26
+ """
27
+ Expects a numpy array with shape HxWxC in uint8 format.
28
+ """
29
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
30
+ return np.array(resize(to_pil_image(image), target_size))
31
+
32
+ def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
33
+ """
34
+ Expects a numpy array of length 2 in the final dimension. Requires the
35
+ original image size in (H, W) format.
36
+ """
37
+ old_h, old_w = original_size
38
+ new_h, new_w = self.get_preprocess_shape(
39
+ original_size[0], original_size[1], self.target_length
40
+ )
41
+ coords = deepcopy(coords).astype(float)
42
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
43
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
44
+ return coords
45
+
46
+ def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
47
+ """
48
+ Expects a numpy array shape Bx4. Requires the original image size
49
+ in (H, W) format.
50
+ """
51
+ boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
52
+ return boxes.reshape(-1, 4)
53
+
54
+ def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
55
+ """
56
+ Expects batched images with shape BxCxHxW and float format. This
57
+ transformation may not exactly match apply_image. apply_image is
58
+ the transformation expected by the model.
59
+ """
60
+ # Expects an image in BCHW format. May not exactly match apply_image.
61
+ target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
62
+ return F.interpolate(
63
+ image, target_size, mode="bilinear", align_corners=False, antialias=True
64
+ )
65
+
66
+ def apply_coords_torch(
67
+ self, coords: torch.Tensor, original_size: Tuple[int, ...]
68
+ ) -> torch.Tensor:
69
+ """
70
+ Expects a torch tensor with length 2 in the last dimension. Requires the
71
+ original image size in (H, W) format.
72
+ """
73
+ old_h, old_w = original_size
74
+ new_h, new_w = self.get_preprocess_shape(
75
+ original_size[0], original_size[1], self.target_length
76
+ )
77
+ coords = deepcopy(coords).to(torch.float)
78
+ coords[..., 0] = coords[..., 0] * (new_w / old_w)
79
+ coords[..., 1] = coords[..., 1] * (new_h / old_h)
80
+ return coords
81
+
82
+ def apply_boxes_torch(
83
+ self, boxes: torch.Tensor, original_size: Tuple[int, ...]
84
+ ) -> torch.Tensor:
85
+ """
86
+ Expects a torch tensor with shape Bx4. Requires the original image
87
+ size in (H, W) format.
88
+ """
89
+ boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
90
+ return boxes.reshape(-1, 4)
91
+
92
+ def apply_mask(self, image: np.ndarray) -> np.ndarray:
93
+ """
94
+ Expects a numpy array with shape HxWxC in uint8 format.
95
+ """
96
+ target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
97
+ return np.array(resize(to_pil_image(image), target_size, interpolation=InterpolationMode.NEAREST))
98
+
99
+ @staticmethod
100
+ def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
101
+ """
102
+ Compute the output size given input size and target long side length.
103
+ """
104
+ scale = long_side_length * 1.0 / max(oldh, oldw)
105
+ newh, neww = oldh * scale, oldw * scale
106
+ neww = int(neww + 0.5)
107
+ newh = int(newh + 0.5)
108
+ return (newh, neww)
109
+
110
+
111
+ def remove_prefix(text, prefix):
112
+ if text.startswith(prefix):
113
+ return text[len(prefix) :]
114
+ return text
115
+
116
+ class AverageMeter(object):
117
+ """Computes and stores the average and current value"""
118
+
119
+ def __init__(self, is_ddp):
120
+ self.is_ddp = is_ddp
121
+ self.reset()
122
+
123
+ def reset(self):
124
+ self.val = 0.0
125
+ self.avg = 0.0
126
+ self.sum = 0.0
127
+ self.count = 0.0
128
+
129
+ def update(self, val, n=1):
130
+ self.val = val
131
+ self.sum += val * n
132
+ self.count += n
133
+ self.avg = self.sum / (self.count + 1e-5)
134
+
135
+ def synch(self, device):
136
+ if self.is_ddp is False:
137
+ return
138
+
139
+ _sum = torch.tensor(self.sum).to(device)
140
+ _count = torch.tensor(self.count).to(device)
141
+
142
+ torch.distributed.reduce(_sum, dst=0)
143
+ torch.distributed.reduce(_count, dst=0)
144
+
145
+ if torch.distributed.get_rank() == 0:
146
+ self.sum = _sum.item()
147
+ self.count = _count.item()
148
+ self.avg = self.sum / (self.count + 1e-5)