hujiecpp commited on
Commit
3d8bebe
·
1 Parent(s): 8bcfdbb

init project

Browse files
Files changed (1) hide show
  1. app.py +345 -347
app.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import sys
3
  sys.path.append(os.path.abspath('./modules'))
4
 
5
- # import math
6
  import tempfile
7
  import gradio
8
  import torch
@@ -11,7 +11,7 @@ import numpy as np
11
  import functools
12
  import trimesh
13
  import copy
14
- # from PIL import Image
15
  from scipy.spatial.transform import Rotation
16
 
17
  from modules.pe3r.images import Images
@@ -22,25 +22,25 @@ from modules.dust3r.utils.image import load_images #, rgb
22
  from modules.dust3r.utils.device import to_numpy
23
  from modules.dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
24
  from modules.dust3r.cloud_opt import global_aligner, GlobalAlignerMode
25
- # from copy import deepcopy
26
- # import cv2
27
- # from typing import Any, Dict, Generator,List
28
- # import matplotlib.pyplot as pl
29
 
30
- # from modules.mobilesamv2.utils.transforms import ResizeLongestSide
31
- # from modules.pe3r.models import Models
32
  import torchvision.transforms as tvf
33
 
34
- # sys.path.append(os.path.abspath('./modules/ultralytics'))
35
 
36
- # from transformers import AutoTokenizer, AutoModel, AutoProcessor, SamModel
37
- # from modules.mast3r.model import AsymmetricMASt3R
38
 
39
- # from modules.sam2.build_sam import build_sam2_video_predictor
40
- # from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
41
- # from modules.mobilesamv2 import sam_model_registry
42
 
43
- # from sam2.sam2_video_predictor import SAM2VideoPredictor
44
  from modules.mast3r.model import AsymmetricMASt3R
45
 
46
 
@@ -117,337 +117,337 @@ def get_3D_model_from_scene(outdir, scene, min_conf_thr=3, as_pointcloud=False,
117
  return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
118
  transparent_cams=transparent_cams, cam_size=cam_size)
119
 
120
- # def mask_nms(masks, threshold=0.8):
121
- # keep = []
122
- # mask_num = len(masks)
123
- # suppressed = np.zeros((mask_num), dtype=np.int64)
124
- # for i in range(mask_num):
125
- # if suppressed[i] == 1:
126
- # continue
127
- # keep.append(i)
128
- # for j in range(i + 1, mask_num):
129
- # if suppressed[j] == 1:
130
- # continue
131
- # intersection = (masks[i] & masks[j]).sum()
132
- # if min(intersection / masks[i].sum(), intersection / masks[j].sum()) > threshold:
133
- # suppressed[j] = 1
134
- # return keep
135
-
136
- # def filter(masks, keep):
137
- # ret = []
138
- # for i, m in enumerate(masks):
139
- # if i in keep: ret.append(m)
140
- # return ret
141
-
142
- # def mask_to_box(mask):
143
- # if mask.sum() == 0:
144
- # return np.array([0, 0, 0, 0])
145
 
146
- # # Get the rows and columns where the mask is 1
147
- # rows = np.any(mask, axis=1)
148
- # cols = np.any(mask, axis=0)
149
 
150
- # # Get top, bottom, left, right edges
151
- # top = np.argmax(rows)
152
- # bottom = len(rows) - 1 - np.argmax(np.flip(rows))
153
- # left = np.argmax(cols)
154
- # right = len(cols) - 1 - np.argmax(np.flip(cols))
155
 
156
- # return np.array([left, top, right, bottom])
157
-
158
- # def box_xyxy_to_xywh(box_xyxy):
159
- # box_xywh = deepcopy(box_xyxy)
160
- # box_xywh[2] = box_xywh[2] - box_xywh[0]
161
- # box_xywh[3] = box_xywh[3] - box_xywh[1]
162
- # return box_xywh
163
-
164
- # def get_seg_img(mask, box, image):
165
- # image = image.copy()
166
- # x, y, w, h = box
167
- # # image[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
168
- # box_area = w * h
169
- # mask_area = mask.sum()
170
- # if 1 - (mask_area / box_area) < 0.2:
171
- # image[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
172
- # else:
173
- # random_values = np.random.randint(0, 255, size=image.shape, dtype=np.uint8)
174
- # image[mask == 0] = random_values[mask == 0]
175
- # seg_img = image[y:y+h, x:x+w, ...]
176
- # return seg_img
177
-
178
- # def pad_img(img):
179
- # h, w, _ = img.shape
180
- # l = max(w,h)
181
- # pad = np.zeros((l,l,3), dtype=np.uint8) #
182
- # if h > w:
183
- # pad[:,(h-w)//2:(h-w)//2 + w, :] = img
184
- # else:
185
- # pad[(w-h)//2:(w-h)//2 + h, :, :] = img
186
- # return pad
187
-
188
- # def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
189
- # assert len(args) > 0 and all(
190
- # len(a) == len(args[0]) for a in args
191
- # ), "Batched iteration must have inputs of all the same size."
192
- # n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
193
- # for b in range(n_batches):
194
- # yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
195
-
196
- # def slerp(u1, u2, t):
197
- # """
198
- # Perform spherical linear interpolation (Slerp) between two unit vectors.
199
 
200
- # Args:
201
- # - u1 (torch.Tensor): First unit vector, shape (1024,)
202
- # - u2 (torch.Tensor): Second unit vector, shape (1024,)
203
- # - t (float): Interpolation parameter
204
 
205
- # Returns:
206
- # - torch.Tensor: Interpolated vector, shape (1024,)
207
- # """
208
- # # Compute the dot product
209
- # dot_product = torch.sum(u1 * u2)
210
 
211
- # # Ensure the dot product is within the valid range [-1, 1]
212
- # dot_product = torch.clamp(dot_product, -1.0, 1.0)
213
 
214
- # # Compute the angle between the vectors
215
- # theta = torch.acos(dot_product)
216
 
217
- # # Compute the coefficients for the interpolation
218
- # sin_theta = torch.sin(theta)
219
- # if sin_theta == 0:
220
- # # Vectors are parallel, return a linear interpolation
221
- # return u1 + t * (u2 - u1)
222
 
223
- # s1 = torch.sin((1 - t) * theta) / sin_theta
224
- # s2 = torch.sin(t * theta) / sin_theta
225
 
226
- # # Perform the interpolation
227
- # return s1 * u1 + s2 * u2
228
 
229
- # def slerp_multiple(vectors, t_values):
230
- # """
231
- # Perform spherical linear interpolation (Slerp) for multiple vectors.
232
 
233
- # Args:
234
- # - vectors (torch.Tensor): Tensor of vectors, shape (n, 1024)
235
- # - a_values (torch.Tensor): Tensor of values corresponding to each vector, shape (n,)
236
 
237
- # Returns:
238
- # - torch.Tensor: Interpolated vector, shape (1024,)
239
- # """
240
- # n = vectors.shape[0]
241
 
242
- # # Initialize the interpolated vector with the first vector
243
- # interpolated_vector = vectors[0]
244
 
245
- # # Perform Slerp iteratively
246
- # for i in range(1, n):
247
- # # Perform Slerp between the current interpolated vector and the next vector
248
- # t = t_values[i] / (t_values[i] + t_values[i-1])
249
- # interpolated_vector = slerp(interpolated_vector, vectors[i], t)
250
 
251
- # return interpolated_vector
252
 
253
- # @torch.no_grad
254
- # def get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_image, yolov8_image, original_size, input_size, transform):
255
 
256
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
257
 
258
 
259
- # sam_mask=[]
260
- # img_area = original_size[0] * original_size[1]
261
 
262
- # obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=1024,conf=0.25,iou=0.95,verbose=False)
263
- # input_boxes1 = obj_results[0].boxes.xyxy
264
- # input_boxes1 = input_boxes1.cpu().numpy()
265
- # input_boxes1 = transform.apply_boxes(input_boxes1, original_size)
266
- # input_boxes = torch.from_numpy(input_boxes1).to(device)
267
 
268
- # # obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=512,conf=0.25,iou=0.9,verbose=False)
269
- # # input_boxes2 = obj_results[0].boxes.xyxy
270
- # # input_boxes2 = input_boxes2.cpu().numpy()
271
- # # input_boxes2 = transform.apply_boxes(input_boxes2, original_size)
272
- # # input_boxes2 = torch.from_numpy(input_boxes2).to(device)
273
-
274
- # # input_boxes = torch.cat((input_boxes1, input_boxes2), dim=0)
275
-
276
- # input_image = mobilesamv2.preprocess(sam1_image)
277
- # image_embedding = mobilesamv2.image_encoder(input_image)['last_hidden_state']
278
-
279
- # image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
280
- # prompt_embedding=mobilesamv2.prompt_encoder.get_dense_pe()
281
- # prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
282
- # for (boxes,) in batch_iterator(320, input_boxes):
283
- # with torch.no_grad():
284
- # image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
285
- # prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
286
- # sparse_embeddings, dense_embeddings = mobilesamv2.prompt_encoder(
287
- # points=None,
288
- # boxes=boxes,
289
- # masks=None,)
290
- # low_res_masks, _ = mobilesamv2.mask_decoder(
291
- # image_embeddings=image_embedding,
292
- # image_pe=prompt_embedding,
293
- # sparse_prompt_embeddings=sparse_embeddings,
294
- # dense_prompt_embeddings=dense_embeddings,
295
- # multimask_output=False,
296
- # simple_type=True,
297
- # )
298
- # low_res_masks=mobilesamv2.postprocess_masks(low_res_masks, input_size, original_size)
299
- # sam_mask_pre = (low_res_masks > mobilesamv2.mask_threshold)
300
- # for mask in sam_mask_pre:
301
- # if mask.sum() / img_area > 0.002:
302
- # sam_mask.append(mask.squeeze(1))
303
- # sam_mask=torch.cat(sam_mask)
304
- # sorted_sam_mask = sorted(sam_mask, key=(lambda x: x.sum()), reverse=True)
305
- # keep = mask_nms(sorted_sam_mask)
306
- # ret_mask = filter(sorted_sam_mask, keep)
307
-
308
- # return ret_mask
309
-
310
- # @torch.no_grad
311
- # def get_cog_feats(images, sam2, siglip, siglip_processor, yolov8, mobilesamv2):
312
-
313
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
314
-
315
- # cog_seg_maps = []
316
- # rev_cog_seg_maps = []
317
- # inference_state = sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
318
- # mask_num = 0
319
-
320
- # sam1_images = images.sam1_images
321
- # sam1_images_size = images.sam1_images_size
322
- # np_images = images.np_images
323
- # np_images_size = images.np_images_size
324
 
325
- # sam1_masks = get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform)
326
- # for mask in sam1_masks:
327
- # _, _, _ = sam2.add_new_mask(
328
- # inference_state=inference_state,
329
- # frame_idx=0,
330
- # obj_id=mask_num,
331
- # mask=mask,
332
- # )
333
- # mask_num += 1
334
-
335
- # video_segments = {} # video_segments contains the per-frame segmentation results
336
- # for out_frame_idx, out_obj_ids, out_mask_logits in sam2.propagate_in_video(inference_state):
337
- # sam2_masks = (out_mask_logits > 0.0).squeeze(1)
338
-
339
- # video_segments[out_frame_idx] = {
340
- # out_obj_id: sam2_masks[i].cpu().numpy()
341
- # for i, out_obj_id in enumerate(out_obj_ids)
342
- # }
343
-
344
- # if out_frame_idx == 0:
345
- # continue
346
-
347
- # sam1_masks = get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_images[out_frame_idx], np_images[out_frame_idx], np_images_size[out_frame_idx], sam1_images_size[out_frame_idx], images.sam1_transform)
348
-
349
- # for sam1_mask in sam1_masks:
350
- # flg = 1
351
- # for sam2_mask in sam2_masks:
352
- # # print(sam1_mask.shape, sam2_mask.shape)
353
- # area1 = sam1_mask.sum()
354
- # area2 = sam2_mask.sum()
355
- # intersection = (sam1_mask & sam2_mask).sum()
356
- # if min(intersection / area1, intersection / area2) > 0.25:
357
- # flg = 0
358
- # break
359
- # if flg:
360
- # video_segments[out_frame_idx][mask_num] = sam1_mask.cpu().numpy()
361
- # mask_num += 1
362
-
363
- # multi_view_clip_feats = torch.zeros((mask_num+1, 1024))
364
- # multi_view_clip_feats_map = {}
365
- # multi_view_clip_area_map = {}
366
- # for now_frame in range(0, len(video_segments), 1):
367
- # image = np_images[now_frame]
368
-
369
- # seg_img_list = []
370
- # out_obj_id_list = []
371
- # out_obj_mask_list = []
372
- # out_obj_area_list = []
373
- # # NOTE: background: -1
374
- # rev_seg_map = -np.ones(image.shape[:2], dtype=np.int64)
375
- # sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=False)
376
- # for out_obj_id, mask in sorted_dict_items:
377
- # if mask.sum() == 0:
378
- # continue
379
- # rev_seg_map[mask] = out_obj_id
380
- # rev_cog_seg_maps.append(rev_seg_map)
381
-
382
- # seg_map = -np.ones(image.shape[:2], dtype=np.int64)
383
- # sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=True)
384
- # for out_obj_id, mask in sorted_dict_items:
385
- # if mask.sum() == 0:
386
- # continue
387
- # box = np.int32(box_xyxy_to_xywh(mask_to_box(mask)))
388
 
389
- # if box[2] == 0 and box[3] == 0:
390
- # continue
391
- # # print(box)
392
- # seg_img = get_seg_img(mask, box, image)
393
- # pad_seg_img = cv2.resize(pad_img(seg_img), (256,256))
394
- # seg_img_list.append(pad_seg_img)
395
- # seg_map[mask] = out_obj_id
396
- # out_obj_id_list.append(out_obj_id)
397
- # out_obj_area_list.append(np.count_nonzero(mask))
398
- # out_obj_mask_list.append(mask)
399
-
400
- # if len(seg_img_list) == 0:
401
- # cog_seg_maps.append(seg_map)
402
- # continue
403
-
404
- # seg_imgs = np.stack(seg_img_list, axis=0) # b,H,W,3
405
- # seg_imgs = torch.from_numpy(seg_imgs).permute(0,3,1,2) # / 255.0
406
 
407
- # inputs = siglip_processor(images=seg_imgs, return_tensors="pt")
408
- # inputs = {key: value.to(device) for key, value in inputs.items()}
409
 
410
- # image_features = siglip.get_image_features(**inputs)
411
- # image_features = image_features / image_features.norm(dim=-1, keepdim=True)
412
- # image_features = image_features.detach().cpu()
413
-
414
- # for i in range(len(out_obj_mask_list)):
415
- # for j in range(i + 1, len(out_obj_mask_list)):
416
- # mask1 = out_obj_mask_list[i]
417
- # mask2 = out_obj_mask_list[j]
418
- # intersection = np.logical_and(mask1, mask2).sum()
419
- # area1 = out_obj_area_list[i]
420
- # area2 = out_obj_area_list[j]
421
- # if min(intersection / area1, intersection / area2) > 0.025:
422
- # conf1 = area1 / (area1 + area2)
423
- # # conf2 = area2 / (area1 + area2)
424
- # image_features[j] = slerp(image_features[j], image_features[i], conf1)
425
-
426
- # for i, clip_feat in enumerate(image_features):
427
- # id = out_obj_id_list[i]
428
- # if id in multi_view_clip_feats_map.keys():
429
- # multi_view_clip_feats_map[id].append(clip_feat)
430
- # multi_view_clip_area_map[id].append(out_obj_area_list[i])
431
- # else:
432
- # multi_view_clip_feats_map[id] = [clip_feat]
433
- # multi_view_clip_area_map[id] = [out_obj_area_list[i]]
434
-
435
- # cog_seg_maps.append(seg_map)
436
- # del image_features
437
 
438
- # for i in range(mask_num):
439
- # if i in multi_view_clip_feats_map.keys():
440
- # clip_feats = multi_view_clip_feats_map[i]
441
- # mask_area = multi_view_clip_area_map[i]
442
- # multi_view_clip_feats[i] = slerp_multiple(torch.stack(clip_feats), np.stack(mask_area))
443
- # else:
444
- # multi_view_clip_feats[i] = torch.zeros((1024))
445
- # multi_view_clip_feats[mask_num] = torch.zeros((1024))
446
 
447
- # return cog_seg_maps, rev_cog_seg_maps, multi_view_clip_feats
448
 
449
 
450
- @spaces.GPU#(duration=30)
451
  def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_conf_thr=3.0,
452
  as_pointcloud=True, mask_sky=False, clean_depth=True, transparent_cams=True, cam_size=0.05,
453
  scenegraph_type='complete', winsize=1, refid=0):
@@ -461,25 +461,25 @@ def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_
461
  MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
462
  mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device)
463
 
464
- # sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large', device=device)
465
 
466
- # siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
467
- # siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256")
468
 
469
- # SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
470
- # mobilesamv2 = sam_model_registry['sam_vit_h'](None)
471
- # sam1 = SamModel.from_pretrained('facebook/sam-vit-huge')
472
- # image_encoder = sam1.vision_encoder
473
 
474
- # prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
475
- # mobilesamv2.prompt_encoder = prompt_encoder
476
- # mobilesamv2.mask_decoder = mask_decoder
477
- # mobilesamv2.image_encoder=image_encoder
478
- # mobilesamv2.to(device=device)
479
- # mobilesamv2.eval()
480
 
481
- # YOLO8_CKP='./checkpoints/ObjectAwareModel.pt'
482
- # yolov8 = ObjectAwareModel(YOLO8_CKP)
483
 
484
  if len(filelist) < 2:
485
  raise gradio.Error("Please input at least 2 images.")
@@ -487,16 +487,16 @@ def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_
487
  images = Images(filelist=filelist, device=device)
488
 
489
  # try:
490
- # cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images, sam2, siglip, siglip_processor, yolov8, mobilesamv2)
491
- # imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
492
- # except Exception as e:
493
- rev_cog_seg_maps = []
494
- for tmp_img in images.np_images:
495
- rev_seg_map = -np.ones(tmp_img.shape[:2], dtype=np.int64)
496
- rev_cog_seg_maps.append(rev_seg_map)
497
- cog_seg_maps = rev_cog_seg_maps
498
- cog_feats = torch.zeros((1, 1024))
499
  imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
 
 
 
 
 
 
 
 
500
 
501
  if len(imgs) == 1:
502
  imgs = [imgs[0], copy.deepcopy(imgs[0])]
@@ -537,19 +537,17 @@ def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_
537
  outfile = get_3D_model_from_scene(outdir, scene, min_conf_thr, as_pointcloud, mask_sky,
538
  clean_depth, transparent_cams, cam_size)
539
 
540
- # scene.to('cpu')
541
- # print(scene)
542
- # print(scene.imgs)
543
- # print(scene.cogs) scene,
544
-
545
  torch.cuda.empty_cache()
 
 
 
546
  return outfile
547
 
548
- # @spaces.GPU #(duration=30)
549
  # def get_3D_object_from_scene(outdir, text, threshold, scene, min_conf_thr, as_pointcloud,
550
  # mask_sky, clean_depth, transparent_cams, cam_size):
551
 
552
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
553
  # siglip_tokenizer = AutoTokenizer.from_pretrained("google/siglip-large-patch16-256")
554
  # siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
555
 
 
2
  import sys
3
  sys.path.append(os.path.abspath('./modules'))
4
 
5
+ import math
6
  import tempfile
7
  import gradio
8
  import torch
 
11
  import functools
12
  import trimesh
13
  import copy
14
+ from PIL import Image
15
  from scipy.spatial.transform import Rotation
16
 
17
  from modules.pe3r.images import Images
 
22
  from modules.dust3r.utils.device import to_numpy
23
  from modules.dust3r.viz import add_scene_cam, CAM_COLORS, OPENGL, pts3d_to_trimesh, cat_meshes
24
  from modules.dust3r.cloud_opt import global_aligner, GlobalAlignerMode
25
+ from copy import deepcopy
26
+ import cv2
27
+ from typing import Any, Dict, Generator,List
28
+ import matplotlib.pyplot as pl
29
 
30
+ from modules.mobilesamv2.utils.transforms import ResizeLongestSide
31
+ from modules.pe3r.models import Models
32
  import torchvision.transforms as tvf
33
 
34
+ sys.path.append(os.path.abspath('./modules/ultralytics'))
35
 
36
+ from transformers import AutoTokenizer, AutoModel, AutoProcessor, SamModel
37
+ from modules.mast3r.model import AsymmetricMASt3R
38
 
39
+ from modules.sam2.build_sam import build_sam2_video_predictor
40
+ from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
41
+ from modules.mobilesamv2 import sam_model_registry
42
 
43
+ from sam2.sam2_video_predictor import SAM2VideoPredictor
44
  from modules.mast3r.model import AsymmetricMASt3R
45
 
46
 
 
117
  return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
118
  transparent_cams=transparent_cams, cam_size=cam_size)
119
 
120
+ def mask_nms(masks, threshold=0.8):
121
+ keep = []
122
+ mask_num = len(masks)
123
+ suppressed = np.zeros((mask_num), dtype=np.int64)
124
+ for i in range(mask_num):
125
+ if suppressed[i] == 1:
126
+ continue
127
+ keep.append(i)
128
+ for j in range(i + 1, mask_num):
129
+ if suppressed[j] == 1:
130
+ continue
131
+ intersection = (masks[i] & masks[j]).sum()
132
+ if min(intersection / masks[i].sum(), intersection / masks[j].sum()) > threshold:
133
+ suppressed[j] = 1
134
+ return keep
135
+
136
+ def filter(masks, keep):
137
+ ret = []
138
+ for i, m in enumerate(masks):
139
+ if i in keep: ret.append(m)
140
+ return ret
141
+
142
+ def mask_to_box(mask):
143
+ if mask.sum() == 0:
144
+ return np.array([0, 0, 0, 0])
145
 
146
+ # Get the rows and columns where the mask is 1
147
+ rows = np.any(mask, axis=1)
148
+ cols = np.any(mask, axis=0)
149
 
150
+ # Get top, bottom, left, right edges
151
+ top = np.argmax(rows)
152
+ bottom = len(rows) - 1 - np.argmax(np.flip(rows))
153
+ left = np.argmax(cols)
154
+ right = len(cols) - 1 - np.argmax(np.flip(cols))
155
 
156
+ return np.array([left, top, right, bottom])
157
+
158
+ def box_xyxy_to_xywh(box_xyxy):
159
+ box_xywh = deepcopy(box_xyxy)
160
+ box_xywh[2] = box_xywh[2] - box_xywh[0]
161
+ box_xywh[3] = box_xywh[3] - box_xywh[1]
162
+ return box_xywh
163
+
164
+ def get_seg_img(mask, box, image):
165
+ image = image.copy()
166
+ x, y, w, h = box
167
+ # image[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
168
+ box_area = w * h
169
+ mask_area = mask.sum()
170
+ if 1 - (mask_area / box_area) < 0.2:
171
+ image[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
172
+ else:
173
+ random_values = np.random.randint(0, 255, size=image.shape, dtype=np.uint8)
174
+ image[mask == 0] = random_values[mask == 0]
175
+ seg_img = image[y:y+h, x:x+w, ...]
176
+ return seg_img
177
+
178
+ def pad_img(img):
179
+ h, w, _ = img.shape
180
+ l = max(w,h)
181
+ pad = np.zeros((l,l,3), dtype=np.uint8) #
182
+ if h > w:
183
+ pad[:,(h-w)//2:(h-w)//2 + w, :] = img
184
+ else:
185
+ pad[(w-h)//2:(w-h)//2 + h, :, :] = img
186
+ return pad
187
+
188
+ def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
189
+ assert len(args) > 0 and all(
190
+ len(a) == len(args[0]) for a in args
191
+ ), "Batched iteration must have inputs of all the same size."
192
+ n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
193
+ for b in range(n_batches):
194
+ yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
195
+
196
+ def slerp(u1, u2, t):
197
+ """
198
+ Perform spherical linear interpolation (Slerp) between two unit vectors.
199
 
200
+ Args:
201
+ - u1 (torch.Tensor): First unit vector, shape (1024,)
202
+ - u2 (torch.Tensor): Second unit vector, shape (1024,)
203
+ - t (float): Interpolation parameter
204
 
205
+ Returns:
206
+ - torch.Tensor: Interpolated vector, shape (1024,)
207
+ """
208
+ # Compute the dot product
209
+ dot_product = torch.sum(u1 * u2)
210
 
211
+ # Ensure the dot product is within the valid range [-1, 1]
212
+ dot_product = torch.clamp(dot_product, -1.0, 1.0)
213
 
214
+ # Compute the angle between the vectors
215
+ theta = torch.acos(dot_product)
216
 
217
+ # Compute the coefficients for the interpolation
218
+ sin_theta = torch.sin(theta)
219
+ if sin_theta == 0:
220
+ # Vectors are parallel, return a linear interpolation
221
+ return u1 + t * (u2 - u1)
222
 
223
+ s1 = torch.sin((1 - t) * theta) / sin_theta
224
+ s2 = torch.sin(t * theta) / sin_theta
225
 
226
+ # Perform the interpolation
227
+ return s1 * u1 + s2 * u2
228
 
229
+ def slerp_multiple(vectors, t_values):
230
+ """
231
+ Perform spherical linear interpolation (Slerp) for multiple vectors.
232
 
233
+ Args:
234
+ - vectors (torch.Tensor): Tensor of vectors, shape (n, 1024)
235
+ - a_values (torch.Tensor): Tensor of values corresponding to each vector, shape (n,)
236
 
237
+ Returns:
238
+ - torch.Tensor: Interpolated vector, shape (1024,)
239
+ """
240
+ n = vectors.shape[0]
241
 
242
+ # Initialize the interpolated vector with the first vector
243
+ interpolated_vector = vectors[0]
244
 
245
+ # Perform Slerp iteratively
246
+ for i in range(1, n):
247
+ # Perform Slerp between the current interpolated vector and the next vector
248
+ t = t_values[i] / (t_values[i] + t_values[i-1])
249
+ interpolated_vector = slerp(interpolated_vector, vectors[i], t)
250
 
251
+ return interpolated_vector
252
 
253
+ @torch.no_grad
254
+ def get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_image, yolov8_image, original_size, input_size, transform):
255
 
256
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
257
 
258
 
259
+ sam_mask=[]
260
+ img_area = original_size[0] * original_size[1]
261
 
262
+ obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=1024,conf=0.25,iou=0.95,verbose=False)
263
+ input_boxes1 = obj_results[0].boxes.xyxy
264
+ input_boxes1 = input_boxes1.cpu().numpy()
265
+ input_boxes1 = transform.apply_boxes(input_boxes1, original_size)
266
+ input_boxes = torch.from_numpy(input_boxes1).to(device)
267
 
268
+ # obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=512,conf=0.25,iou=0.9,verbose=False)
269
+ # input_boxes2 = obj_results[0].boxes.xyxy
270
+ # input_boxes2 = input_boxes2.cpu().numpy()
271
+ # input_boxes2 = transform.apply_boxes(input_boxes2, original_size)
272
+ # input_boxes2 = torch.from_numpy(input_boxes2).to(device)
273
+
274
+ # input_boxes = torch.cat((input_boxes1, input_boxes2), dim=0)
275
+
276
+ input_image = mobilesamv2.preprocess(sam1_image)
277
+ image_embedding = mobilesamv2.image_encoder(input_image)['last_hidden_state']
278
+
279
+ image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
280
+ prompt_embedding=mobilesamv2.prompt_encoder.get_dense_pe()
281
+ prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
282
+ for (boxes,) in batch_iterator(320, input_boxes):
283
+ with torch.no_grad():
284
+ image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
285
+ prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
286
+ sparse_embeddings, dense_embeddings = mobilesamv2.prompt_encoder(
287
+ points=None,
288
+ boxes=boxes,
289
+ masks=None,)
290
+ low_res_masks, _ = mobilesamv2.mask_decoder(
291
+ image_embeddings=image_embedding,
292
+ image_pe=prompt_embedding,
293
+ sparse_prompt_embeddings=sparse_embeddings,
294
+ dense_prompt_embeddings=dense_embeddings,
295
+ multimask_output=False,
296
+ simple_type=True,
297
+ )
298
+ low_res_masks=mobilesamv2.postprocess_masks(low_res_masks, input_size, original_size)
299
+ sam_mask_pre = (low_res_masks > mobilesamv2.mask_threshold)
300
+ for mask in sam_mask_pre:
301
+ if mask.sum() / img_area > 0.002:
302
+ sam_mask.append(mask.squeeze(1))
303
+ sam_mask=torch.cat(sam_mask)
304
+ sorted_sam_mask = sorted(sam_mask, key=(lambda x: x.sum()), reverse=True)
305
+ keep = mask_nms(sorted_sam_mask)
306
+ ret_mask = filter(sorted_sam_mask, keep)
307
+
308
+ return ret_mask
309
+
310
+ @torch.no_grad
311
+ def get_cog_feats(images, sam2, siglip, siglip_processor, yolov8, mobilesamv2):
312
+
313
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
314
+
315
+ cog_seg_maps = []
316
+ rev_cog_seg_maps = []
317
+ inference_state = sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
318
+ mask_num = 0
319
+
320
+ sam1_images = images.sam1_images
321
+ sam1_images_size = images.sam1_images_size
322
+ np_images = images.np_images
323
+ np_images_size = images.np_images_size
324
 
325
+ sam1_masks = get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform)
326
+ for mask in sam1_masks:
327
+ _, _, _ = sam2.add_new_mask(
328
+ inference_state=inference_state,
329
+ frame_idx=0,
330
+ obj_id=mask_num,
331
+ mask=mask,
332
+ )
333
+ mask_num += 1
334
+
335
+ video_segments = {} # video_segments contains the per-frame segmentation results
336
+ for out_frame_idx, out_obj_ids, out_mask_logits in sam2.propagate_in_video(inference_state):
337
+ sam2_masks = (out_mask_logits > 0.0).squeeze(1)
338
+
339
+ video_segments[out_frame_idx] = {
340
+ out_obj_id: sam2_masks[i].cpu().numpy()
341
+ for i, out_obj_id in enumerate(out_obj_ids)
342
+ }
343
+
344
+ if out_frame_idx == 0:
345
+ continue
346
+
347
+ sam1_masks = get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_images[out_frame_idx], np_images[out_frame_idx], np_images_size[out_frame_idx], sam1_images_size[out_frame_idx], images.sam1_transform)
348
+
349
+ for sam1_mask in sam1_masks:
350
+ flg = 1
351
+ for sam2_mask in sam2_masks:
352
+ # print(sam1_mask.shape, sam2_mask.shape)
353
+ area1 = sam1_mask.sum()
354
+ area2 = sam2_mask.sum()
355
+ intersection = (sam1_mask & sam2_mask).sum()
356
+ if min(intersection / area1, intersection / area2) > 0.25:
357
+ flg = 0
358
+ break
359
+ if flg:
360
+ video_segments[out_frame_idx][mask_num] = sam1_mask.cpu().numpy()
361
+ mask_num += 1
362
+
363
+ multi_view_clip_feats = torch.zeros((mask_num+1, 1024))
364
+ multi_view_clip_feats_map = {}
365
+ multi_view_clip_area_map = {}
366
+ for now_frame in range(0, len(video_segments), 1):
367
+ image = np_images[now_frame]
368
+
369
+ seg_img_list = []
370
+ out_obj_id_list = []
371
+ out_obj_mask_list = []
372
+ out_obj_area_list = []
373
+ # NOTE: background: -1
374
+ rev_seg_map = -np.ones(image.shape[:2], dtype=np.int64)
375
+ sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=False)
376
+ for out_obj_id, mask in sorted_dict_items:
377
+ if mask.sum() == 0:
378
+ continue
379
+ rev_seg_map[mask] = out_obj_id
380
+ rev_cog_seg_maps.append(rev_seg_map)
381
+
382
+ seg_map = -np.ones(image.shape[:2], dtype=np.int64)
383
+ sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=True)
384
+ for out_obj_id, mask in sorted_dict_items:
385
+ if mask.sum() == 0:
386
+ continue
387
+ box = np.int32(box_xyxy_to_xywh(mask_to_box(mask)))
388
 
389
+ if box[2] == 0 and box[3] == 0:
390
+ continue
391
+ # print(box)
392
+ seg_img = get_seg_img(mask, box, image)
393
+ pad_seg_img = cv2.resize(pad_img(seg_img), (256,256))
394
+ seg_img_list.append(pad_seg_img)
395
+ seg_map[mask] = out_obj_id
396
+ out_obj_id_list.append(out_obj_id)
397
+ out_obj_area_list.append(np.count_nonzero(mask))
398
+ out_obj_mask_list.append(mask)
399
+
400
+ if len(seg_img_list) == 0:
401
+ cog_seg_maps.append(seg_map)
402
+ continue
403
+
404
+ seg_imgs = np.stack(seg_img_list, axis=0) # b,H,W,3
405
+ seg_imgs = torch.from_numpy(seg_imgs).permute(0,3,1,2) # / 255.0
406
 
407
+ inputs = siglip_processor(images=seg_imgs, return_tensors="pt")
408
+ inputs = {key: value.to(device) for key, value in inputs.items()}
409
 
410
+ image_features = siglip.get_image_features(**inputs)
411
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
412
+ image_features = image_features.detach().cpu()
413
+
414
+ for i in range(len(out_obj_mask_list)):
415
+ for j in range(i + 1, len(out_obj_mask_list)):
416
+ mask1 = out_obj_mask_list[i]
417
+ mask2 = out_obj_mask_list[j]
418
+ intersection = np.logical_and(mask1, mask2).sum()
419
+ area1 = out_obj_area_list[i]
420
+ area2 = out_obj_area_list[j]
421
+ if min(intersection / area1, intersection / area2) > 0.025:
422
+ conf1 = area1 / (area1 + area2)
423
+ # conf2 = area2 / (area1 + area2)
424
+ image_features[j] = slerp(image_features[j], image_features[i], conf1)
425
+
426
+ for i, clip_feat in enumerate(image_features):
427
+ id = out_obj_id_list[i]
428
+ if id in multi_view_clip_feats_map.keys():
429
+ multi_view_clip_feats_map[id].append(clip_feat)
430
+ multi_view_clip_area_map[id].append(out_obj_area_list[i])
431
+ else:
432
+ multi_view_clip_feats_map[id] = [clip_feat]
433
+ multi_view_clip_area_map[id] = [out_obj_area_list[i]]
434
+
435
+ cog_seg_maps.append(seg_map)
436
+ del image_features
437
 
438
+ for i in range(mask_num):
439
+ if i in multi_view_clip_feats_map.keys():
440
+ clip_feats = multi_view_clip_feats_map[i]
441
+ mask_area = multi_view_clip_area_map[i]
442
+ multi_view_clip_feats[i] = slerp_multiple(torch.stack(clip_feats), np.stack(mask_area))
443
+ else:
444
+ multi_view_clip_feats[i] = torch.zeros((1024))
445
+ multi_view_clip_feats[mask_num] = torch.zeros((1024))
446
 
447
+ return cog_seg_maps, rev_cog_seg_maps, multi_view_clip_feats
448
 
449
 
450
+ @spaces.GPU(duration=30)
451
  def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_conf_thr=3.0,
452
  as_pointcloud=True, mask_sky=False, clean_depth=True, transparent_cams=True, cam_size=0.05,
453
  scenegraph_type='complete', winsize=1, refid=0):
 
461
  MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
462
  mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device)
463
 
464
+ sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large', device=device)
465
 
466
+ siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
467
+ siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256")
468
 
469
+ SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
470
+ mobilesamv2 = sam_model_registry['sam_vit_h'](None)
471
+ sam1 = SamModel.from_pretrained('facebook/sam-vit-huge')
472
+ image_encoder = sam1.vision_encoder
473
 
474
+ prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
475
+ mobilesamv2.prompt_encoder = prompt_encoder
476
+ mobilesamv2.mask_decoder = mask_decoder
477
+ mobilesamv2.image_encoder=image_encoder
478
+ mobilesamv2.to(device=device)
479
+ mobilesamv2.eval()
480
 
481
+ YOLO8_CKP='./checkpoints/ObjectAwareModel.pt'
482
+ yolov8 = ObjectAwareModel(YOLO8_CKP)
483
 
484
  if len(filelist) < 2:
485
  raise gradio.Error("Please input at least 2 images.")
 
487
  images = Images(filelist=filelist, device=device)
488
 
489
  # try:
490
+ cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images, sam2, siglip, siglip_processor, yolov8, mobilesamv2)
 
 
 
 
 
 
 
 
491
  imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
492
+ # except Exception as e:
493
+ # rev_cog_seg_maps = []
494
+ # for tmp_img in images.np_images:
495
+ # rev_seg_map = -np.ones(tmp_img.shape[:2], dtype=np.int64)
496
+ # rev_cog_seg_maps.append(rev_seg_map)
497
+ # cog_seg_maps = rev_cog_seg_maps
498
+ # cog_feats = torch.zeros((1, 1024))
499
+ # imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
500
 
501
  if len(imgs) == 1:
502
  imgs = [imgs[0], copy.deepcopy(imgs[0])]
 
537
  outfile = get_3D_model_from_scene(outdir, scene, min_conf_thr, as_pointcloud, mask_sky,
538
  clean_depth, transparent_cams, cam_size)
539
 
 
 
 
 
 
540
  torch.cuda.empty_cache()
541
+
542
+
543
+
544
  return outfile
545
 
546
+
547
  # def get_3D_object_from_scene(outdir, text, threshold, scene, min_conf_thr, as_pointcloud,
548
  # mask_sky, clean_depth, transparent_cams, cam_size):
549
 
550
+ # device = 'cpu'
551
  # siglip_tokenizer = AutoTokenizer.from_pretrained("google/siglip-large-patch16-256")
552
  # siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
553