YulianSa commited on
Commit
13973ba
·
1 Parent(s): cfae272
Files changed (3) hide show
  1. app.py +2 -2
  2. infer_api.py +34 -33
  3. refine/mesh_refine.py +23 -12
app.py CHANGED
@@ -47,8 +47,8 @@ This is official demo for our CVPR 2025 paper <a href="">StdGEN: Semantic-Decomp
47
 
48
  Code: <a href='https://github.com/hyz317/StdGEN' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2411.05738' target='_blank'>ArXiv</a>.
49
 
50
- ❗️❗️❗️**Important Notes:**
51
- 1. Refinement stage takes about ~3.5min, and the mesh result may possibly delayed due to the server load, please wait patiently.
52
 
53
  2. You can upload any reference image (with or without background), A-pose images are also supported (white bkg required). If the image has an alpha channel (transparency), background segmentation will be automatically performed. Alternatively, you can pre-segment the background using other tools and upload the result directly.
54
 
 
47
 
48
  Code: <a href='https://github.com/hyz317/StdGEN' target='_blank'>GitHub</a>. Paper: <a href='https://arxiv.org/abs/2411.05738' target='_blank'>ArXiv</a>.
49
 
50
+ ❗️❗️❗️**Important Notes:** This is only a **PREVIEW** version with lower quality. We only perform color back-projection to clothes and hair. Please refer to GitHub repo for complete version.
51
+ 1. Refinement stage takes about ~2.5min, and the mesh result may possibly delayed due to the server load, please wait patiently.
52
 
53
  2. You can upload any reference image (with or without background), A-pose images are also supported (white bkg required). If the image has an alpha channel (transparency), background segmentation will be automatically performed. Alternatively, you can pre-segment the background using other tools and upload the result directly.
54
 
infer_api.py CHANGED
@@ -542,19 +542,19 @@ def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None,
542
  return distract_mask, distract_bbox, random_sampled_points, final_mask
543
 
544
 
545
- infer_refine_sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
546
- infer_refine_generator = SamAutomaticMaskGenerator(
547
- model=infer_refine_sam,
548
- points_per_side=64,
549
- pred_iou_thresh=0.80,
550
- stability_score_thresh=0.92,
551
- crop_n_layers=1,
552
- crop_n_points_downscale_factor=2,
553
- min_mask_region_area=100,
554
- )
555
  infer_refine_outside_ratio = 0.20
556
 
557
- @spaces.GPU(duration=150)
558
  def infer_refine(meshes, imgs):
559
  fixed_v, fixed_f, fixed_t = None, None, None
560
  flow_vert, flow_vector = None, None
@@ -564,7 +564,6 @@ def infer_refine(meshes, imgs):
564
 
565
  mv, proj = make_star_cameras_orthographic(8, 1, r=1.2)
566
  mv = mv[[4, 3, 2, 0, 6, 5]]
567
- renderer = NormalsRenderer(mv,proj,(1024,1024))
568
 
569
  results = []
570
 
@@ -576,12 +575,17 @@ def infer_refine(meshes, imgs):
576
  mesh_v, mesh_f = mesh.vertices, mesh.faces
577
 
578
  if last_colors is None:
579
- images = renderer.render(
580
- torch.tensor(mesh_v, device='cuda').float(),
581
- torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(),
582
- torch.tensor(mesh_f, device='cuda'),
583
- )
584
- mask = (images[..., 3] < 0.9).cpu().numpy()
 
 
 
 
 
585
 
586
  colors, normals = [], []
587
  for i in range(6):
@@ -604,18 +608,15 @@ def infer_refine(meshes, imgs):
604
  colors.append(color)
605
  normals.append(normal)
606
 
607
- if last_front_color is not None and level == 0:
608
- original_mask, distract_bbox, _, distract_mask = get_distract_mask(infer_refine_generator, last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=infer_refine_outside_ratio)
609
- else:
610
- distract_mask = None
611
- distract_bbox = None
612
-
613
- last_front_color = np.array(colors[0]).astype(np.float32) / 255.0
614
- last_front_normal = np.array(normals[0]).astype(np.float32) / 255.0
615
 
616
  if last_colors is None:
617
  from copy import deepcopy
618
- last_colors, last_normals = deepcopy(colors), deepcopy(normals)
619
 
620
  # my mesh flow weight by nearest vertexs
621
  if fixed_v is not None and fixed_f is not None and level == 1:
@@ -643,8 +644,8 @@ def infer_refine(meshes, imgs):
643
 
644
  t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
645
 
646
- mesh_v = torch.tensor(mesh_v, device='cuda', dtype=torch.float32)
647
- mesh_f = torch.tensor(mesh_f, device='cuda')
648
 
649
  new_mesh, simp_v, simp_f = geo_refine(mesh_v, mesh_f, colors, normals, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask, distract_bbox=distract_bbox)
650
 
@@ -659,22 +660,22 @@ def infer_refine(meshes, imgs):
659
  _, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1)
660
  _, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25)
661
  idx_anchor = idx_anchor.squeeze()
662
- neighbors = torch.tensor(new_mesh_v).cuda()[idx_mesh_v] # V, 25, 3
663
  # calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25]
664
- neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v).cuda()[:, None], dim=-1)
665
  neighbor_dists[neighbor_dists > 0.06] = 114514.
666
  neighbor_weights = torch.exp(-neighbor_dists * 1.)
667
  neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
668
  anchors = fixed_v[idx_anchor] # V, 3
669
  anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
670
- dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
671
  vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
672
  vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
673
  weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
674
  new_mesh_v += weighted_vec_anchor.cpu().numpy()
675
 
676
  # replace new_mesh verts with new_mesh_v
677
- new_mesh = Meshes(verts=[torch.tensor(new_mesh_v, device='cuda')], faces=new_mesh.faces_list(), textures=new_mesh.textures)
678
 
679
  except Exception as e:
680
  pass
 
542
  return distract_mask, distract_bbox, random_sampled_points, final_mask
543
 
544
 
545
+ # infer_refine_sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
546
+ # infer_refine_generator = SamAutomaticMaskGenerator(
547
+ # model=infer_refine_sam,
548
+ # points_per_side=64,
549
+ # pred_iou_thresh=0.80,
550
+ # stability_score_thresh=0.92,
551
+ # crop_n_layers=1,
552
+ # crop_n_points_downscale_factor=2,
553
+ # min_mask_region_area=100,
554
+ # )
555
  infer_refine_outside_ratio = 0.20
556
 
557
+
558
  def infer_refine(meshes, imgs):
559
  fixed_v, fixed_f, fixed_t = None, None, None
560
  flow_vert, flow_vector = None, None
 
564
 
565
  mv, proj = make_star_cameras_orthographic(8, 1, r=1.2)
566
  mv = mv[[4, 3, 2, 0, 6, 5]]
 
567
 
568
  results = []
569
 
 
575
  mesh_v, mesh_f = mesh.vertices, mesh.faces
576
 
577
  if last_colors is None:
578
+ @spaces.GPU()
579
+ def get_mask():
580
+ renderer = NormalsRenderer(mv,proj,(1024,1024))
581
+ images = renderer.render(
582
+ torch.tensor(mesh_v, device='cuda').float(),
583
+ torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(),
584
+ torch.tensor(mesh_f, device='cuda'),
585
+ )
586
+ mask = (images[..., 3] < 0.9).cpu().numpy()
587
+ return mask
588
+ mask = get_mask()
589
 
590
  colors, normals = [], []
591
  for i in range(6):
 
608
  colors.append(color)
609
  normals.append(normal)
610
 
611
+ # if last_front_color is not None and level == 0:
612
+ # original_mask, distract_bbox, _, distract_mask = get_distract_mask(infer_refine_generator, last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=infer_refine_outside_ratio)
613
+ # else:
614
+ distract_mask = None
615
+ distract_bbox = None
 
 
 
616
 
617
  if last_colors is None:
618
  from copy import deepcopy
619
+ last_colors = deepcopy(colors)
620
 
621
  # my mesh flow weight by nearest vertexs
622
  if fixed_v is not None and fixed_f is not None and level == 1:
 
644
 
645
  t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
646
 
647
+ mesh_v = torch.tensor(mesh_v, dtype=torch.float32)
648
+ mesh_f = torch.tensor(mesh_f)
649
 
650
  new_mesh, simp_v, simp_f = geo_refine(mesh_v, mesh_f, colors, normals, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask, distract_bbox=distract_bbox)
651
 
 
660
  _, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1)
661
  _, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25)
662
  idx_anchor = idx_anchor.squeeze()
663
+ neighbors = torch.tensor(new_mesh_v)[idx_mesh_v] # V, 25, 3
664
  # calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25]
665
+ neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v)[:, None], dim=-1)
666
  neighbor_dists[neighbor_dists > 0.06] = 114514.
667
  neighbor_weights = torch.exp(-neighbor_dists * 1.)
668
  neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
669
  anchors = fixed_v[idx_anchor] # V, 3
670
  anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
671
+ dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v)) * anchor_normals).sum(-1), min=0) + 0.01
672
  vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
673
  vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
674
  weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
675
  new_mesh_v += weighted_vec_anchor.cpu().numpy()
676
 
677
  # replace new_mesh verts with new_mesh_v
678
+ new_mesh = Meshes(verts=[torch.tensor(new_mesh_v)], faces=new_mesh.faces_list(), textures=new_mesh.textures)
679
 
680
  except Exception as e:
681
  pass
refine/mesh_refine.py CHANGED
@@ -1,4 +1,5 @@
1
  import torch
 
2
  import numpy as np
3
  import trimesh
4
  from PIL import Image
@@ -267,6 +268,15 @@ def run_mesh_refine(vertices, faces, pils: List[Image.Image], fixed_v=None, fixe
267
 
268
  def geo_refine(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=None, fixed_f=None,
269
  distract_mask=None, distract_bbox=None, thres=3e-6, no_decompose=False):
 
 
 
 
 
 
 
 
 
270
  rm_normals = simple_remove(normal_ls)
271
 
272
  # transfer the alpha channel of rm_normals to img_list
@@ -282,6 +292,9 @@ def geo_refine(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=
282
  if no_decompose:
283
  stage1_lr = 0.03
284
  stage1_remesh_interval = 30
 
 
 
285
 
286
  vertices, faces = reconstruct_stage1(rm_normals, steps=200, vertices=mesh_v, faces=mesh_f, fixed_v=fixed_v, fixed_f=fixed_f,
287
  lr=stage1_lr, remesh_interval=stage1_remesh_interval, start_edge_len=0.02,
@@ -290,20 +303,18 @@ def geo_refine(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=
290
 
291
  vertices, faces = run_mesh_refine(vertices, faces, rm_normals, fixed_v=fixed_v, fixed_f=fixed_f, steps=100, start_edge_len=0.005, end_edge_len=0.0002,
292
  decay=0.99, update_normal_interval=20, update_warmup=5, process_inputs=False, process_outputs=False, remesh_interval=1)
293
- meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=2, apply_sub_divide=False, sub_divide_threshold=0.25).to("cuda")
294
- # subdivide meshes
 
 
295
  simp_vertices, simp_faces = meshes.verts_packed(), meshes.faces_packed()
296
  vertices, faces = simp_vertices.detach().cpu().numpy(), simp_faces.detach().cpu().numpy()
 
 
 
297
 
298
- mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False)
299
- mesh = merge_small_faces(mesh, thres=thres)
300
- new_mesh = mesh.split(only_watertight=False)
301
-
302
- new_mesh = [ j for j in new_mesh if len(j.vertices) >= 200 ]
303
- mesh = trimesh.Scene(new_mesh).dump(concatenate=True)
304
- vertices, faces = mesh.vertices.astype('float32'), mesh.faces
305
-
306
- vertices, faces = trimesh.remesh.subdivide(vertices, faces)
307
  origin_len_v, origin_len_f = len(vertices), len(faces)
308
  # concatenate fixed_v and fixed_f
309
  if fixed_v is not None and fixed_f is not None:
@@ -316,4 +327,4 @@ def geo_refine(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=
316
  if fixed_v is not None and fixed_f is not None:
317
  new_meshes = Meshes(verts=[new_meshes.verts_packed()[:origin_len_v]], faces=[new_meshes.faces_packed()[:origin_len_f]],
318
  textures=pytorch3d.renderer.mesh.textures.TexturesVertex([new_meshes.textures.verts_features_packed()[:origin_len_v]]))
319
- return new_meshes, simp_vertices, simp_faces
 
1
  import torch
2
+ import spaces
3
  import numpy as np
4
  import trimesh
5
  from PIL import Image
 
268
 
269
  def geo_refine(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=None, fixed_f=None,
270
  distract_mask=None, distract_bbox=None, thres=3e-6, no_decompose=False):
271
+ vertices, faces = geo_refine_1(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=expansion_weight, fixed_v=fixed_v, fixed_f=fixed_f,
272
+ distract_mask=distract_mask, distract_bbox=distract_bbox, thres=thres, no_decompose=no_decompose)
273
+ vertices, faces = geo_refine_2(vertices, faces, fixed_v=fixed_v)
274
+ return geo_refine_3(vertices, faces, rgb_ls, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask)
275
+
276
+ @spaces.GPU()
277
+ def geo_refine_1(mesh_v, mesh_f, rgb_ls, normal_ls, expansion_weight=0.1, fixed_v=None, fixed_f=None,
278
+ distract_mask=None, distract_bbox=None, thres=3e-6, no_decompose=False):
279
+
280
  rm_normals = simple_remove(normal_ls)
281
 
282
  # transfer the alpha channel of rm_normals to img_list
 
292
  if no_decompose:
293
  stage1_lr = 0.03
294
  stage1_remesh_interval = 30
295
+
296
+ if fixed_v is not None:
297
+ return mesh_v, mesh_f
298
 
299
  vertices, faces = reconstruct_stage1(rm_normals, steps=200, vertices=mesh_v, faces=mesh_f, fixed_v=fixed_v, fixed_f=fixed_f,
300
  lr=stage1_lr, remesh_interval=stage1_remesh_interval, start_edge_len=0.02,
 
303
 
304
  vertices, faces = run_mesh_refine(vertices, faces, rm_normals, fixed_v=fixed_v, fixed_f=fixed_f, steps=100, start_edge_len=0.005, end_edge_len=0.0002,
305
  decay=0.99, update_normal_interval=20, update_warmup=5, process_inputs=False, process_outputs=False, remesh_interval=1)
306
+ return vertices, faces
307
+
308
+ def geo_refine_2(vertices, faces, fixed_v=None):
309
+ meshes = simple_clean_mesh(to_pyml_mesh(vertices, faces), apply_smooth=True, stepsmoothnum=2, apply_sub_divide=False, sub_divide_threshold=0.25)
310
  simp_vertices, simp_faces = meshes.verts_packed(), meshes.faces_packed()
311
  vertices, faces = simp_vertices.detach().cpu().numpy(), simp_faces.detach().cpu().numpy()
312
+ if fixed_v is not None:
313
+ vertices, faces = trimesh.remesh.subdivide(vertices, faces)
314
+ return vertices, faces
315
 
316
+ @spaces.GPU()
317
+ def geo_refine_3(vertices, faces, rgb_ls, fixed_v=None, fixed_f=None, distract_mask=None):
 
 
 
 
 
 
 
318
  origin_len_v, origin_len_f = len(vertices), len(faces)
319
  # concatenate fixed_v and fixed_f
320
  if fixed_v is not None and fixed_f is not None:
 
327
  if fixed_v is not None and fixed_f is not None:
328
  new_meshes = Meshes(verts=[new_meshes.verts_packed()[:origin_len_v]], faces=[new_meshes.faces_packed()[:origin_len_f]],
329
  textures=pytorch3d.renderer.mesh.textures.TexturesVertex([new_meshes.textures.verts_features_packed()[:origin_len_v]]))
330
+ return new_meshes.to("cpu"), vertices, faces