hujiecpp commited on
Commit
5505892
·
1 Parent(s): 2768473

init project

Browse files
app.py CHANGED
@@ -43,6 +43,7 @@ from modules.mobilesamv2 import sam_model_registry
43
  from sam2.sam2_video_predictor import SAM2VideoPredictor
44
  from modules.mast3r.model import AsymmetricMASt3R
45
 
 
46
 
47
  silent = False
48
 
@@ -448,6 +449,44 @@ def get_cog_feats(images, sam2, siglip, siglip_processor, yolov8, mobilesamv2):
448
  return cog_seg_maps, rev_cog_seg_maps, multi_view_clip_feats
449
 
450
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
451
  @spaces.GPU(duration=120)
452
  def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_conf_thr=3.0,
453
  as_pointcloud=True, mask_sky=False, clean_depth=True, transparent_cams=True, cam_size=0.05,
@@ -540,37 +579,46 @@ def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_
540
 
541
  torch.cuda.empty_cache()
542
 
 
 
 
 
 
 
 
 
 
 
 
 
 
543
 
544
- return outfile
545
-
546
- # def get_3D_object_from_scene(outdir, text, threshold, scene, min_conf_thr=3.0, as_pointcloud=True,
547
- # mask_sky=False, clean_depth=True, transparent_cams=True, cam_size=0.05):
548
-
549
- # device = 'cpu'
550
- # siglip_tokenizer = AutoTokenizer.from_pretrained("google/siglip-large-patch16-256")
551
- # siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
552
-
553
- # texts = [text]
554
- # inputs = siglip_tokenizer(text=texts, padding="max_length", return_tensors="pt")
555
- # inputs = {key: value.to(device) for key, value in inputs.items()}
556
- # with torch.no_grad():
557
- # text_feats =siglip.get_text_features(**inputs)
558
- # text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
559
- # scene.render_image(text_feats, threshold)
560
- # scene.ori_imgs = scene.rendered_imgs
561
-
562
 
563
- # rgbimg = scene.ori_imgs
564
- # focals = scene.get_focals().cpu()
565
- # cams2world = scene.get_im_poses().cpu()
566
- # # 3D pointcloud from depthmap, poses and intrinsics
567
- # pts3d = to_numpy(scene.get_pts3d())
568
- # scene.min_conf_thr = float(scene.conf_trf(torch.tensor(min_conf_thr)))
569
- # msk = to_numpy(scene.get_masks())
570
- # return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
571
- # transparent_cams=transparent_cams, cam_size=cam_size)
572
 
 
 
 
 
 
 
573
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
574
 
575
 
576
 
@@ -579,11 +627,11 @@ tmpdirname = tempfile.mkdtemp(suffix='pe3r_gradio_demo')
579
 
580
  recon_fun = functools.partial(get_reconstructed_scene, tmpdirname)
581
  # model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname)
582
- # get_3D_object_from_scene_fun = functools.partial(get_3D_object_from_scene, tmpdirname)
583
 
584
  with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="PE3R Demo") as demo:
585
  # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
586
- # scene = gradio.State(None)
587
 
588
  gradio.HTML('<h2 style="text-align: center;">PE3R Demo</h2>')
589
  with gradio.Column():
@@ -602,9 +650,9 @@ with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 10
602
 
603
  run_btn.click(fn=recon_fun,
604
  inputs=[inputfiles],
605
- outputs=[outmodel]) # , outgallery, scene,
606
 
607
- # find_btn.click(fn=get_3D_object_from_scene_fun,
608
- # inputs=[text_input, threshold, scene],
609
- # outputs=outmodel)
610
  demo.launch(show_error=True, share=None, server_name=None, server_port=None)
 
43
  from sam2.sam2_video_predictor import SAM2VideoPredictor
44
  from modules.mast3r.model import AsymmetricMASt3R
45
 
46
+ from torch.nn.functional import cosine_similarity
47
 
48
  silent = False
49
 
 
449
  return cog_seg_maps, rev_cog_seg_maps, multi_view_clip_feats
450
 
451
 
452
+ class Scene_cpu:
453
+ def __init__(self, fix_imgs, cogs, focals, cams2world, pts3d, min_conf_thr, msk):
454
+ self.fix_imgs = fix_imgs
455
+ self.cogs = cogs
456
+ self.focals = focals
457
+ self.cams2world = cams2world
458
+ self.pts3d = pts3d
459
+ self.min_conf_thr = min_conf_thr
460
+ self.msk = msk
461
+
462
+ def render_image(self, text_feats, threshold=0.85):
463
+ self.rendered_imgs = []
464
+ # Collect all cosine similarities to compute min-max normalization
465
+ all_similarities = []
466
+ for each_cog in self.cogs:
467
+ similarity_map = cosine_similarity(each_cog, text_feats.unsqueeze(1), dim=-1)
468
+ all_similarities.append(similarity_map.squeeze().numpy())
469
+ # Flatten and normalize all similarities
470
+ total_similarities = np.concatenate(all_similarities)
471
+ min_sim, max_sim = total_similarities.min(), total_similarities.max()
472
+ normalized_similarities = [(sim - min_sim) / (max_sim - min_sim) for sim in all_similarities]
473
+ # Process each image with normalized similarities
474
+ for i, (each_cog, heatmap) in enumerate(zip(self.cogs, normalized_similarities)):
475
+ mask = heatmap > threshold
476
+ # Scale heatmap for visualization
477
+ heatmap = np.uint8(255 * heatmap)
478
+ heatmap_color = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
479
+ # Prepare image
480
+ image = self.fix_imgs[i]
481
+ image = image * 255.0
482
+ image = np.clip(image, 0, 255).astype(np.uint8)
483
+ # Apply mask and overlay heatmap with red RGB for masked areas
484
+ mask_indices = np.where(mask) # Get indices where mask is True
485
+ heatmap_color[mask_indices[0], mask_indices[1]] = [0, 0, 255] # Red color for masked regions
486
+ superimposed_img = np.where(np.expand_dims(mask, axis=-1), heatmap_color, image) / 255.0
487
+ self.rendered_imgs.append(superimposed_img)
488
+
489
+
490
  @spaces.GPU(duration=120)
491
  def get_reconstructed_scene(outdir, filelist, schedule='linear', niter=300, min_conf_thr=3.0,
492
  as_pointcloud=True, mask_sky=False, clean_depth=True, transparent_cams=True, cam_size=0.05,
 
579
 
580
  torch.cuda.empty_cache()
581
 
582
+ fix_imgs = []
583
+ for img in scene.fix_imgs:
584
+ fix_imgs.append(img)
585
+ cogs = []
586
+ for cog in scene.cogs:
587
+ cog_cpu = cog.cpu()
588
+ cogs.append(cog_cpu)
589
+ focals = scene.get_focals().cpu()
590
+ cams2world = scene.get_im_poses().cpu()
591
+ pts3d = to_numpy(scene.get_pts3d())
592
+ min_conf_thr = float(scene.conf_trf(torch.tensor(3.0)))
593
+ msk = to_numpy(scene.get_masks())
594
+ scene_cpu = Scene_cpu(fix_imgs, cogs, focals, cams2world, pts3d, min_conf_thr, msk)
595
 
596
+ return scene_cpu, outfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
 
 
 
 
 
 
 
 
 
 
598
 
599
+ def get_3D_object_from_scene(outdir, text, threshold, scene, min_conf_thr=3.0, as_pointcloud=True,
600
+ mask_sky=False, clean_depth=True, transparent_cams=True, cam_size=0.05):
601
+
602
+ device = 'cpu'
603
+ siglip_tokenizer = AutoTokenizer.from_pretrained("google/siglip-large-patch16-256")
604
+ siglip = AutoModel.from_pretrained("google/siglip-large-patch16-256", device_map=device)
605
 
606
+ texts = [text]
607
+ inputs = siglip_tokenizer(text=texts, padding="max_length", return_tensors="pt")
608
+ inputs = {key: value.to(device) for key, value in inputs.items()}
609
+ with torch.no_grad():
610
+ text_feats =siglip.get_text_features(**inputs)
611
+ text_feats = text_feats / text_feats.norm(dim=-1, keepdim=True)
612
+ scene.render_image(text_feats, threshold)
613
+ scene.ori_imgs = scene.rendered_imgs
614
+ rgbimg = scene.ori_imgs
615
+ focals = scene.focals
616
+ cams2world = scene.cams2world
617
+ # 3D pointcloud from depthmap, poses and intrinsics
618
+ pts3d = scene.pts3d
619
+ msk = scene.msk
620
+ return _convert_scene_output_to_glb(outdir, rgbimg, pts3d, msk, focals, cams2world, as_pointcloud=as_pointcloud,
621
+ transparent_cams=transparent_cams, cam_size=cam_size)
622
 
623
 
624
 
 
627
 
628
  recon_fun = functools.partial(get_reconstructed_scene, tmpdirname)
629
  # model_from_scene_fun = functools.partial(get_3D_model_from_scene, tmpdirname)
630
+ get_3D_object_from_scene_fun = functools.partial(get_3D_object_from_scene, tmpdirname)
631
 
632
  with gradio.Blocks(css=""".gradio-container {margin: 0 !important; min-width: 100%};""", title="PE3R Demo") as demo:
633
  # scene state is save so that you can change conf_thr, cam_size... without rerunning the inference
634
+ scene = gradio.State(None)
635
 
636
  gradio.HTML('<h2 style="text-align: center;">PE3R Demo</h2>')
637
  with gradio.Column():
 
650
 
651
  run_btn.click(fn=recon_fun,
652
  inputs=[inputfiles],
653
+ outputs=[scene, outmodel]) # , outgallery, ,
654
 
655
+ find_btn.click(fn=get_3D_object_from_scene_fun,
656
+ inputs=[text_input, threshold, scene],
657
+ outputs=outmodel)
658
  demo.launch(show_error=True, share=None, server_name=None, server_port=None)
modules/pe3r/__pycache__/models.cpython-312.pyc CHANGED
Binary files a/modules/pe3r/__pycache__/models.cpython-312.pyc and b/modules/pe3r/__pycache__/models.cpython-312.pyc differ