YulianSa commited on
Commit
216a665
·
1 Parent(s): f1e6905
Files changed (4) hide show
  1. app.py +14 -14
  2. infer_api.py +68 -73
  3. refine/mesh_refine.py +168 -14
  4. slrm/models/lrm_mesh.py +2 -2
app.py CHANGED
@@ -10,20 +10,20 @@ import os
10
  import shlex
11
  import subprocess
12
 
13
- os.makedirs("./ckpt", exist_ok=True)
14
- # download ViT-H SAM model into ./ckpt
15
- subprocess.call(["wget", "-q", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "-O", "./ckpt/sam_vit_h_4b8939.pth"])
16
-
17
- subprocess.run(
18
- shlex.split(
19
- "pip install pip==24.0"
20
- )
21
- )
22
- subprocess.run(
23
- shlex.split(
24
- "pip install package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
25
- )
26
- )
27
 
28
  from infer_api import InferAPI
29
 
 
10
  import shlex
11
  import subprocess
12
 
13
+ # os.makedirs("./ckpt", exist_ok=True)
14
+ # # download ViT-H SAM model into ./ckpt
15
+ # subprocess.call(["wget", "-q", "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth", "-O", "./ckpt/sam_vit_h_4b8939.pth"])
16
+
17
+ # subprocess.run(
18
+ # shlex.split(
19
+ # "pip install pip==24.0"
20
+ # )
21
+ # )
22
+ # subprocess.run(
23
+ # shlex.split(
24
+ # "pip install package/nvdiffrast-0.3.1.torch-cp310-cp310-linux_x86_64.whl --force-reinstall --no-deps"
25
+ # )
26
+ # )
27
 
28
  from infer_api import InferAPI
29
 
infer_api.py CHANGED
@@ -12,6 +12,7 @@ from omegaconf import OmegaConf
12
  import numpy as np
13
 
14
  import torch
 
15
 
16
  from diffusers import AutoencoderKL, DDIMScheduler
17
  from diffusers.utils import check_min_version
@@ -72,7 +73,7 @@ from slrm.utils.camera_util import (
72
  FOV_to_intrinsics,
73
  get_circular_camera_poses,
74
  )
75
- from slrm.utils.mesh_util import save_obj, save_glb
76
  from slrm.utils.infer_util import images_to_video
77
 
78
  import cv2
@@ -477,7 +478,7 @@ def calc_horizontal_offset2(target_mask, source_img):
477
 
478
 
479
  @spaces.GPU
480
- def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None, thres=0.25, ratio=0.50, outside_thres=0.10, outside_ratio=0.20):
481
  distract_area = np.abs(color_0 - color_1).sum(axis=-1) > thres
482
  if normal_0 is not None and normal_1 is not None:
483
  distract_area |= np.abs(normal_0 - normal_1).sum(axis=-1) > thres
@@ -503,43 +504,7 @@ def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None,
503
  max_x, max_y = bbox.max(axis=0)
504
  distract_bbox[min_x:max_x, min_y:max_y] = 1
505
 
506
- points = np.array(random_sampled_points)[:, ::-1]
507
- labels = np.ones(len(points), dtype=np.int32)
508
-
509
- masks = generator.generate((color_1 * 255).astype(np.uint8))
510
-
511
- outside_area = np.abs(color_0 - color_1).sum(axis=-1) < outside_thres
512
-
513
- final_mask = np.zeros_like(distract_mask)
514
- for iii, mask in enumerate(masks):
515
- mask['segmentation'] = cv2.resize(mask['segmentation'].astype(np.float32), (1024, 1024)) > 0.5
516
- intersection = np.logical_and(mask['segmentation'], distract_mask).sum()
517
- total = mask['segmentation'].sum()
518
- iou = intersection / total
519
- outside_intersection = np.logical_and(mask['segmentation'], outside_area).sum()
520
- outside_total = mask['segmentation'].sum()
521
- outside_iou = outside_intersection / outside_total
522
- if iou > ratio and outside_iou < outside_ratio:
523
- final_mask |= mask['segmentation']
524
-
525
- # calculate coverage
526
- intersection = np.logical_and(final_mask, distract_mask).sum()
527
- total = distract_mask.sum()
528
- coverage = intersection / total
529
-
530
- if coverage < 0.8:
531
- # use original distract mask
532
- final_mask = (distract_mask.copy() * 255).astype(np.uint8)
533
- final_mask = cv2.dilate(final_mask, np.ones((3, 3), np.uint8), iterations=3)
534
- labeled_array_dilate, num_features_dilate = scipy.ndimage.label(final_mask)
535
- for i in range(num_features_dilate + 1):
536
- if np.sum(labeled_array_dilate == i) < 200:
537
- final_mask[labeled_array_dilate == i] = 255
538
-
539
- final_mask = cv2.erode(final_mask, np.ones((3, 3), np.uint8), iterations=3)
540
- final_mask = final_mask > 127
541
-
542
- return distract_mask, distract_bbox, random_sampled_points, final_mask
543
 
544
 
545
  # infer_refine_sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
@@ -563,6 +528,7 @@ def infer_refine(meshes, imgs):
563
  distract_mask = None
564
 
565
  results = []
 
566
 
567
  for name_idx, level in zip([2, 0, 1], [2, 1, 0]):
568
  mesh = trimesh.load(meshes[name_idx])
@@ -607,11 +573,11 @@ def infer_refine(meshes, imgs):
607
  colors.append(color)
608
  normals.append(normal)
609
 
610
- # if last_front_color is not None and level == 0:
611
- # original_mask, distract_bbox, _, distract_mask = get_distract_mask(infer_refine_generator, last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=infer_refine_outside_ratio)
612
- # else:
613
- distract_mask = None
614
- distract_bbox = None
615
 
616
  if last_colors is None:
617
  from copy import deepcopy
@@ -625,15 +591,15 @@ def infer_refine(meshes, imgs):
625
  _, idx_anchor = kdtree_anchor.query(mesh_v, k=1)
626
  _, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25)
627
  idx_anchor = idx_anchor.squeeze()
628
- neighbors = torch.tensor(mesh_v)[idx_mesh_v] # V, 25, 3
629
  # calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25]
630
- neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v)[:, None], dim=-1)
631
  neighbor_dists[neighbor_dists > 0.06] = 114514.
632
  neighbor_weights = torch.exp(-neighbor_dists * 1.)
633
  neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
634
  anchors = fixed_v[idx_anchor] # V, 3
635
  anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
636
- dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v)) * anchor_normals).sum(-1), min=0) + 0.01
637
  vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
638
  vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
639
  weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
@@ -647,7 +613,7 @@ def infer_refine(meshes, imgs):
647
  # my mesh flow weight by nearest vertexs
648
  try:
649
  if fixed_v is not None and fixed_f is not None and level != 0:
650
- new_mesh_v = new_mesh.verts_packed().cpu().numpy()
651
 
652
  fixed_v_cpu = fixed_v.cpu().numpy()
653
  kdtree_anchor = KDTree(fixed_v_cpu)
@@ -655,48 +621,60 @@ def infer_refine(meshes, imgs):
655
  _, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1)
656
  _, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25)
657
  idx_anchor = idx_anchor.squeeze()
658
- neighbors = torch.tensor(new_mesh_v)[idx_mesh_v] # V, 25, 3
659
  # calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25]
660
- neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v)[:, None], dim=-1)
661
  neighbor_dists[neighbor_dists > 0.06] = 114514.
662
  neighbor_weights = torch.exp(-neighbor_dists * 1.)
663
  neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
664
  anchors = fixed_v[idx_anchor] # V, 3
665
  anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
666
- dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v)) * anchor_normals).sum(-1), min=0) + 0.01
667
  vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
668
  vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
669
  weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
670
  new_mesh_v += weighted_vec_anchor.cpu().numpy()
671
 
672
  # replace new_mesh verts with new_mesh_v
673
- new_mesh = Meshes(verts=[torch.tensor(new_mesh_v)], faces=new_mesh.faces_list(), textures=new_mesh.textures)
674
 
675
  except Exception as e:
676
  pass
677
 
678
- notsimp_v, notsimp_f, notsimp_t = new_mesh.verts_packed(), new_mesh.faces_packed(), new_mesh.textures.verts_features_packed()
679
-
680
  if fixed_v is None:
681
  fixed_v, fixed_f = simp_v, simp_f
682
- complete_v, complete_f, complete_t = notsimp_v, notsimp_f, notsimp_t
683
  else:
684
  fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0)
685
  fixed_v = torch.cat([fixed_v, simp_v], dim=0)
686
-
687
- complete_f = torch.cat([complete_f, notsimp_f + complete_v.shape[0]], dim=0)
688
- complete_v = torch.cat([complete_v, notsimp_v], dim=0)
689
- complete_t = torch.cat([complete_t, notsimp_t], dim=0)
690
 
691
  if level == 2:
692
- new_mesh = Meshes(verts=[new_mesh.verts_packed()], faces=[new_mesh.faces_packed()], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[torch.ones_like(new_mesh.textures.verts_features_packed(), device=new_mesh.verts_packed().device)*0.5]))
693
 
694
- save_py3dmesh_with_trimesh_fast(new_mesh, meshes[name_idx].replace('.obj', '_refined.obj'), apply_sRGB_to_LinearRGB=False)
695
- results.append(meshes[name_idx].replace('.obj', '_refined.obj'))
 
 
 
 
 
 
 
 
696
 
697
  # save whole mesh
698
- save_py3dmesh_with_trimesh_fast(Meshes(verts=[complete_v], faces=[complete_f], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[complete_t])), meshes[name_idx].replace('.obj', '_refined_whole.obj'), apply_sRGB_to_LinearRGB=False)
699
- results.append(meshes[name_idx].replace('.obj', '_refined_whole.obj'))
 
 
 
 
 
 
 
 
 
700
 
701
  return results
702
 
@@ -749,7 +727,7 @@ def infer_slrm_make3d(images):
749
  return mesh_glb_fpaths
750
 
751
  @spaces.GPU
752
- def infer_slrm_make_mesh(mesh_fpath, planes, level=None):
753
  mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
754
  mesh_dirname = os.path.dirname(mesh_fpath)
755
 
@@ -757,19 +735,36 @@ def infer_slrm_make_mesh(mesh_fpath, planes, level=None):
757
  # get mesh
758
  mesh_out = infer_slrm_model.extract_mesh(
759
  planes,
760
- use_texture_map=False,
761
  levels=torch.tensor([level]).to(device),
762
  **infer_slrm_infer_config,
763
  )
764
 
765
- vertices, faces, vertex_colors = mesh_out
766
- vertices = vertices[:, [1, 2, 0]]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
767
 
768
- if level == 2:
769
- # fill all vertex_colors with 127
770
- vertex_colors = np.ones_like(vertex_colors) * 127
771
-
772
- save_obj(vertices, faces, vertex_colors, mesh_fpath)
773
 
774
  return mesh_fpath
775
 
 
12
  import numpy as np
13
 
14
  import torch
15
+ from pygltflib import GLTF2, Material, PbrMetallicRoughness
16
 
17
  from diffusers import AutoencoderKL, DDIMScheduler
18
  from diffusers.utils import check_min_version
 
73
  FOV_to_intrinsics,
74
  get_circular_camera_poses,
75
  )
76
+ from slrm.utils.mesh_util import save_obj, save_glb, save_obj_with_mtl
77
  from slrm.utils.infer_util import images_to_video
78
 
79
  import cv2
 
478
 
479
 
480
  @spaces.GPU
481
+ def get_distract_mask(color_0, color_1, normal_0=None, normal_1=None, thres=0.25, ratio=0.50, outside_thres=0.10, outside_ratio=0.20):
482
  distract_area = np.abs(color_0 - color_1).sum(axis=-1) > thres
483
  if normal_0 is not None and normal_1 is not None:
484
  distract_area |= np.abs(normal_0 - normal_1).sum(axis=-1) > thres
 
504
  max_x, max_y = bbox.max(axis=0)
505
  distract_bbox[min_x:max_x, min_y:max_y] = 1
506
 
507
+ return distract_mask, distract_bbox
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
508
 
509
 
510
  # infer_refine_sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
 
528
  distract_mask = None
529
 
530
  results = []
531
+ mesh_list = []
532
 
533
  for name_idx, level in zip([2, 0, 1], [2, 1, 0]):
534
  mesh = trimesh.load(meshes[name_idx])
 
573
  colors.append(color)
574
  normals.append(normal)
575
 
576
+ if last_front_color is not None and level == 0:
577
+ distract_mask, distract_bbox = get_distract_mask(last_front_color, np.array(colors[0]).astype(np.float32) / 255.0)
578
+ else:
579
+ distract_mask = None
580
+ distract_bbox = None
581
 
582
  if last_colors is None:
583
  from copy import deepcopy
 
591
  _, idx_anchor = kdtree_anchor.query(mesh_v, k=1)
592
  _, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25)
593
  idx_anchor = idx_anchor.squeeze()
594
+ neighbors = torch.tensor(mesh_v).cuda()[idx_mesh_v] # V, 25, 3
595
  # calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25]
596
+ neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v).cuda()[:, None], dim=-1)
597
  neighbor_dists[neighbor_dists > 0.06] = 114514.
598
  neighbor_weights = torch.exp(-neighbor_dists * 1.)
599
  neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
600
  anchors = fixed_v[idx_anchor] # V, 3
601
  anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
602
+ dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
603
  vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
604
  vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
605
  weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
 
613
  # my mesh flow weight by nearest vertexs
614
  try:
615
  if fixed_v is not None and fixed_f is not None and level != 0:
616
+ new_mesh_v = new_mesh.vertices.copy()
617
 
618
  fixed_v_cpu = fixed_v.cpu().numpy()
619
  kdtree_anchor = KDTree(fixed_v_cpu)
 
621
  _, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1)
622
  _, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25)
623
  idx_anchor = idx_anchor.squeeze()
624
+ neighbors = torch.tensor(new_mesh_v).cuda()[idx_mesh_v] # V, 25, 3
625
  # calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25]
626
+ neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v).cuda()[:, None], dim=-1)
627
  neighbor_dists[neighbor_dists > 0.06] = 114514.
628
  neighbor_weights = torch.exp(-neighbor_dists * 1.)
629
  neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
630
  anchors = fixed_v[idx_anchor] # V, 3
631
  anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
632
+ dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
633
  vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
634
  vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
635
  weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
636
  new_mesh_v += weighted_vec_anchor.cpu().numpy()
637
 
638
  # replace new_mesh verts with new_mesh_v
639
+ new_mesh.vertices = new_mesh_v
640
 
641
  except Exception as e:
642
  pass
643
 
 
 
644
  if fixed_v is None:
645
  fixed_v, fixed_f = simp_v, simp_f
 
646
  else:
647
  fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0)
648
  fixed_v = torch.cat([fixed_v, simp_v], dim=0)
649
+
650
+ mesh_list.append(new_mesh)
 
 
651
 
652
  if level == 2:
653
+ new_mesh = trimesh.Trimesh(simp_v.cpu().numpy(), simp_f.cpu().numpy(), process=False)
654
 
655
+ new_mesh.export(meshes[name_idx].replace('.obj', '_refined.glb'))
656
+ results.append(meshes[name_idx].replace('.obj', '_refined.glb'))
657
+
658
+ gltf = GLTF2().load(meshes[name_idx].replace('.obj', '_refined.glb'))
659
+ for material in gltf.materials:
660
+ if material.pbrMetallicRoughness:
661
+ material.pbrMetallicRoughness.baseColorFactor = [1.0, 1.0, 1.0, 100.0]
662
+ material.pbrMetallicRoughness.metallicFactor = 0.0
663
+ material.pbrMetallicRoughness.roughnessFactor = 1.0
664
+ gltf.save(meshes[name_idx].replace('.obj', '_refined.glb'))
665
 
666
  # save whole mesh
667
+ scene = trimesh.Scene(mesh_list)
668
+ scene.export(meshes[name_idx].replace('.obj', '_refined_whole.glb'))
669
+ results.append(meshes[name_idx].replace('.obj', '_refined_whole.glb'))
670
+
671
+ gltf = GLTF2().load(meshes[name_idx].replace('.obj', '_refined_whole.glb'))
672
+ for material in gltf.materials:
673
+ if material.pbrMetallicRoughness:
674
+ material.pbrMetallicRoughness.baseColorFactor = [1.0, 1.0, 1.0, 100.0]
675
+ material.pbrMetallicRoughness.metallicFactor = 0.0
676
+ material.pbrMetallicRoughness.roughnessFactor = 1.0
677
+ gltf.save(meshes[name_idx].replace('.obj', '_refined_whole.glb'))
678
 
679
  return results
680
 
 
727
  return mesh_glb_fpaths
728
 
729
  @spaces.GPU
730
+ def infer_slrm_make_mesh(mesh_fpath, planes, level=None, use_texture_map=False):
731
  mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
732
  mesh_dirname = os.path.dirname(mesh_fpath)
733
 
 
735
  # get mesh
736
  mesh_out = infer_slrm_model.extract_mesh(
737
  planes,
738
+ use_texture_map=use_texture_map,
739
  levels=torch.tensor([level]).to(device),
740
  **infer_slrm_infer_config,
741
  )
742
 
743
+ if use_texture_map:
744
+ vertices, faces, uvs, mesh_tex_idx, tex_map = mesh_out
745
+ vertices = vertices[:, [1, 2, 0]]
746
+ tex_map = tex_map.permute(1, 2, 0).data.cpu().numpy()
747
+
748
+ if level == 2:
749
+ # fill all vertex_colors with 127
750
+ tex_map = np.ones_like(tex_map) * 127
751
+ save_obj_with_mtl(
752
+ vertices.data.cpu().numpy(),
753
+ uvs.data.cpu().numpy(),
754
+ faces.data.cpu().numpy(),
755
+ mesh_tex_idx.data.cpu().numpy(),
756
+ tex_map,
757
+ mesh_fpath
758
+ )
759
+ else:
760
+ vertices, faces, vertex_colors = mesh_out
761
+ vertices = vertices[:, [1, 2, 0]]
762
 
763
+ if level == 2:
764
+ # fill all vertex_colors with 127
765
+ vertex_colors = np.ones_like(vertex_colors) * 127
766
+
767
+ save_obj(vertices, faces, vertex_colors, mesh_fpath)
768
 
769
  return mesh_fpath
770
 
refine/mesh_refine.py CHANGED
@@ -13,6 +13,104 @@ from refine.render import NormalsRenderer, calc_vertex_normals
13
 
14
  import pytorch3d
15
  from pytorch3d.structures import Meshes
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  def remove_color(arr):
18
  if arr.shape[-1] == 4:
@@ -301,11 +399,11 @@ def geo_refine_1(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_
301
  return mesh_v, mesh_f
302
 
303
  vertices, faces = reconstruct_stage1(rm_normals, steps=200, vertices=mesh_v, faces=mesh_f, fixed_v=fixed_v, fixed_f=fixed_f,
304
- lr=stage1_lr, remesh_interval=stage1_remesh_interval, start_edge_len=0.02,
305
- end_edge_len=0.005, gain=0.05, loss_expansion_weight=expansion_weight,
306
  distract_mask=distract_mask, distract_bbox=distract_bbox)
307
 
308
- vertices, faces = run_mesh_refine(vertices, faces, rm_normals, fixed_v=fixed_v, fixed_f=fixed_f, steps=100, start_edge_len=0.005, end_edge_len=0.0002,
309
  decay=0.99, update_normal_interval=20, update_warmup=5, process_inputs=False, process_outputs=False, remesh_interval=1)
310
  return vertices, faces
311
 
@@ -314,21 +412,77 @@ def geo_refine_2(vertices, faces, fixed_v=None):
314
  simp_vertices, simp_faces = meshes.verts_packed(), meshes.faces_packed()
315
  vertices, faces = simp_vertices.detach().cpu().numpy(), simp_faces.detach().cpu().numpy()
316
  # vertices, faces = trimesh.remesh.subdivide(vertices, faces)
317
- if fixed_v is not None:
318
- vertices, faces = trimesh.remesh.subdivide(vertices, faces)
319
  return vertices, faces
320
 
321
- def geo_refine_3(vertices, faces, rgb_ls, fixed_v=None, fixed_f=None, distract_mask=None):
322
- origin_len_v, origin_len_f = len(vertices), len(faces)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  # concatenate fixed_v and fixed_f
324
  if fixed_v is not None and fixed_f is not None:
325
- vertices, faces = np.concatenate([vertices, fixed_v.detach().cpu().numpy()], axis=0), np.concatenate([faces, fixed_f.detach().cpu().numpy() + len(vertices)], axis=0)
326
- vertices, faces = torch.tensor(vertices, device='cuda'), torch.tensor(faces, device='cuda')
327
  # reconstruct meshes
328
- meshes = Meshes(verts=[vertices], faces=[faces], textures=pytorch3d.renderer.mesh.textures.TexturesVertex([torch.zeros_like(vertices).float()]))
329
  new_meshes = multiview_color_projection(meshes, rgb_ls, resolution=1024, device="cuda", complete_unseen=True, confidence_threshold=0.2, cameras_list = get_cameras_list([180, 225, 270, 0, 90, 135], "cuda", focal=1/1.2), weights=[2.0, 0.5, 0.0, 1.0, 0.0, 0.5] if distract_mask is None else [2.0, 0.0, 0.5, 1.0, 0.5, 0.0], distract_mask=distract_mask)
330
- # exclude fixed_v and fixed_f
331
  if fixed_v is not None and fixed_f is not None:
332
- new_meshes = Meshes(verts=[new_meshes.verts_packed()[:origin_len_v]], faces=[new_meshes.faces_packed()[:origin_len_f]],
333
- textures=pytorch3d.renderer.mesh.textures.TexturesVertex([new_meshes.textures.verts_features_packed()[:origin_len_v]]))
334
- return new_meshes.to("cpu"), vertices.cpu(), faces.cpu()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  import pytorch3d
15
  from pytorch3d.structures import Meshes
16
+ import xatlas
17
+ import cv2
18
+
19
+
20
+ def mesh_uv_wrap(vertices, faces):
21
+ if len(faces) > 50000:
22
+ raise ValueError("The mesh has more than 50,000 faces, which is not supported.")
23
+
24
+ vmapping, indices, uvs = xatlas.parametrize(vertices, faces)
25
+ return vertices[vmapping], indices, uvs
26
+
27
+
28
+ def stride_from_shape(shape):
29
+ stride = [1]
30
+ for x in reversed(shape[1:]):
31
+ stride.append(stride[-1] * x)
32
+ return list(reversed(stride))
33
+
34
+ def scatter_add_nd_with_count(input, count, indices, values, weights=None):
35
+ # input: [..., C], D dimension + C channel
36
+ # count: [..., 1], D dimension
37
+ # indices: [N, D], long
38
+ # values: [N, C]
39
+
40
+ D = indices.shape[-1]
41
+ C = input.shape[-1]
42
+ size = input.shape[:-1]
43
+ stride = stride_from_shape(size)
44
+
45
+ assert len(size) == D
46
+
47
+ input = input.view(-1, C) # [HW, C]
48
+ count = count.view(-1, 1)
49
+
50
+ flatten_indices = (indices * torch.tensor(stride,
51
+ dtype=torch.long, device=indices.device)).sum(-1) # [N]
52
+
53
+ if weights is None:
54
+ weights = torch.ones_like(values[..., :1])
55
+
56
+ input.scatter_add_(0, flatten_indices.unsqueeze(1).repeat(1, C), values)
57
+ count.scatter_add_(0, flatten_indices.unsqueeze(1), weights)
58
+
59
+ return input.view(*size, C), count.view(*size, 1)
60
+
61
+
62
+ def linear_grid_put_2d(H, W, coords, values, return_count=False):
63
+ # coords: [N, 2], float in [0, 1]
64
+ # values: [N, C]
65
+
66
+ C = values.shape[-1]
67
+
68
+ indices = coords * torch.tensor(
69
+ [H - 1, W - 1], dtype=torch.float32, device=coords.device
70
+ )
71
+ indices_00 = indices.floor().long() # [N, 2]
72
+ indices_00[:, 0].clamp_(0, H - 2)
73
+ indices_00[:, 1].clamp_(0, W - 2)
74
+ indices_01 = indices_00 + torch.tensor(
75
+ [0, 1], dtype=torch.long, device=indices.device
76
+ )
77
+ indices_10 = indices_00 + torch.tensor(
78
+ [1, 0], dtype=torch.long, device=indices.device
79
+ )
80
+ indices_11 = indices_00 + torch.tensor(
81
+ [1, 1], dtype=torch.long, device=indices.device
82
+ )
83
+
84
+ h = indices[..., 0] - indices_00[..., 0].float()
85
+ w = indices[..., 1] - indices_00[..., 1].float()
86
+ w_00 = (1 - h) * (1 - w)
87
+ w_01 = (1 - h) * w
88
+ w_10 = h * (1 - w)
89
+ w_11 = h * w
90
+
91
+ result = torch.zeros(H, W, C, device=values.device,
92
+ dtype=values.dtype) # [H, W, C]
93
+ count = torch.zeros(H, W, 1, device=values.device,
94
+ dtype=values.dtype) # [H, W, 1]
95
+ weights = torch.ones_like(values[..., :1]) # [N, 1]
96
+
97
+ result, count = scatter_add_nd_with_count(
98
+ result, count, indices_00, values * w_00.unsqueeze(1), weights * w_00.unsqueeze(1))
99
+ result, count = scatter_add_nd_with_count(
100
+ result, count, indices_01, values * w_01.unsqueeze(1), weights * w_01.unsqueeze(1))
101
+ result, count = scatter_add_nd_with_count(
102
+ result, count, indices_10, values * w_10.unsqueeze(1), weights * w_10.unsqueeze(1))
103
+ result, count = scatter_add_nd_with_count(
104
+ result, count, indices_11, values * w_11.unsqueeze(1), weights * w_11.unsqueeze(1))
105
+
106
+ if return_count:
107
+ return result, count
108
+
109
+ mask = (count.squeeze(-1) > 0)
110
+ result[mask] = result[mask] / count[mask].repeat(1, C)
111
+
112
+ return result, count.squeeze(-1) == 0
113
+
114
 
115
  def remove_color(arr):
116
  if arr.shape[-1] == 4:
 
399
  return mesh_v, mesh_f
400
 
401
  vertices, faces = reconstruct_stage1(rm_normals, steps=200, vertices=mesh_v, faces=mesh_f, fixed_v=fixed_v, fixed_f=fixed_f,
402
+ lr=stage1_lr, remesh_interval=stage1_remesh_interval, start_edge_len=0.04,
403
+ end_edge_len=0.02, gain=0.05, loss_expansion_weight=expansion_weight,
404
  distract_mask=distract_mask, distract_bbox=distract_bbox)
405
 
406
+ vertices, faces = run_mesh_refine(vertices, faces, rm_normals, fixed_v=fixed_v, fixed_f=fixed_f, steps=100, start_edge_len=0.02, end_edge_len=0.001,
407
  decay=0.99, update_normal_interval=20, update_warmup=5, process_inputs=False, process_outputs=False, remesh_interval=1)
408
  return vertices, faces
409
 
 
412
  simp_vertices, simp_faces = meshes.verts_packed(), meshes.faces_packed()
413
  vertices, faces = simp_vertices.detach().cpu().numpy(), simp_faces.detach().cpu().numpy()
414
  # vertices, faces = trimesh.remesh.subdivide(vertices, faces)
 
 
415
  return vertices, faces
416
 
417
+ def geo_refine_3(vertices_, faces_, rgb_ls, fixed_v=None, fixed_f=None, distract_mask=None):
418
+ # vertices, faces, uvs = mesh_uv_wrap(vertices_, faces_)
419
+ vmapping, indices, uvs = xatlas.parametrize(vertices_, faces_)
420
+ vertices, faces = vertices_[vmapping], indices
421
+
422
+ def subdivide(vertices, faces, uvs):
423
+ vertices, faces = trimesh.remesh.subdivide(
424
+ vertices=np.hstack((vertices, uvs.copy())),
425
+ faces=faces
426
+ )
427
+ return vertices[:, :3], faces, vertices[:, 3:]
428
+
429
+ if fixed_v is not None:
430
+ dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs = subdivide(vertices, faces, uvs)
431
+ dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs = subdivide(dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs)
432
+ # dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs = subdivide(dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs)
433
+ dense_vertices, dense_faces = trimesh.remesh.subdivide(vertices_, faces_)
434
+ dense_vertices, dense_faces = trimesh.remesh.subdivide(dense_vertices, dense_faces)
435
+ # dense_vertices, dense_faces = trimesh.remesh.subdivide(dense_vertices, dense_faces)
436
+ else:
437
+ dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs = subdivide(vertices, faces, uvs)
438
+ dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs = subdivide(dense_atlas_vertices, dense_atlas_faces, dense_atlas_uvs)
439
+ dense_vertices, dense_faces = trimesh.remesh.subdivide(vertices_, faces_)
440
+ dense_vertices, dense_faces = trimesh.remesh.subdivide(dense_vertices, dense_faces)
441
+
442
+ origin_len_v, origin_len_f = len(dense_vertices), len(dense_faces)
443
+
444
  # concatenate fixed_v and fixed_f
445
  if fixed_v is not None and fixed_f is not None:
446
+ dense_vertices, dense_faces = np.concatenate([dense_vertices, fixed_v.detach().cpu().numpy()], axis=0), np.concatenate([dense_faces, fixed_f.detach().cpu().numpy() + len(vertices)], axis=0)
447
+ dense_vertices, dense_faces = torch.from_numpy(dense_vertices).cuda(), torch.from_numpy(dense_faces.astype('int32')).cuda()
448
  # reconstruct meshes
449
+ meshes = Meshes(verts=[dense_vertices], faces=[dense_faces], textures=pytorch3d.renderer.mesh.textures.TexturesVertex([torch.zeros_like(dense_vertices).float()]))
450
  new_meshes = multiview_color_projection(meshes, rgb_ls, resolution=1024, device="cuda", complete_unseen=True, confidence_threshold=0.2, cameras_list = get_cameras_list([180, 225, 270, 0, 90, 135], "cuda", focal=1/1.2), weights=[2.0, 0.5, 0.0, 1.0, 0.0, 0.5] if distract_mask is None else [2.0, 0.0, 0.5, 1.0, 0.5, 0.0], distract_mask=distract_mask)
451
+
452
  if fixed_v is not None and fixed_f is not None:
453
+ dense_vertices = dense_vertices[:origin_len_v]
454
+ dense_faces = dense_faces[:origin_len_f]
455
+ textures = new_meshes.textures.verts_features_packed()[:origin_len_v]
456
+ else:
457
+ textures = new_meshes.textures.verts_features_packed()
458
+
459
+ # distances = torch.cdist(torch.tensor(dense_atlas_vertices).cuda(), torch.tensor(dense_vertices).cuda())
460
+ # nearest_indices = torch.argmin(distances, dim=1)
461
+ # atlas_textures = textures[nearest_indices]
462
+
463
+ chunk_size = 500
464
+ atlas_textures_chunks = []
465
+ for i in range(0, len(dense_atlas_vertices), chunk_size):
466
+ chunk = dense_atlas_vertices[i:i+chunk_size]
467
+ distances = torch.cdist(torch.tensor(chunk).cuda(), torch.tensor(dense_vertices).cuda())
468
+ nearest_indices = torch.argmin(distances, dim=1)
469
+ atlas_textures_chunks.append(textures[nearest_indices])
470
+ atlas_textures = torch.cat(atlas_textures_chunks, dim=0)
471
+
472
+ dense_atlas_uvs = torch.tensor(dense_atlas_uvs, dtype=torch.float32).cuda()
473
+ tex_img, mask = linear_grid_put_2d(1024, 1024, dense_atlas_uvs, atlas_textures)
474
+ tex_img, mask = tex_img.cpu().numpy(), mask.cpu().numpy()
475
+ tex_img = cv2.inpaint((tex_img * 255).astype(np.uint8), (mask*255).astype('uint8'), 3, cv2.INPAINT_NS)
476
+ tex_img = Image.fromarray(np.transpose(tex_img,(1,0,2))[::-1])
477
+
478
+ mesh = trimesh.Trimesh(vertices, faces, process=False)
479
+ # material = trimesh.visual.texture.SimpleMaterial(image=tex_img, diffuse=(255, 255, 255))
480
+ material = trimesh.visual.material.PBRMaterial(
481
+ roughnessFactor=1.0,
482
+ baseColorTexture=tex_img,
483
+ baseColorFactor=np.array([255, 255, 255, 255], dtype=np.uint8)
484
+ )
485
+ texture_visuals = trimesh.visual.TextureVisuals(uv=uvs, image=tex_img, material=material)
486
+ mesh.visual = texture_visuals
487
+
488
+ return mesh, torch.tensor(vertices).cuda(), torch.tensor(faces.astype('int64')).cuda()
slrm/models/lrm_mesh.py CHANGED
@@ -116,13 +116,13 @@ class MeshSLRM(nn.Module):
116
  camera = OrthogonalCamera(device=device)
117
 
118
  with torch.cuda.amp.autocast(enabled=False):
119
- # renderer = NeuralRender(device, camera_model=camera)
120
  self.geometry = FlexiCubesGeometry(
121
  grid_res_xy=self.grid_res_xy,
122
  grid_res_z=self.grid_res_z,
123
  scale_xy=self.grid_scale_xy,
124
  scale_z=self.grid_scale_z,
125
- renderer=None,
126
  render_type='neural_render',
127
  device=device,
128
  )
 
116
  camera = OrthogonalCamera(device=device)
117
 
118
  with torch.cuda.amp.autocast(enabled=False):
119
+ renderer = NeuralRender(device, camera_model=camera)
120
  self.geometry = FlexiCubesGeometry(
121
  grid_res_xy=self.grid_res_xy,
122
  grid_res_z=self.grid_res_z,
123
  scale_xy=self.grid_scale_xy,
124
  scale_z=self.grid_scale_z,
125
+ renderer=renderer,
126
  render_type='neural_render',
127
  device=device,
128
  )