YulianSa commited on
Commit
756ba34
·
1 Parent(s): bdf35dc
Files changed (1) hide show
  1. infer_api.py +217 -221
infer_api.py CHANGED
@@ -107,7 +107,7 @@ for file in all_files:
107
  hf_hub_download(repo_id, file, local_dir="./ckpt")
108
 
109
  @spaces.GPU
110
- def set_seed(seed):
111
  random.seed(seed)
112
  np.random.seed(seed)
113
  torch.manual_seed(seed)
@@ -174,7 +174,7 @@ def process_image(image, totensor, width, height):
174
  def inference(validation_pipeline, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
175
  text_encoder, pretrained_model_path, validation, val_width, val_height, unet_condition_type,
176
  use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
177
- set_seed(seed)
178
  generator = torch.Generator(device=device).manual_seed(seed)
179
 
180
  totensor = transforms.ToTensor()
@@ -372,10 +372,10 @@ class InferAPI:
372
  return infer_multiview_gen(img, seed, num_levels)
373
 
374
  def genStage3(self, img):
375
- return self.slrm_infer.gen(img)
376
 
377
  def genStage4(self, meshes, imgs):
378
- return self.refine_infer.refine(meshes, imgs)
379
 
380
 
381
  ############## Refine ##############
@@ -400,6 +400,7 @@ def srgb_to_linear(c_srgb):
400
  return c_linear.clip(0, 1.)
401
 
402
 
 
403
  def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
404
  # convert from pytorch3d meshes to trimesh mesh
405
  vertices = meshes.verts_packed().cpu().float().numpy()
@@ -515,245 +516,240 @@ def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None,
515
  return distract_mask, distract_bbox, random_sampled_points, final_mask
516
 
517
 
518
- class InferRefineAPI:
519
- @spaces.GPU
520
- def __init__(self, config):
521
- self.sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
522
- self.generator = SamAutomaticMaskGenerator(
523
- model=self.sam,
524
- points_per_side=64,
525
- pred_iou_thresh=0.80,
526
- stability_score_thresh=0.92,
527
- crop_n_layers=1,
528
- crop_n_points_downscale_factor=2,
529
- min_mask_region_area=100,
530
- )
531
- self.outside_ratio = 0.20
532
-
533
- @spaces.GPU
534
- def refine(self, meshes, imgs):
535
- fixed_v, fixed_f, fixed_t = None, None, None
536
- flow_vert, flow_vector = None, None
537
- last_colors, last_normals = None, None
538
- last_front_color, last_front_normal = None, None
539
- distract_mask = None
540
-
541
- mv, proj = make_star_cameras_orthographic(8, 1, r=1.2)
542
- mv = mv[[4, 3, 2, 0, 6, 5]]
543
- renderer = NormalsRenderer(mv,proj,(1024,1024))
544
-
545
- results = []
546
-
547
- for name_idx, level in zip([2, 0, 1], [2, 1, 0]):
548
- mesh = trimesh.load(meshes[name_idx])
549
- new_mesh = mesh.split(only_watertight=False)
550
- new_mesh = [ j for j in new_mesh if len(j.vertices) >= 300 ]
551
- mesh = trimesh.Scene(new_mesh).dump(concatenate=True)
552
- mesh_v, mesh_f = mesh.vertices, mesh.faces
553
-
554
- if last_colors is None:
555
- images = renderer.render(
556
- torch.tensor(mesh_v, device='cuda').float(),
557
- torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(),
558
- torch.tensor(mesh_f, device='cuda'),
559
- )
560
- mask = (images[..., 3] < 0.9).cpu().numpy()
561
-
562
- colors, normals = [], []
563
- for i in range(6):
564
- color = np.array(imgs[level]['images'][i])
565
- normal = np.array(imgs[level]['normals'][i])
566
-
567
- if last_colors is not None:
568
- offset = calc_horizontal_offset(np.array(last_colors[i]), color)
569
- # print('offset', i, offset)
570
- else:
571
- offset = calc_horizontal_offset2(mask[i], color)
572
- # print('init offset', i, offset)
573
 
574
- if offset != 0:
575
- color = np.roll(color, offset, axis=1)
576
- normal = np.roll(normal, offset, axis=1)
 
 
 
 
577
 
578
- color = Image.fromarray(color)
579
- normal = Image.fromarray(normal)
580
- colors.append(color)
581
- normals.append(normal)
582
 
583
- if last_front_color is not None and level == 0:
584
- original_mask, distract_bbox, _, distract_mask = get_distract_mask(self.generator, last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=self.outside_ratio)
585
- else:
586
- distract_mask = None
587
- distract_bbox = None
588
 
589
- last_front_color = np.array(colors[0]).astype(np.float32) / 255.0
590
- last_front_normal = np.array(normals[0]).astype(np.float32) / 255.0
 
 
 
 
 
 
 
 
 
 
 
 
591
 
592
- if last_colors is None:
593
- from copy import deepcopy
594
- last_colors, last_normals = deepcopy(colors), deepcopy(normals)
 
595
 
596
- # my mesh flow weight by nearest vertexs
597
- if fixed_v is not None and fixed_f is not None and level == 1:
598
- t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
599
 
600
  fixed_v_cpu = fixed_v.cpu().numpy()
601
  kdtree_anchor = KDTree(fixed_v_cpu)
602
- kdtree_mesh_v = KDTree(mesh_v)
603
- _, idx_anchor = kdtree_anchor.query(mesh_v, k=1)
604
- _, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25)
605
  idx_anchor = idx_anchor.squeeze()
606
- neighbors = torch.tensor(mesh_v).cuda()[idx_mesh_v] # V, 25, 3
607
- # calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25]
608
- neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v).cuda()[:, None], dim=-1)
609
  neighbor_dists[neighbor_dists > 0.06] = 114514.
610
  neighbor_weights = torch.exp(-neighbor_dists * 1.)
611
  neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
612
  anchors = fixed_v[idx_anchor] # V, 3
613
  anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
614
- dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
615
  vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
616
  vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
617
  weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
618
- mesh_v += weighted_vec_anchor.cpu().numpy()
619
-
620
- t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
621
-
622
- mesh_v = torch.tensor(mesh_v, device='cuda', dtype=torch.float32)
623
- mesh_f = torch.tensor(mesh_f, device='cuda')
624
-
625
- new_mesh, simp_v, simp_f = geo_refine(mesh_v, mesh_f, colors, normals, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask, distract_bbox=distract_bbox)
626
-
627
- # my mesh flow weight by nearest vertexs
628
- try:
629
- if fixed_v is not None and fixed_f is not None and level != 0:
630
- new_mesh_v = new_mesh.verts_packed().cpu().numpy()
631
-
632
- fixed_v_cpu = fixed_v.cpu().numpy()
633
- kdtree_anchor = KDTree(fixed_v_cpu)
634
- kdtree_mesh_v = KDTree(new_mesh_v)
635
- _, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1)
636
- _, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25)
637
- idx_anchor = idx_anchor.squeeze()
638
- neighbors = torch.tensor(new_mesh_v).cuda()[idx_mesh_v] # V, 25, 3
639
- # calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25]
640
- neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v).cuda()[:, None], dim=-1)
641
- neighbor_dists[neighbor_dists > 0.06] = 114514.
642
- neighbor_weights = torch.exp(-neighbor_dists * 1.)
643
- neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
644
- anchors = fixed_v[idx_anchor] # V, 3
645
- anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
646
- dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
647
- vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
648
- vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
649
- weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
650
- new_mesh_v += weighted_vec_anchor.cpu().numpy()
651
-
652
- # replace new_mesh verts with new_mesh_v
653
- new_mesh = Meshes(verts=[torch.tensor(new_mesh_v, device='cuda')], faces=new_mesh.faces_list(), textures=new_mesh.textures)
654
-
655
- except Exception as e:
656
- pass
657
-
658
- notsimp_v, notsimp_f, notsimp_t = new_mesh.verts_packed(), new_mesh.faces_packed(), new_mesh.textures.verts_features_packed()
659
-
660
- if fixed_v is None:
661
- fixed_v, fixed_f = simp_v, simp_f
662
- complete_v, complete_f, complete_t = notsimp_v, notsimp_f, notsimp_t
663
- else:
664
- fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0)
665
- fixed_v = torch.cat([fixed_v, simp_v], dim=0)
666
-
667
- complete_f = torch.cat([complete_f, notsimp_f + complete_v.shape[0]], dim=0)
668
- complete_v = torch.cat([complete_v, notsimp_v], dim=0)
669
- complete_t = torch.cat([complete_t, notsimp_t], dim=0)
670
-
671
- if level == 2:
672
- new_mesh = Meshes(verts=[new_mesh.verts_packed()], faces=[new_mesh.faces_packed()], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[torch.ones_like(new_mesh.textures.verts_features_packed(), device=new_mesh.verts_packed().device)*0.5]))
673
-
674
- save_py3dmesh_with_trimesh_fast(new_mesh, meshes[name_idx].replace('.obj', '_refined.obj'), apply_sRGB_to_LinearRGB=False)
675
- results.append(meshes[name_idx].replace('.obj', '_refined.obj'))
676
-
677
- # save whole mesh
678
- save_py3dmesh_with_trimesh_fast(Meshes(verts=[complete_v], faces=[complete_f], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[complete_t])), meshes[name_idx].replace('.obj', '_refined_whole.obj'), apply_sRGB_to_LinearRGB=False)
679
- results.append(meshes[name_idx].replace('.obj', '_refined_whole.obj'))
680
-
681
- return results
682
-
683
-
684
- class InferSlrmAPI:
685
- @spaces.GPU
686
- def __init__(self, config):
687
- self.config_path = config['config_path']
688
- self.config = OmegaConf.load(self.config_path)
689
- self.config_name = os.path.basename(self.config_path).replace('.yaml', '')
690
- self.model_config = self.config.model_config
691
- self.infer_config = self.config.infer_config
692
- self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
693
- self.model = instantiate_from_config(self.model_config)
694
- state_dict = torch.load(self.infer_config.model_path, map_location='cpu')
695
- self.model.load_state_dict(state_dict, strict=False)
696
- self.model = self.model.to(self.device)
697
- self.model.init_flexicubes_geometry(self.device, fovy=30.0, is_ortho=self.model.is_ortho)
698
- self.model = self.model.eval()
699
-
700
- @spaces.GPU
701
- def gen(self, imgs):
702
- imgs = [ cv2.imread(img[0])[:, :, ::-1] for img in imgs ]
703
- imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0
704
- imgs = torch.from_numpy(np.array(imgs)).permute(0, 3, 1, 2).contiguous().float() # (6, 3, 1024, 1024)
705
- mesh_glb_fpaths = self.make3d(imgs)
706
- return mesh_glb_fpaths[1:4] + mesh_glb_fpaths[0:1]
707
-
708
- @spaces.GPU
709
- def make3d(self, images):
710
- input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device)
711
-
712
- images = images.unsqueeze(0).to(device)
713
- images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
714
-
715
- mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
716
- print(mesh_fpath)
717
- mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
718
- mesh_dirname = os.path.dirname(mesh_fpath)
719
-
720
- with torch.no_grad():
721
- # get triplane
722
- planes = self.model.forward_planes(images, input_cameras.float())
723
-
724
- # get mesh
725
- mesh_glb_fpaths = []
726
- for j in range(4):
727
- mesh_glb_fpath = self.make_mesh(mesh_fpath.replace(mesh_fpath[-4:], f'_{j}{mesh_fpath[-4:]}'), planes, level=[0, 3, 4, 2][j])
728
- mesh_glb_fpaths.append(mesh_glb_fpath)
729
-
730
- return mesh_glb_fpaths
731
-
732
- @spaces.GPU
733
- def make_mesh(self, mesh_fpath, planes, level=None):
734
- mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
735
- mesh_dirname = os.path.dirname(mesh_fpath)
736
- mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
737
-
738
- with torch.no_grad():
739
- # get mesh
740
- mesh_out = self.model.extract_mesh(
741
- planes,
742
- use_texture_map=False,
743
- levels=torch.tensor([level]).to(device),
744
- **self.infer_config,
745
- )
746
 
747
- vertices, faces, vertex_colors = mesh_out
748
- vertices = vertices[:, [1, 2, 0]]
749
 
750
- if level == 2:
751
- # fill all vertex_colors with 127
752
- vertex_colors = np.ones_like(vertex_colors) * 127
 
 
 
 
 
753
 
754
- save_obj(vertices, faces, vertex_colors, mesh_fpath)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
755
 
756
- return mesh_fpath
757
 
758
 
759
  parser = argparse.ArgumentParser()
 
107
  hf_hub_download(repo_id, file, local_dir="./ckpt")
108
 
109
  @spaces.GPU
110
+ def set_seed22(seed):
111
  random.seed(seed)
112
  np.random.seed(seed)
113
  torch.manual_seed(seed)
 
174
  def inference(validation_pipeline, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
175
  text_encoder, pretrained_model_path, validation, val_width, val_height, unet_condition_type,
176
  use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
177
+ set_seed2(seed)
178
  generator = torch.Generator(device=device).manual_seed(seed)
179
 
180
  totensor = transforms.ToTensor()
 
372
  return infer_multiview_gen(img, seed, num_levels)
373
 
374
  def genStage3(self, img):
375
+ return infer_slrm_gen(img)
376
 
377
  def genStage4(self, meshes, imgs):
378
+ return infer_refine(meshes, imgs)
379
 
380
 
381
  ############## Refine ##############
 
400
  return c_linear.clip(0, 1.)
401
 
402
 
403
+ @spaces.GPU
404
  def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
405
  # convert from pytorch3d meshes to trimesh mesh
406
  vertices = meshes.verts_packed().cpu().float().numpy()
 
516
  return distract_mask, distract_bbox, random_sampled_points, final_mask
517
 
518
 
519
+ infer_refine_sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
520
+ infer_refine_generator = SamAutomaticMaskGenerator(
521
+ model=infer_refine_sam,
522
+ points_per_side=64,
523
+ pred_iou_thresh=0.80,
524
+ stability_score_thresh=0.92,
525
+ crop_n_layers=1,
526
+ crop_n_points_downscale_factor=2,
527
+ min_mask_region_area=100,
528
+ )
529
+ infer_refine_outside_ratio = 0.20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
+ @spaces.GPU
532
+ def infer_refine(meshes, imgs):
533
+ fixed_v, fixed_f, fixed_t = None, None, None
534
+ flow_vert, flow_vector = None, None
535
+ last_colors, last_normals = None, None
536
+ last_front_color, last_front_normal = None, None
537
+ distract_mask = None
538
 
539
+ mv, proj = make_star_cameras_orthographic(8, 1, r=1.2)
540
+ mv = mv[[4, 3, 2, 0, 6, 5]]
541
+ renderer = NormalsRenderer(mv,proj,(1024,1024))
 
542
 
543
+ results = []
 
 
 
 
544
 
545
+ for name_idx, level in zip([2, 0, 1], [2, 1, 0]):
546
+ mesh = trimesh.load(meshes[name_idx])
547
+ new_mesh = mesh.split(only_watertight=False)
548
+ new_mesh = [ j for j in new_mesh if len(j.vertices) >= 300 ]
549
+ mesh = trimesh.Scene(new_mesh).dump(concatenate=True)
550
+ mesh_v, mesh_f = mesh.vertices, mesh.faces
551
+
552
+ if last_colors is None:
553
+ images = renderer.render(
554
+ torch.tensor(mesh_v, device='cuda').float(),
555
+ torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(),
556
+ torch.tensor(mesh_f, device='cuda'),
557
+ )
558
+ mask = (images[..., 3] < 0.9).cpu().numpy()
559
 
560
+ colors, normals = [], []
561
+ for i in range(6):
562
+ color = np.array(imgs[level]['images'][i])
563
+ normal = np.array(imgs[level]['normals'][i])
564
 
565
+ if last_colors is not None:
566
+ offset = calc_horizontal_offset(np.array(last_colors[i]), color)
567
+ # print('offset', i, offset)
568
+ else:
569
+ offset = calc_horizontal_offset2(mask[i], color)
570
+ # print('init offset', i, offset)
571
+
572
+ if offset != 0:
573
+ color = np.roll(color, offset, axis=1)
574
+ normal = np.roll(normal, offset, axis=1)
575
+
576
+ color = Image.fromarray(color)
577
+ normal = Image.fromarray(normal)
578
+ colors.append(color)
579
+ normals.append(normal)
580
+
581
+ if last_front_color is not None and level == 0:
582
+ original_mask, distract_bbox, _, distract_mask = get_distract_mask(infer_refine_generator, last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=infer_refine_outside_ratio)
583
+ else:
584
+ distract_mask = None
585
+ distract_bbox = None
586
+
587
+ last_front_color = np.array(colors[0]).astype(np.float32) / 255.0
588
+ last_front_normal = np.array(normals[0]).astype(np.float32) / 255.0
589
+
590
+ if last_colors is None:
591
+ from copy import deepcopy
592
+ last_colors, last_normals = deepcopy(colors), deepcopy(normals)
593
+
594
+ # my mesh flow weight by nearest vertexs
595
+ if fixed_v is not None and fixed_f is not None and level == 1:
596
+ t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
597
+
598
+ fixed_v_cpu = fixed_v.cpu().numpy()
599
+ kdtree_anchor = KDTree(fixed_v_cpu)
600
+ kdtree_mesh_v = KDTree(mesh_v)
601
+ _, idx_anchor = kdtree_anchor.query(mesh_v, k=1)
602
+ _, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25)
603
+ idx_anchor = idx_anchor.squeeze()
604
+ neighbors = torch.tensor(mesh_v).cuda()[idx_mesh_v] # V, 25, 3
605
+ # calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25]
606
+ neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v).cuda()[:, None], dim=-1)
607
+ neighbor_dists[neighbor_dists > 0.06] = 114514.
608
+ neighbor_weights = torch.exp(-neighbor_dists * 1.)
609
+ neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
610
+ anchors = fixed_v[idx_anchor] # V, 3
611
+ anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
612
+ dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
613
+ vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
614
+ vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
615
+ weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
616
+ mesh_v += weighted_vec_anchor.cpu().numpy()
617
+
618
+ t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
619
+
620
+ mesh_v = torch.tensor(mesh_v, device='cuda', dtype=torch.float32)
621
+ mesh_f = torch.tensor(mesh_f, device='cuda')
622
+
623
+ new_mesh, simp_v, simp_f = geo_refine(mesh_v, mesh_f, colors, normals, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask, distract_bbox=distract_bbox)
624
+
625
+ # my mesh flow weight by nearest vertexs
626
+ try:
627
+ if fixed_v is not None and fixed_f is not None and level != 0:
628
+ new_mesh_v = new_mesh.verts_packed().cpu().numpy()
629
 
630
  fixed_v_cpu = fixed_v.cpu().numpy()
631
  kdtree_anchor = KDTree(fixed_v_cpu)
632
+ kdtree_mesh_v = KDTree(new_mesh_v)
633
+ _, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1)
634
+ _, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25)
635
  idx_anchor = idx_anchor.squeeze()
636
+ neighbors = torch.tensor(new_mesh_v).cuda()[idx_mesh_v] # V, 25, 3
637
+ # calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25]
638
+ neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v).cuda()[:, None], dim=-1)
639
  neighbor_dists[neighbor_dists > 0.06] = 114514.
640
  neighbor_weights = torch.exp(-neighbor_dists * 1.)
641
  neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
642
  anchors = fixed_v[idx_anchor] # V, 3
643
  anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
644
+ dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
645
  vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
646
  vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
647
  weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
648
+ new_mesh_v += weighted_vec_anchor.cpu().numpy()
649
+
650
+ # replace new_mesh verts with new_mesh_v
651
+ new_mesh = Meshes(verts=[torch.tensor(new_mesh_v, device='cuda')], faces=new_mesh.faces_list(), textures=new_mesh.textures)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
652
 
653
+ except Exception as e:
654
+ pass
655
 
656
+ notsimp_v, notsimp_f, notsimp_t = new_mesh.verts_packed(), new_mesh.faces_packed(), new_mesh.textures.verts_features_packed()
657
+
658
+ if fixed_v is None:
659
+ fixed_v, fixed_f = simp_v, simp_f
660
+ complete_v, complete_f, complete_t = notsimp_v, notsimp_f, notsimp_t
661
+ else:
662
+ fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0)
663
+ fixed_v = torch.cat([fixed_v, simp_v], dim=0)
664
 
665
+ complete_f = torch.cat([complete_f, notsimp_f + complete_v.shape[0]], dim=0)
666
+ complete_v = torch.cat([complete_v, notsimp_v], dim=0)
667
+ complete_t = torch.cat([complete_t, notsimp_t], dim=0)
668
+
669
+ if level == 2:
670
+ new_mesh = Meshes(verts=[new_mesh.verts_packed()], faces=[new_mesh.faces_packed()], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[torch.ones_like(new_mesh.textures.verts_features_packed(), device=new_mesh.verts_packed().device)*0.5]))
671
+
672
+ save_py3dmesh_with_trimesh_fast(new_mesh, meshes[name_idx].replace('.obj', '_refined.obj'), apply_sRGB_to_LinearRGB=False)
673
+ results.append(meshes[name_idx].replace('.obj', '_refined.obj'))
674
+
675
+ # save whole mesh
676
+ save_py3dmesh_with_trimesh_fast(Meshes(verts=[complete_v], faces=[complete_f], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[complete_t])), meshes[name_idx].replace('.obj', '_refined_whole.obj'), apply_sRGB_to_LinearRGB=False)
677
+ results.append(meshes[name_idx].replace('.obj', '_refined_whole.obj'))
678
+
679
+ return results
680
+
681
+ config_slrm = {
682
+ 'config_path': './configs/mesh-slrm-infer.yaml'
683
+ }
684
+ infer_slrm_config_path = config_slrm['config_path']
685
+ infer_slrm_config = OmegaConf.load(infer_slrm_config_path)
686
+ infer_slrm_config_name = os.path.basename(infer_slrm_config_path).replace('.yaml', '')
687
+ infer_slrm_model_config = infer_slrm_config.model_config
688
+ infer_slrm_infer_config = infer_slrm_config.infer_config
689
+ infer_slrm_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
690
+ infer_slrm_model = instantiate_from_config(infer_slrm_model_config)
691
+ state_dict = torch.load(infer_slrm_infer_config.model_path, map_location='cpu')
692
+ infer_slrm_model.load_state_dict(state_dict, strict=False)
693
+ infer_slrm_model = infer_slrm_model.to(infer_slrm_device)
694
+ infer_slrm_model.init_flexicubes_geometry(infer_slrm_device, fovy=30.0, is_ortho=infer_slrm_model.is_ortho)
695
+ infer_slrm_model = infer_slrm_model.eval()
696
+
697
+ @spaces.GPU
698
+ def infer_slrm_gen(imgs):
699
+ imgs = [ cv2.imread(img[0])[:, :, ::-1] for img in imgs ]
700
+ imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0
701
+ imgs = torch.from_numpy(np.array(imgs)).permute(0, 3, 1, 2).contiguous().float() # (6, 3, 1024, 1024)
702
+ mesh_glb_fpaths = infer_slrm_make3d(imgs)
703
+ return mesh_glb_fpaths[1:4] + mesh_glb_fpaths[0:1]
704
+
705
+ @spaces.GPU
706
+ def infer_slrm_make3d(images):
707
+ input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device)
708
+
709
+ images = images.unsqueeze(0).to(device)
710
+ images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
711
+
712
+ mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
713
+ print(mesh_fpath)
714
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
715
+ mesh_dirname = os.path.dirname(mesh_fpath)
716
+
717
+ with torch.no_grad():
718
+ # get triplane
719
+ planes = infer_slrm_model.forward_planes(images, input_cameras.float())
720
+
721
+ # get mesh
722
+ mesh_glb_fpaths = []
723
+ for j in range(4):
724
+ mesh_glb_fpath = infer_slrm_make_mesh(mesh_fpath.replace(mesh_fpath[-4:], f'_{j}{mesh_fpath[-4:]}'), planes, level=[0, 3, 4, 2][j])
725
+ mesh_glb_fpaths.append(mesh_glb_fpath)
726
+
727
+ return mesh_glb_fpaths
728
+
729
+ @spaces.GPU
730
+ def infer_slrm_make_mesh(mesh_fpath, planes, level=None):
731
+ mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
732
+ mesh_dirname = os.path.dirname(mesh_fpath)
733
+
734
+ with torch.no_grad():
735
+ # get mesh
736
+ mesh_out = infer_slrm_model.extract_mesh(
737
+ planes,
738
+ use_texture_map=False,
739
+ levels=torch.tensor([level]).to(device),
740
+ **infer_slrm_infer_config,
741
+ )
742
+
743
+ vertices, faces, vertex_colors = mesh_out
744
+ vertices = vertices[:, [1, 2, 0]]
745
+
746
+ if level == 2:
747
+ # fill all vertex_colors with 127
748
+ vertex_colors = np.ones_like(vertex_colors) * 127
749
+
750
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
751
 
752
+ return mesh_fpath
753
 
754
 
755
  parser = argparse.ArgumentParser()