hujiecpp commited on
Commit
ba148f1
·
1 Parent(s): cb96c94

init project

Browse files
Files changed (1) hide show
  1. app.py +360 -370
app.py CHANGED
@@ -1,9 +1,3 @@
1
- # Copyright (C) 2024-present Naver Corporation. All rights reserved.
2
- # Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
3
- #
4
- # --------------------------------------------------------
5
- # gradio demo
6
- # --------------------------------------------------------
7
  import os
8
  import sys
9
  sys.path.append(os.path.abspath('./modules'))
@@ -37,23 +31,22 @@ from modules.mobilesamv2.utils.transforms import ResizeLongestSide
37
  # from modules.pe3r.models import Models
38
  import torchvision.transforms as tvf
39
 
 
40
 
41
-
42
- sys.path.append(os.path.abspath('./modules/ultralytics'))
43
-
44
- from transformers import AutoTokenizer, AutoModel, AutoProcessor, SamModel
45
  # from modules.mast3r.model import AsymmetricMASt3R
46
 
47
  # from modules.sam2.build_sam import build_sam2_video_predictor
48
- from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
49
- from modules.mobilesamv2 import sam_model_registry
50
 
51
- from sam2.sam2_video_predictor import SAM2VideoPredictor
52
  from modules.mast3r.model import AsymmetricMASt3R
53
 
54
 
55
  silent = False
56
- device = 'cpu' #'cuda' if torch.cuda.is_available() else 'cpu' # #
 
57
  # pe3r = Models('cpu') # 'cpu' #
58
  # print(device)
59
 
@@ -124,369 +117,369 @@ def get_3D_model_from_scene(outdir, scene, min_conf_thr=3, as_pointcloud=False,
124
  return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
125
  transparent_cams=transparent_cams, cam_size=cam_size)
126
 
127
- def mask_nms(masks, threshold=0.8):
128
- keep = []
129
- mask_num = len(masks)
130
- suppressed = np.zeros((mask_num), dtype=np.int64)
131
- for i in range(mask_num):
132
- if suppressed[i] == 1:
133
- continue
134
- keep.append(i)
135
- for j in range(i + 1, mask_num):
136
- if suppressed[j] == 1:
137
- continue
138
- intersection = (masks[i] & masks[j]).sum()
139
- if min(intersection / masks[i].sum(), intersection / masks[j].sum()) > threshold:
140
- suppressed[j] = 1
141
- return keep
142
-
143
- def filter(masks, keep):
144
- ret = []
145
- for i, m in enumerate(masks):
146
- if i in keep: ret.append(m)
147
- return ret
148
-
149
- def mask_to_box(mask):
150
- if mask.sum() == 0:
151
- return np.array([0, 0, 0, 0])
152
 
153
- # Get the rows and columns where the mask is 1
154
- rows = np.any(mask, axis=1)
155
- cols = np.any(mask, axis=0)
156
 
157
- # Get top, bottom, left, right edges
158
- top = np.argmax(rows)
159
- bottom = len(rows) - 1 - np.argmax(np.flip(rows))
160
- left = np.argmax(cols)
161
- right = len(cols) - 1 - np.argmax(np.flip(cols))
162
 
163
- return np.array([left, top, right, bottom])
164
-
165
- def box_xyxy_to_xywh(box_xyxy):
166
- box_xywh = deepcopy(box_xyxy)
167
- box_xywh[2] = box_xywh[2] - box_xywh[0]
168
- box_xywh[3] = box_xywh[3] - box_xywh[1]
169
- return box_xywh
170
-
171
- def get_seg_img(mask, box, image):
172
- image = image.copy()
173
- x, y, w, h = box
174
- # image[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
175
- box_area = w * h
176
- mask_area = mask.sum()
177
- if 1 - (mask_area / box_area) < 0.2:
178
- image[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
179
- else:
180
- random_values = np.random.randint(0, 255, size=image.shape, dtype=np.uint8)
181
- image[mask == 0] = random_values[mask == 0]
182
- seg_img = image[y:y+h, x:x+w, ...]
183
- return seg_img
184
-
185
- def pad_img(img):
186
- h, w, _ = img.shape
187
- l = max(w,h)
188
- pad = np.zeros((l,l,3), dtype=np.uint8) #
189
- if h > w:
190
- pad[:,(h-w)//2:(h-w)//2 + w, :] = img
191
- else:
192
- pad[(w-h)//2:(w-h)//2 + h, :, :] = img
193
- return pad
194
-
195
- def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
196
- assert len(args) > 0 and all(
197
- len(a) == len(args[0]) for a in args
198
- ), "Batched iteration must have inputs of all the same size."
199
- n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
200
- for b in range(n_batches):
201
- yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
202
-
203
- def slerp(u1, u2, t):
204
- """
205
- Perform spherical linear interpolation (Slerp) between two unit vectors.
206
 
207
- Args:
208
- - u1 (torch.Tensor): First unit vector, shape (1024,)
209
- - u2 (torch.Tensor): Second unit vector, shape (1024,)
210
- - t (float): Interpolation parameter
211
 
212
- Returns:
213
- - torch.Tensor: Interpolated vector, shape (1024,)
214
- """
215
- # Compute the dot product
216
- dot_product = torch.sum(u1 * u2)
217
 
218
- # Ensure the dot product is within the valid range [-1, 1]
219
- dot_product = torch.clamp(dot_product, -1.0, 1.0)
220
 
221
- # Compute the angle between the vectors
222
- theta = torch.acos(dot_product)
223
 
224
- # Compute the coefficients for the interpolation
225
- sin_theta = torch.sin(theta)
226
- if sin_theta == 0:
227
- # Vectors are parallel, return a linear interpolation
228
- return u1 + t * (u2 - u1)
229
 
230
- s1 = torch.sin((1 - t) * theta) / sin_theta
231
- s2 = torch.sin(t * theta) / sin_theta
232
 
233
- # Perform the interpolation
234
- return s1 * u1 + s2 * u2
235
 
236
- def slerp_multiple(vectors, t_values):
237
- """
238
- Perform spherical linear interpolation (Slerp) for multiple vectors.
239
 
240
- Args:
241
- - vectors (torch.Tensor): Tensor of vectors, shape (n, 1024)
242
- - a_values (torch.Tensor): Tensor of values corresponding to each vector, shape (n,)
243
 
244
- Returns:
245
- - torch.Tensor: Interpolated vector, shape (1024,)
246
- """
247
- n = vectors.shape[0]
248
 
249
- # Initialize the interpolated vector with the first vector
250
- interpolated_vector = vectors[0]
251
 
252
- # Perform Slerp iteratively
253
- for i in range(1, n):
254
- # Perform Slerp between the current interpolated vector and the next vector
255
- t = t_values[i] / (t_values[i] + t_values[i-1])
256
- interpolated_vector = slerp(interpolated_vector, vectors[i], t)
257
 
258
- return interpolated_vector
259
 
260
- @torch.no_grad
261
- def get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_image, yolov8_image, original_size, input_size, transform):
262
 
263
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
264
 
265
 
266
- sam_mask=[]
267
- img_area = original_size[0] * original_size[1]
268
 
269
- obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=1024,conf=0.25,iou=0.95,verbose=False)
270
- input_boxes1 = obj_results[0].boxes.xyxy
271
- input_boxes1 = input_boxes1.cpu().numpy()
272
- input_boxes1 = transform.apply_boxes(input_boxes1, original_size)
273
- input_boxes = torch.from_numpy(input_boxes1).to(device)
274
 
275
- # obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=512,conf=0.25,iou=0.9,verbose=False)
276
- # input_boxes2 = obj_results[0].boxes.xyxy
277
- # input_boxes2 = input_boxes2.cpu().numpy()
278
- # input_boxes2 = transform.apply_boxes(input_boxes2, original_size)
279
- # input_boxes2 = torch.from_numpy(input_boxes2).to(device)
280
-
281
- # input_boxes = torch.cat((input_boxes1, input_boxes2), dim=0)
282
-
283
- input_image = mobilesamv2.preprocess(sam1_image)
284
- image_embedding = mobilesamv2.image_encoder(input_image)['last_hidden_state']
285
-
286
- image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
287
- prompt_embedding=mobilesamv2.prompt_encoder.get_dense_pe()
288
- prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
289
- for (boxes,) in batch_iterator(320, input_boxes):
290
- with torch.no_grad():
291
- image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
292
- prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
293
- sparse_embeddings, dense_embeddings = mobilesamv2.prompt_encoder(
294
- points=None,
295
- boxes=boxes,
296
- masks=None,)
297
- low_res_masks, _ = mobilesamv2.mask_decoder(
298
- image_embeddings=image_embedding,
299
- image_pe=prompt_embedding,
300
- sparse_prompt_embeddings=sparse_embeddings,
301
- dense_prompt_embeddings=dense_embeddings,
302
- multimask_output=False,
303
- simple_type=True,
304
- )
305
- low_res_masks=mobilesamv2.postprocess_masks(low_res_masks, input_size, original_size)
306
- sam_mask_pre = (low_res_masks > mobilesamv2.mask_threshold)
307
- for mask in sam_mask_pre:
308
- if mask.sum() / img_area > 0.002:
309
- sam_mask.append(mask.squeeze(1))
310
- sam_mask=torch.cat(sam_mask)
311
- sorted_sam_mask = sorted(sam_mask, key=(lambda x: x.sum()), reverse=True)
312
- keep = mask_nms(sorted_sam_mask)
313
- ret_mask = filter(sorted_sam_mask, keep)
314
-
315
- return ret_mask
316
-
317
- @torch.no_grad
318
- def get_cog_feats(images, sam2, siglip, siglip_processor, yolov8, mobilesamv2):
319
-
320
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
321
-
322
- cog_seg_maps = []
323
- rev_cog_seg_maps = []
324
- inference_state = sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
325
- mask_num = 0
326
 
327
- sam1_images = images.sam1_images
328
- sam1_images_size = images.sam1_images_size
329
- np_images = images.np_images
330
- np_images_size = images.np_images_size
331
-
332
- sam1_masks = get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform)
333
- for mask in sam1_masks:
334
- _, _, _ = sam2.add_new_mask(
335
- inference_state=inference_state,
336
- frame_idx=0,
337
- obj_id=mask_num,
338
- mask=mask,
339
- )
340
- mask_num += 1
341
-
342
- video_segments = {} # video_segments contains the per-frame segmentation results
343
- for out_frame_idx, out_obj_ids, out_mask_logits in sam2.propagate_in_video(inference_state):
344
- sam2_masks = (out_mask_logits > 0.0).squeeze(1)
345
-
346
- video_segments[out_frame_idx] = {
347
- out_obj_id: sam2_masks[i].cpu().numpy()
348
- for i, out_obj_id in enumerate(out_obj_ids)
349
- }
350
-
351
- if out_frame_idx == 0:
352
- continue
353
-
354
- sam1_masks = get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_images[out_frame_idx], np_images[out_frame_idx], np_images_size[out_frame_idx], sam1_images_size[out_frame_idx], images.sam1_transform)
355
-
356
- for sam1_mask in sam1_masks:
357
- flg = 1
358
- for sam2_mask in sam2_masks:
359
- # print(sam1_mask.shape, sam2_mask.shape)
360
- area1 = sam1_mask.sum()
361
- area2 = sam2_mask.sum()
362
- intersection = (sam1_mask & sam2_mask).sum()
363
- if min(intersection / area1, intersection / area2) > 0.25:
364
- flg = 0
365
- break
366
- if flg:
367
- video_segments[out_frame_idx][mask_num] = sam1_mask.cpu().numpy()
368
- mask_num += 1
369
-
370
- multi_view_clip_feats = torch.zeros((mask_num+1, 1024))
371
- multi_view_clip_feats_map = {}
372
- multi_view_clip_area_map = {}
373
- for now_frame in range(0, len(video_segments), 1):
374
- image = np_images[now_frame]
375
-
376
- seg_img_list = []
377
- out_obj_id_list = []
378
- out_obj_mask_list = []
379
- out_obj_area_list = []
380
- # NOTE: background: -1
381
- rev_seg_map = -np.ones(image.shape[:2], dtype=np.int64)
382
- sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=False)
383
- for out_obj_id, mask in sorted_dict_items:
384
- if mask.sum() == 0:
385
- continue
386
- rev_seg_map[mask] = out_obj_id
387
- rev_cog_seg_maps.append(rev_seg_map)
388
 
389
- seg_map = -np.ones(image.shape[:2], dtype=np.int64)
390
- sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=True)
391
- for out_obj_id, mask in sorted_dict_items:
392
- if mask.sum() == 0:
393
- continue
394
- box = np.int32(box_xyxy_to_xywh(mask_to_box(mask)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
- if box[2] == 0 and box[3] == 0:
397
- continue
398
- # print(box)
399
- seg_img = get_seg_img(mask, box, image)
400
- pad_seg_img = cv2.resize(pad_img(seg_img), (256,256))
401
- seg_img_list.append(pad_seg_img)
402
- seg_map[mask] = out_obj_id
403
- out_obj_id_list.append(out_obj_id)
404
- out_obj_area_list.append(np.count_nonzero(mask))
405
- out_obj_mask_list.append(mask)
406
-
407
- if len(seg_img_list) == 0:
408
- cog_seg_maps.append(seg_map)
409
- continue
410
-
411
- seg_imgs = np.stack(seg_img_list, axis=0) # b,H,W,3
412
- seg_imgs = torch.from_numpy(seg_imgs).permute(0,3,1,2) # / 255.0
413
 
414
- inputs = siglip_processor(images=seg_imgs, return_tensors="pt")
415
- inputs = {key: value.to(device) for key, value in inputs.items()}
416
 
417
- image_features = siglip.get_image_features(**inputs)
418
- image_features = image_features / image_features.norm(dim=-1, keepdim=True)
419
- image_features = image_features.detach().cpu()
420
-
421
- for i in range(len(out_obj_mask_list)):
422
- for j in range(i + 1, len(out_obj_mask_list)):
423
- mask1 = out_obj_mask_list[i]
424
- mask2 = out_obj_mask_list[j]
425
- intersection = np.logical_and(mask1, mask2).sum()
426
- area1 = out_obj_area_list[i]
427
- area2 = out_obj_area_list[j]
428
- if min(intersection / area1, intersection / area2) > 0.025:
429
- conf1 = area1 / (area1 + area2)
430
- # conf2 = area2 / (area1 + area2)
431
- image_features[j] = slerp(image_features[j], image_features[i], conf1)
432
-
433
- for i, clip_feat in enumerate(image_features):
434
- id = out_obj_id_list[i]
435
- if id in multi_view_clip_feats_map.keys():
436
- multi_view_clip_feats_map[id].append(clip_feat)
437
- multi_view_clip_area_map[id].append(out_obj_area_list[i])
438
- else:
439
- multi_view_clip_feats_map[id] = [clip_feat]
440
- multi_view_clip_area_map[id] = [out_obj_area_list[i]]
441
-
442
- cog_seg_maps.append(seg_map)
443
- del image_features
444
 
445
- for i in range(mask_num):
446
- if i in multi_view_clip_feats_map.keys():
447
- clip_feats = multi_view_clip_feats_map[i]
448
- mask_area = multi_view_clip_area_map[i]
449
- multi_view_clip_feats[i] = slerp_multiple(torch.stack(clip_feats), np.stack(mask_area))
450
- else:
451
- multi_view_clip_feats[i] = torch.zeros((1024))
452
- multi_view_clip_feats[mask_num] = torch.zeros((1024))
453
 
454
- return cog_seg_maps, rev_cog_seg_maps, multi_view_clip_feats
455
 
456
 
457
  @spaces.GPU(duration=60)
458
- def get_reconstructed_scene(outdir, filelist, schedule, niter, min_conf_thr,
459
- as_pointcloud, mask_sky, clean_depth, transparent_cams, cam_size,
460
- scenegraph_type, winsize, refid):
461
  """
462
  from a list of images, run dust3r inference, global aligner.
463
  then run get_3D_model_from_scene
464
  """
465
 
466
- # device = 'cuda' if torch.cuda.is_available() else 'cpu'
467
 
468
  MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
469
  mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device)
470
 
471
- sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large', device=device)
472
 
473
- siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
474
- siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256")
475
 
476
- SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
477
- mobilesamv2 = sam_model_registry['sam_vit_h'](None)
478
- sam1 = SamModel.from_pretrained('facebook/sam-vit-huge')
479
- image_encoder = sam1.vision_encoder
480
 
481
- prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
482
- mobilesamv2.prompt_encoder = prompt_encoder
483
- mobilesamv2.mask_decoder = mask_decoder
484
- mobilesamv2.image_encoder=image_encoder
485
- mobilesamv2.to(device=device)
486
- mobilesamv2.eval()
487
 
488
- YOLO8_CKP='./checkpoints/ObjectAwareModel.pt'
489
- yolov8 = ObjectAwareModel(YOLO8_CKP)
490
 
491
  if len(filelist) < 2:
492
  raise gradio.Error("Please input at least 2 images.")
@@ -494,16 +487,16 @@ def get_reconstructed_scene(outdir, filelist, schedule, niter, min_conf_thr,
494
  images = Images(filelist=filelist, device=device)
495
 
496
  # try:
497
- cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images, sam2, siglip, siglip_processor, yolov8, mobilesamv2)
498
- imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
499
  # except Exception as e:
500
- # rev_cog_seg_maps = []
501
- # for tmp_img in images.np_images:
502
- # rev_seg_map = -np.ones(tmp_img.shape[:2], dtype=np.int64)
503
- # rev_cog_seg_maps.append(rev_seg_map)
504
- # cog_seg_maps = rev_cog_seg_maps
505
- # cog_feats = torch.zeros((1, 1024))
506
- # imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
507
 
508
  if len(imgs) == 1:
509
  imgs = [imgs[0], copy.deepcopy(imgs[0])]
@@ -546,7 +539,6 @@ def get_reconstructed_scene(outdir, filelist, schedule, niter, min_conf_thr,
546
 
547
  scene.to('cpu')
548
  torch.cuda.empty_cache()
549
-
550
  return scene, outfile
551
 
552
  # @spaces.GPU(duration=60)
@@ -581,37 +573,37 @@ with tempfile.TemporaryDirectory(suffix='pe3r_gradio_demo') as tmpdirname:
581
  gradio.HTML('<h2 style="text-align: center;">PE3R Demo</h2>')
582
  with gradio.Column():
583
  inputfiles = gradio.File(file_count="multiple")
584
- with gradio.Row():
585
- schedule = gradio.Dropdown(["linear", "cosine"],
586
- value='linear', label="schedule", info="For global alignment!",
587
- visible=False)
588
- niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
589
- label="num_iterations", info="For global alignment!",
590
- visible=False)
591
- scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
592
- ("swin: sliding window", "swin"),
593
- ("oneref: match one image with all", "oneref")],
594
- value='complete', label="Scenegraph",
595
- info="Define how to make pairs",
596
- interactive=True,
597
- visible=False)
598
- winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
599
- minimum=1, maximum=1, step=1, visible=False)
600
- refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
601
 
602
  run_btn = gradio.Button("Reconstruct")
603
 
604
- with gradio.Row():
605
  # adjust the confidence threshold
606
- min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
607
  # adjust the camera size in the output pointcloud
608
- cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001, visible=False)
609
- with gradio.Row():
610
- as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud", visible=False)
611
  # two post process implemented
612
- mask_sky = gradio.Checkbox(value=False, label="Mask sky", visible=False)
613
- clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
614
- transparent_cams = gradio.Checkbox(value=True, label="Transparent cameras", visible=False)
615
 
616
  with gradio.Row():
617
  text_input = gradio.Textbox(label="Query Text")
@@ -623,9 +615,7 @@ with tempfile.TemporaryDirectory(suffix='pe3r_gradio_demo') as tmpdirname:
623
  # events
624
 
625
  run_btn.click(fn=recon_fun,
626
- inputs=[inputfiles, schedule, niter, min_conf_thr, as_pointcloud,
627
- mask_sky, clean_depth, transparent_cams, cam_size,
628
- scenegraph_type, winsize, refid],
629
  outputs=[scene, outmodel]) # , outgallery
630
 
631
  # find_btn.click(fn=get_3D_object_from_scene_fun,
 
 
 
 
 
 
 
1
  import os
2
  import sys
3
  sys.path.append(os.path.abspath('./modules'))
 
31
  # from modules.pe3r.models import Models
32
  import torchvision.transforms as tvf
33
 
34
+ # sys.path.append(os.path.abspath('./modules/ultralytics'))
35
 
36
+ # from transformers import AutoTokenizer, AutoModel, AutoProcessor, SamModel
 
 
 
37
  # from modules.mast3r.model import AsymmetricMASt3R
38
 
39
  # from modules.sam2.build_sam import build_sam2_video_predictor
40
+ # from modules.mobilesamv2.promt_mobilesamv2 import ObjectAwareModel
41
+ # from modules.mobilesamv2 import sam_model_registry
42
 
43
+ # from sam2.sam2_video_predictor import SAM2VideoPredictor
44
  from modules.mast3r.model import AsymmetricMASt3R
45
 
46
 
47
  silent = False
48
+
49
+ # device = 'cpu' #'cuda' if torch.cuda.is_available() else 'cpu' # #
50
  # pe3r = Models('cpu') # 'cpu' #
51
  # print(device)
52
 
 
117
  return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
118
  transparent_cams=transparent_cams, cam_size=cam_size)
119
 
120
+ # def mask_nms(masks, threshold=0.8):
121
+ # keep = []
122
+ # mask_num = len(masks)
123
+ # suppressed = np.zeros((mask_num), dtype=np.int64)
124
+ # for i in range(mask_num):
125
+ # if suppressed[i] == 1:
126
+ # continue
127
+ # keep.append(i)
128
+ # for j in range(i + 1, mask_num):
129
+ # if suppressed[j] == 1:
130
+ # continue
131
+ # intersection = (masks[i] & masks[j]).sum()
132
+ # if min(intersection / masks[i].sum(), intersection / masks[j].sum()) > threshold:
133
+ # suppressed[j] = 1
134
+ # return keep
135
+
136
+ # def filter(masks, keep):
137
+ # ret = []
138
+ # for i, m in enumerate(masks):
139
+ # if i in keep: ret.append(m)
140
+ # return ret
141
+
142
+ # def mask_to_box(mask):
143
+ # if mask.sum() == 0:
144
+ # return np.array([0, 0, 0, 0])
145
 
146
+ # # Get the rows and columns where the mask is 1
147
+ # rows = np.any(mask, axis=1)
148
+ # cols = np.any(mask, axis=0)
149
 
150
+ # # Get top, bottom, left, right edges
151
+ # top = np.argmax(rows)
152
+ # bottom = len(rows) - 1 - np.argmax(np.flip(rows))
153
+ # left = np.argmax(cols)
154
+ # right = len(cols) - 1 - np.argmax(np.flip(cols))
155
 
156
+ # return np.array([left, top, right, bottom])
157
+
158
+ # def box_xyxy_to_xywh(box_xyxy):
159
+ # box_xywh = deepcopy(box_xyxy)
160
+ # box_xywh[2] = box_xywh[2] - box_xywh[0]
161
+ # box_xywh[3] = box_xywh[3] - box_xywh[1]
162
+ # return box_xywh
163
+
164
+ # def get_seg_img(mask, box, image):
165
+ # image = image.copy()
166
+ # x, y, w, h = box
167
+ # # image[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
168
+ # box_area = w * h
169
+ # mask_area = mask.sum()
170
+ # if 1 - (mask_area / box_area) < 0.2:
171
+ # image[mask == 0] = np.array([0, 0, 0], dtype=np.uint8)
172
+ # else:
173
+ # random_values = np.random.randint(0, 255, size=image.shape, dtype=np.uint8)
174
+ # image[mask == 0] = random_values[mask == 0]
175
+ # seg_img = image[y:y+h, x:x+w, ...]
176
+ # return seg_img
177
+
178
+ # def pad_img(img):
179
+ # h, w, _ = img.shape
180
+ # l = max(w,h)
181
+ # pad = np.zeros((l,l,3), dtype=np.uint8) #
182
+ # if h > w:
183
+ # pad[:,(h-w)//2:(h-w)//2 + w, :] = img
184
+ # else:
185
+ # pad[(w-h)//2:(w-h)//2 + h, :, :] = img
186
+ # return pad
187
+
188
+ # def batch_iterator(batch_size: int, *args) -> Generator[List[Any], None, None]:
189
+ # assert len(args) > 0 and all(
190
+ # len(a) == len(args[0]) for a in args
191
+ # ), "Batched iteration must have inputs of all the same size."
192
+ # n_batches = len(args[0]) // batch_size + int(len(args[0]) % batch_size != 0)
193
+ # for b in range(n_batches):
194
+ # yield [arg[b * batch_size : (b + 1) * batch_size] for arg in args]
195
+
196
+ # def slerp(u1, u2, t):
197
+ # """
198
+ # Perform spherical linear interpolation (Slerp) between two unit vectors.
199
 
200
+ # Args:
201
+ # - u1 (torch.Tensor): First unit vector, shape (1024,)
202
+ # - u2 (torch.Tensor): Second unit vector, shape (1024,)
203
+ # - t (float): Interpolation parameter
204
 
205
+ # Returns:
206
+ # - torch.Tensor: Interpolated vector, shape (1024,)
207
+ # """
208
+ # # Compute the dot product
209
+ # dot_product = torch.sum(u1 * u2)
210
 
211
+ # # Ensure the dot product is within the valid range [-1, 1]
212
+ # dot_product = torch.clamp(dot_product, -1.0, 1.0)
213
 
214
+ # # Compute the angle between the vectors
215
+ # theta = torch.acos(dot_product)
216
 
217
+ # # Compute the coefficients for the interpolation
218
+ # sin_theta = torch.sin(theta)
219
+ # if sin_theta == 0:
220
+ # # Vectors are parallel, return a linear interpolation
221
+ # return u1 + t * (u2 - u1)
222
 
223
+ # s1 = torch.sin((1 - t) * theta) / sin_theta
224
+ # s2 = torch.sin(t * theta) / sin_theta
225
 
226
+ # # Perform the interpolation
227
+ # return s1 * u1 + s2 * u2
228
 
229
+ # def slerp_multiple(vectors, t_values):
230
+ # """
231
+ # Perform spherical linear interpolation (Slerp) for multiple vectors.
232
 
233
+ # Args:
234
+ # - vectors (torch.Tensor): Tensor of vectors, shape (n, 1024)
235
+ # - a_values (torch.Tensor): Tensor of values corresponding to each vector, shape (n,)
236
 
237
+ # Returns:
238
+ # - torch.Tensor: Interpolated vector, shape (1024,)
239
+ # """
240
+ # n = vectors.shape[0]
241
 
242
+ # # Initialize the interpolated vector with the first vector
243
+ # interpolated_vector = vectors[0]
244
 
245
+ # # Perform Slerp iteratively
246
+ # for i in range(1, n):
247
+ # # Perform Slerp between the current interpolated vector and the next vector
248
+ # t = t_values[i] / (t_values[i] + t_values[i-1])
249
+ # interpolated_vector = slerp(interpolated_vector, vectors[i], t)
250
 
251
+ # return interpolated_vector
252
 
253
+ # @torch.no_grad
254
+ # def get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_image, yolov8_image, original_size, input_size, transform):
255
 
256
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
257
 
258
 
259
+ # sam_mask=[]
260
+ # img_area = original_size[0] * original_size[1]
261
 
262
+ # obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=1024,conf=0.25,iou=0.95,verbose=False)
263
+ # input_boxes1 = obj_results[0].boxes.xyxy
264
+ # input_boxes1 = input_boxes1.cpu().numpy()
265
+ # input_boxes1 = transform.apply_boxes(input_boxes1, original_size)
266
+ # input_boxes = torch.from_numpy(input_boxes1).to(device)
267
 
268
+ # # obj_results = yolov8(yolov8_image,device=device,retina_masks=False,imgsz=512,conf=0.25,iou=0.9,verbose=False)
269
+ # # input_boxes2 = obj_results[0].boxes.xyxy
270
+ # # input_boxes2 = input_boxes2.cpu().numpy()
271
+ # # input_boxes2 = transform.apply_boxes(input_boxes2, original_size)
272
+ # # input_boxes2 = torch.from_numpy(input_boxes2).to(device)
273
+
274
+ # # input_boxes = torch.cat((input_boxes1, input_boxes2), dim=0)
275
+
276
+ # input_image = mobilesamv2.preprocess(sam1_image)
277
+ # image_embedding = mobilesamv2.image_encoder(input_image)['last_hidden_state']
278
+
279
+ # image_embedding=torch.repeat_interleave(image_embedding, 320, dim=0)
280
+ # prompt_embedding=mobilesamv2.prompt_encoder.get_dense_pe()
281
+ # prompt_embedding=torch.repeat_interleave(prompt_embedding, 320, dim=0)
282
+ # for (boxes,) in batch_iterator(320, input_boxes):
283
+ # with torch.no_grad():
284
+ # image_embedding=image_embedding[0:boxes.shape[0],:,:,:]
285
+ # prompt_embedding=prompt_embedding[0:boxes.shape[0],:,:,:]
286
+ # sparse_embeddings, dense_embeddings = mobilesamv2.prompt_encoder(
287
+ # points=None,
288
+ # boxes=boxes,
289
+ # masks=None,)
290
+ # low_res_masks, _ = mobilesamv2.mask_decoder(
291
+ # image_embeddings=image_embedding,
292
+ # image_pe=prompt_embedding,
293
+ # sparse_prompt_embeddings=sparse_embeddings,
294
+ # dense_prompt_embeddings=dense_embeddings,
295
+ # multimask_output=False,
296
+ # simple_type=True,
297
+ # )
298
+ # low_res_masks=mobilesamv2.postprocess_masks(low_res_masks, input_size, original_size)
299
+ # sam_mask_pre = (low_res_masks > mobilesamv2.mask_threshold)
300
+ # for mask in sam_mask_pre:
301
+ # if mask.sum() / img_area > 0.002:
302
+ # sam_mask.append(mask.squeeze(1))
303
+ # sam_mask=torch.cat(sam_mask)
304
+ # sorted_sam_mask = sorted(sam_mask, key=(lambda x: x.sum()), reverse=True)
305
+ # keep = mask_nms(sorted_sam_mask)
306
+ # ret_mask = filter(sorted_sam_mask, keep)
307
+
308
+ # return ret_mask
309
+
310
+ # @torch.no_grad
311
+ # def get_cog_feats(images, sam2, siglip, siglip_processor, yolov8, mobilesamv2):
 
 
 
 
 
 
 
312
 
313
+ # device = 'cuda' if torch.cuda.is_available() else 'cpu'
314
+
315
+ # cog_seg_maps = []
316
+ # rev_cog_seg_maps = []
317
+ # inference_state = sam2.init_state(images=images.sam2_images, video_height=images.sam2_video_size[0], video_width=images.sam2_video_size[1])
318
+ # mask_num = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
+ # sam1_images = images.sam1_images
321
+ # sam1_images_size = images.sam1_images_size
322
+ # np_images = images.np_images
323
+ # np_images_size = images.np_images_size
324
+
325
+ # sam1_masks = get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_images[0], np_images[0], np_images_size[0], sam1_images_size[0], images.sam1_transform)
326
+ # for mask in sam1_masks:
327
+ # _, _, _ = sam2.add_new_mask(
328
+ # inference_state=inference_state,
329
+ # frame_idx=0,
330
+ # obj_id=mask_num,
331
+ # mask=mask,
332
+ # )
333
+ # mask_num += 1
334
+
335
+ # video_segments = {} # video_segments contains the per-frame segmentation results
336
+ # for out_frame_idx, out_obj_ids, out_mask_logits in sam2.propagate_in_video(inference_state):
337
+ # sam2_masks = (out_mask_logits > 0.0).squeeze(1)
338
+
339
+ # video_segments[out_frame_idx] = {
340
+ # out_obj_id: sam2_masks[i].cpu().numpy()
341
+ # for i, out_obj_id in enumerate(out_obj_ids)
342
+ # }
343
+
344
+ # if out_frame_idx == 0:
345
+ # continue
346
+
347
+ # sam1_masks = get_mask_from_img_sam1(yolov8, mobilesamv2, sam1_images[out_frame_idx], np_images[out_frame_idx], np_images_size[out_frame_idx], sam1_images_size[out_frame_idx], images.sam1_transform)
348
+
349
+ # for sam1_mask in sam1_masks:
350
+ # flg = 1
351
+ # for sam2_mask in sam2_masks:
352
+ # # print(sam1_mask.shape, sam2_mask.shape)
353
+ # area1 = sam1_mask.sum()
354
+ # area2 = sam2_mask.sum()
355
+ # intersection = (sam1_mask & sam2_mask).sum()
356
+ # if min(intersection / area1, intersection / area2) > 0.25:
357
+ # flg = 0
358
+ # break
359
+ # if flg:
360
+ # video_segments[out_frame_idx][mask_num] = sam1_mask.cpu().numpy()
361
+ # mask_num += 1
362
+
363
+ # multi_view_clip_feats = torch.zeros((mask_num+1, 1024))
364
+ # multi_view_clip_feats_map = {}
365
+ # multi_view_clip_area_map = {}
366
+ # for now_frame in range(0, len(video_segments), 1):
367
+ # image = np_images[now_frame]
368
+
369
+ # seg_img_list = []
370
+ # out_obj_id_list = []
371
+ # out_obj_mask_list = []
372
+ # out_obj_area_list = []
373
+ # # NOTE: background: -1
374
+ # rev_seg_map = -np.ones(image.shape[:2], dtype=np.int64)
375
+ # sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=False)
376
+ # for out_obj_id, mask in sorted_dict_items:
377
+ # if mask.sum() == 0:
378
+ # continue
379
+ # rev_seg_map[mask] = out_obj_id
380
+ # rev_cog_seg_maps.append(rev_seg_map)
381
+
382
+ # seg_map = -np.ones(image.shape[:2], dtype=np.int64)
383
+ # sorted_dict_items = sorted(video_segments[now_frame].items(), key=lambda x: np.count_nonzero(x[1]), reverse=True)
384
+ # for out_obj_id, mask in sorted_dict_items:
385
+ # if mask.sum() == 0:
386
+ # continue
387
+ # box = np.int32(box_xyxy_to_xywh(mask_to_box(mask)))
388
 
389
+ # if box[2] == 0 and box[3] == 0:
390
+ # continue
391
+ # # print(box)
392
+ # seg_img = get_seg_img(mask, box, image)
393
+ # pad_seg_img = cv2.resize(pad_img(seg_img), (256,256))
394
+ # seg_img_list.append(pad_seg_img)
395
+ # seg_map[mask] = out_obj_id
396
+ # out_obj_id_list.append(out_obj_id)
397
+ # out_obj_area_list.append(np.count_nonzero(mask))
398
+ # out_obj_mask_list.append(mask)
399
+
400
+ # if len(seg_img_list) == 0:
401
+ # cog_seg_maps.append(seg_map)
402
+ # continue
403
+
404
+ # seg_imgs = np.stack(seg_img_list, axis=0) # b,H,W,3
405
+ # seg_imgs = torch.from_numpy(seg_imgs).permute(0,3,1,2) # / 255.0
406
 
407
+ # inputs = siglip_processor(images=seg_imgs, return_tensors="pt")
408
+ # inputs = {key: value.to(device) for key, value in inputs.items()}
409
 
410
+ # image_features = siglip.get_image_features(**inputs)
411
+ # image_features = image_features / image_features.norm(dim=-1, keepdim=True)
412
+ # image_features = image_features.detach().cpu()
413
+
414
+ # for i in range(len(out_obj_mask_list)):
415
+ # for j in range(i + 1, len(out_obj_mask_list)):
416
+ # mask1 = out_obj_mask_list[i]
417
+ # mask2 = out_obj_mask_list[j]
418
+ # intersection = np.logical_and(mask1, mask2).sum()
419
+ # area1 = out_obj_area_list[i]
420
+ # area2 = out_obj_area_list[j]
421
+ # if min(intersection / area1, intersection / area2) > 0.025:
422
+ # conf1 = area1 / (area1 + area2)
423
+ # # conf2 = area2 / (area1 + area2)
424
+ # image_features[j] = slerp(image_features[j], image_features[i], conf1)
425
+
426
+ # for i, clip_feat in enumerate(image_features):
427
+ # id = out_obj_id_list[i]
428
+ # if id in multi_view_clip_feats_map.keys():
429
+ # multi_view_clip_feats_map[id].append(clip_feat)
430
+ # multi_view_clip_area_map[id].append(out_obj_area_list[i])
431
+ # else:
432
+ # multi_view_clip_feats_map[id] = [clip_feat]
433
+ # multi_view_clip_area_map[id] = [out_obj_area_list[i]]
434
+
435
+ # cog_seg_maps.append(seg_map)
436
+ # del image_features
437
 
438
+ # for i in range(mask_num):
439
+ # if i in multi_view_clip_feats_map.keys():
440
+ # clip_feats = multi_view_clip_feats_map[i]
441
+ # mask_area = multi_view_clip_area_map[i]
442
+ # multi_view_clip_feats[i] = slerp_multiple(torch.stack(clip_feats), np.stack(mask_area))
443
+ # else:
444
+ # multi_view_clip_feats[i] = torch.zeros((1024))
445
+ # multi_view_clip_feats[mask_num] = torch.zeros((1024))
446
 
447
+ # return cog_seg_maps, rev_cog_seg_maps, multi_view_clip_feats
448
 
449
 
450
  @spaces.GPU(duration=60)
451
+ def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_conf_thr=3.0,
452
+ as_pointcloud=True, mask_sky=False, clean_depth=True, transparent_cams=True, cam_size=0.05,
453
+ scenegraph_type='complete', winsize=1, refid=0):
454
  """
455
  from a list of images, run dust3r inference, global aligner.
456
  then run get_3D_model_from_scene
457
  """
458
 
459
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
460
 
461
  MAST3R_CKP = 'naver/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric'
462
  mast3r = AsymmetricMASt3R.from_pretrained(MAST3R_CKP).to(device)
463
 
464
+ # sam2 = SAM2VideoPredictor.from_pretrained('facebook/sam2.1-hiera-large', device=device)
465
 
466
+ # siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
467
+ # siglip_processor = AutoProcessor.from_pretrained("google/siglip-large-patch16-256")
468
 
469
+ # SAM1_DECODER_CKP = './checkpoints/Prompt_guided_Mask_Decoder.pt'
470
+ # mobilesamv2 = sam_model_registry['sam_vit_h'](None)
471
+ # sam1 = SamModel.from_pretrained('facebook/sam-vit-huge')
472
+ # image_encoder = sam1.vision_encoder
473
 
474
+ # prompt_encoder, mask_decoder = sam_model_registry['prompt_guided_decoder'](SAM1_DECODER_CKP)
475
+ # mobilesamv2.prompt_encoder = prompt_encoder
476
+ # mobilesamv2.mask_decoder = mask_decoder
477
+ # mobilesamv2.image_encoder=image_encoder
478
+ # mobilesamv2.to(device=device)
479
+ # mobilesamv2.eval()
480
 
481
+ # YOLO8_CKP='./checkpoints/ObjectAwareModel.pt'
482
+ # yolov8 = ObjectAwareModel(YOLO8_CKP)
483
 
484
  if len(filelist) < 2:
485
  raise gradio.Error("Please input at least 2 images.")
 
487
  images = Images(filelist=filelist, device=device)
488
 
489
  # try:
490
+ # cog_seg_maps, rev_cog_seg_maps, cog_feats = get_cog_feats(images, sam2, siglip, siglip_processor, yolov8, mobilesamv2)
491
+ # imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
492
  # except Exception as e:
493
+ rev_cog_seg_maps = []
494
+ for tmp_img in images.np_images:
495
+ rev_seg_map = -np.ones(tmp_img.shape[:2], dtype=np.int64)
496
+ rev_cog_seg_maps.append(rev_seg_map)
497
+ cog_seg_maps = rev_cog_seg_maps
498
+ cog_feats = torch.zeros((1, 1024))
499
+ imgs = load_images(images, rev_cog_seg_maps, size=512, verbose=not silent)
500
 
501
  if len(imgs) == 1:
502
  imgs = [imgs[0], copy.deepcopy(imgs[0])]
 
539
 
540
  scene.to('cpu')
541
  torch.cuda.empty_cache()
 
542
  return scene, outfile
543
 
544
  # @spaces.GPU(duration=60)
 
573
  gradio.HTML('<h2 style="text-align: center;">PE3R Demo</h2>')
574
  with gradio.Column():
575
  inputfiles = gradio.File(file_count="multiple")
576
+ # with gradio.Row():
577
+ # schedule = gradio.Dropdown(["linear", "cosine"],
578
+ # value='linear', label="schedule", info="For global alignment!",
579
+ # visible=False)
580
+ # niter = gradio.Number(value=300, precision=0, minimum=0, maximum=5000,
581
+ # label="num_iterations", info="For global alignment!",
582
+ # visible=False)
583
+ # scenegraph_type = gradio.Dropdown([("complete: all possible image pairs", "complete"),
584
+ # ("swin: sliding window", "swin"),
585
+ # ("oneref: match one image with all", "oneref")],
586
+ # value='complete', label="Scenegraph",
587
+ # info="Define how to make pairs",
588
+ # interactive=True,
589
+ # visible=False)
590
+ # winsize = gradio.Slider(label="Scene Graph: Window Size", value=1,
591
+ # minimum=1, maximum=1, step=1, visible=False)
592
+ # refid = gradio.Slider(label="Scene Graph: Id", value=0, minimum=0, maximum=0, step=1, visible=False)
593
 
594
  run_btn = gradio.Button("Reconstruct")
595
 
596
+ # with gradio.Row():
597
  # adjust the confidence threshold
598
+ # min_conf_thr = gradio.Slider(label="min_conf_thr", value=3.0, minimum=1.0, maximum=20, step=0.1, visible=False)
599
  # adjust the camera size in the output pointcloud
600
+ # cam_size = gradio.Slider(label="cam_size", value=0.05, minimum=0.001, maximum=0.1, step=0.001, visible=False)
601
+ # with gradio.Row():
602
+ # as_pointcloud = gradio.Checkbox(value=True, label="As pointcloud", visible=False)
603
  # two post process implemented
604
+ # mask_sky = gradio.Checkbox(value=False, label="Mask sky", visible=False)
605
+ # clean_depth = gradio.Checkbox(value=True, label="Clean-up depthmaps", visible=False)
606
+ # transparent_cams = gradio.Checkbox(value=True, label="Transparent cameras", visible=False)
607
 
608
  with gradio.Row():
609
  text_input = gradio.Textbox(label="Query Text")
 
615
  # events
616
 
617
  run_btn.click(fn=recon_fun,
618
+ inputs=[inputfiles],
 
 
619
  outputs=[scene, outmodel]) # , outgallery
620
 
621
  # find_btn.click(fn=get_3D_object_from_scene_fun,