Spaces:
Running
on
L40S
Running
on
L40S
update
Browse files- infer_api.py +217 -221
infer_api.py
CHANGED
@@ -107,7 +107,7 @@ for file in all_files:
|
|
107 |
hf_hub_download(repo_id, file, local_dir="./ckpt")
|
108 |
|
109 |
@spaces.GPU
|
110 |
-
def
|
111 |
random.seed(seed)
|
112 |
np.random.seed(seed)
|
113 |
torch.manual_seed(seed)
|
@@ -174,7 +174,7 @@ def process_image(image, totensor, width, height):
|
|
174 |
def inference(validation_pipeline, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
|
175 |
text_encoder, pretrained_model_path, validation, val_width, val_height, unet_condition_type,
|
176 |
use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
|
177 |
-
|
178 |
generator = torch.Generator(device=device).manual_seed(seed)
|
179 |
|
180 |
totensor = transforms.ToTensor()
|
@@ -372,10 +372,10 @@ class InferAPI:
|
|
372 |
return infer_multiview_gen(img, seed, num_levels)
|
373 |
|
374 |
def genStage3(self, img):
|
375 |
-
return
|
376 |
|
377 |
def genStage4(self, meshes, imgs):
|
378 |
-
return
|
379 |
|
380 |
|
381 |
############## Refine ##############
|
@@ -400,6 +400,7 @@ def srgb_to_linear(c_srgb):
|
|
400 |
return c_linear.clip(0, 1.)
|
401 |
|
402 |
|
|
|
403 |
def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
|
404 |
# convert from pytorch3d meshes to trimesh mesh
|
405 |
vertices = meshes.verts_packed().cpu().float().numpy()
|
@@ -515,245 +516,240 @@ def get_distract_mask(generator, color_0, color_1, normal_0=None, normal_1=None,
|
|
515 |
return distract_mask, distract_bbox, random_sampled_points, final_mask
|
516 |
|
517 |
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
|
526 |
-
|
527 |
-
|
528 |
-
|
529 |
-
min_mask_region_area=100,
|
530 |
-
)
|
531 |
-
self.outside_ratio = 0.20
|
532 |
-
|
533 |
-
@spaces.GPU
|
534 |
-
def refine(self, meshes, imgs):
|
535 |
-
fixed_v, fixed_f, fixed_t = None, None, None
|
536 |
-
flow_vert, flow_vector = None, None
|
537 |
-
last_colors, last_normals = None, None
|
538 |
-
last_front_color, last_front_normal = None, None
|
539 |
-
distract_mask = None
|
540 |
-
|
541 |
-
mv, proj = make_star_cameras_orthographic(8, 1, r=1.2)
|
542 |
-
mv = mv[[4, 3, 2, 0, 6, 5]]
|
543 |
-
renderer = NormalsRenderer(mv,proj,(1024,1024))
|
544 |
-
|
545 |
-
results = []
|
546 |
-
|
547 |
-
for name_idx, level in zip([2, 0, 1], [2, 1, 0]):
|
548 |
-
mesh = trimesh.load(meshes[name_idx])
|
549 |
-
new_mesh = mesh.split(only_watertight=False)
|
550 |
-
new_mesh = [ j for j in new_mesh if len(j.vertices) >= 300 ]
|
551 |
-
mesh = trimesh.Scene(new_mesh).dump(concatenate=True)
|
552 |
-
mesh_v, mesh_f = mesh.vertices, mesh.faces
|
553 |
-
|
554 |
-
if last_colors is None:
|
555 |
-
images = renderer.render(
|
556 |
-
torch.tensor(mesh_v, device='cuda').float(),
|
557 |
-
torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(),
|
558 |
-
torch.tensor(mesh_f, device='cuda'),
|
559 |
-
)
|
560 |
-
mask = (images[..., 3] < 0.9).cpu().numpy()
|
561 |
-
|
562 |
-
colors, normals = [], []
|
563 |
-
for i in range(6):
|
564 |
-
color = np.array(imgs[level]['images'][i])
|
565 |
-
normal = np.array(imgs[level]['normals'][i])
|
566 |
-
|
567 |
-
if last_colors is not None:
|
568 |
-
offset = calc_horizontal_offset(np.array(last_colors[i]), color)
|
569 |
-
# print('offset', i, offset)
|
570 |
-
else:
|
571 |
-
offset = calc_horizontal_offset2(mask[i], color)
|
572 |
-
# print('init offset', i, offset)
|
573 |
|
574 |
-
|
575 |
-
|
576 |
-
|
|
|
|
|
|
|
|
|
577 |
|
578 |
-
|
579 |
-
|
580 |
-
|
581 |
-
normals.append(normal)
|
582 |
|
583 |
-
|
584 |
-
original_mask, distract_bbox, _, distract_mask = get_distract_mask(self.generator, last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=self.outside_ratio)
|
585 |
-
else:
|
586 |
-
distract_mask = None
|
587 |
-
distract_bbox = None
|
588 |
|
589 |
-
|
590 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
591 |
|
592 |
-
|
593 |
-
|
594 |
-
|
|
|
595 |
|
596 |
-
|
597 |
-
|
598 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
599 |
|
600 |
fixed_v_cpu = fixed_v.cpu().numpy()
|
601 |
kdtree_anchor = KDTree(fixed_v_cpu)
|
602 |
-
kdtree_mesh_v = KDTree(
|
603 |
-
_, idx_anchor = kdtree_anchor.query(
|
604 |
-
_, idx_mesh_v = kdtree_mesh_v.query(
|
605 |
idx_anchor = idx_anchor.squeeze()
|
606 |
-
neighbors = torch.tensor(
|
607 |
-
# calculate the distances neighbors [V, 25, 3];
|
608 |
-
neighbor_dists = torch.norm(neighbors - torch.tensor(
|
609 |
neighbor_dists[neighbor_dists > 0.06] = 114514.
|
610 |
neighbor_weights = torch.exp(-neighbor_dists * 1.)
|
611 |
neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
|
612 |
anchors = fixed_v[idx_anchor] # V, 3
|
613 |
anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
|
614 |
-
dis_anchor = torch.clamp(((anchors - torch.tensor(
|
615 |
vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
|
616 |
vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
|
617 |
weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
mesh_v = torch.tensor(mesh_v, device='cuda', dtype=torch.float32)
|
623 |
-
mesh_f = torch.tensor(mesh_f, device='cuda')
|
624 |
-
|
625 |
-
new_mesh, simp_v, simp_f = geo_refine(mesh_v, mesh_f, colors, normals, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask, distract_bbox=distract_bbox)
|
626 |
-
|
627 |
-
# my mesh flow weight by nearest vertexs
|
628 |
-
try:
|
629 |
-
if fixed_v is not None and fixed_f is not None and level != 0:
|
630 |
-
new_mesh_v = new_mesh.verts_packed().cpu().numpy()
|
631 |
-
|
632 |
-
fixed_v_cpu = fixed_v.cpu().numpy()
|
633 |
-
kdtree_anchor = KDTree(fixed_v_cpu)
|
634 |
-
kdtree_mesh_v = KDTree(new_mesh_v)
|
635 |
-
_, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1)
|
636 |
-
_, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25)
|
637 |
-
idx_anchor = idx_anchor.squeeze()
|
638 |
-
neighbors = torch.tensor(new_mesh_v).cuda()[idx_mesh_v] # V, 25, 3
|
639 |
-
# calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25]
|
640 |
-
neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v).cuda()[:, None], dim=-1)
|
641 |
-
neighbor_dists[neighbor_dists > 0.06] = 114514.
|
642 |
-
neighbor_weights = torch.exp(-neighbor_dists * 1.)
|
643 |
-
neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
|
644 |
-
anchors = fixed_v[idx_anchor] # V, 3
|
645 |
-
anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
|
646 |
-
dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
|
647 |
-
vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
|
648 |
-
vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
|
649 |
-
weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
|
650 |
-
new_mesh_v += weighted_vec_anchor.cpu().numpy()
|
651 |
-
|
652 |
-
# replace new_mesh verts with new_mesh_v
|
653 |
-
new_mesh = Meshes(verts=[torch.tensor(new_mesh_v, device='cuda')], faces=new_mesh.faces_list(), textures=new_mesh.textures)
|
654 |
-
|
655 |
-
except Exception as e:
|
656 |
-
pass
|
657 |
-
|
658 |
-
notsimp_v, notsimp_f, notsimp_t = new_mesh.verts_packed(), new_mesh.faces_packed(), new_mesh.textures.verts_features_packed()
|
659 |
-
|
660 |
-
if fixed_v is None:
|
661 |
-
fixed_v, fixed_f = simp_v, simp_f
|
662 |
-
complete_v, complete_f, complete_t = notsimp_v, notsimp_f, notsimp_t
|
663 |
-
else:
|
664 |
-
fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0)
|
665 |
-
fixed_v = torch.cat([fixed_v, simp_v], dim=0)
|
666 |
-
|
667 |
-
complete_f = torch.cat([complete_f, notsimp_f + complete_v.shape[0]], dim=0)
|
668 |
-
complete_v = torch.cat([complete_v, notsimp_v], dim=0)
|
669 |
-
complete_t = torch.cat([complete_t, notsimp_t], dim=0)
|
670 |
-
|
671 |
-
if level == 2:
|
672 |
-
new_mesh = Meshes(verts=[new_mesh.verts_packed()], faces=[new_mesh.faces_packed()], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[torch.ones_like(new_mesh.textures.verts_features_packed(), device=new_mesh.verts_packed().device)*0.5]))
|
673 |
-
|
674 |
-
save_py3dmesh_with_trimesh_fast(new_mesh, meshes[name_idx].replace('.obj', '_refined.obj'), apply_sRGB_to_LinearRGB=False)
|
675 |
-
results.append(meshes[name_idx].replace('.obj', '_refined.obj'))
|
676 |
-
|
677 |
-
# save whole mesh
|
678 |
-
save_py3dmesh_with_trimesh_fast(Meshes(verts=[complete_v], faces=[complete_f], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[complete_t])), meshes[name_idx].replace('.obj', '_refined_whole.obj'), apply_sRGB_to_LinearRGB=False)
|
679 |
-
results.append(meshes[name_idx].replace('.obj', '_refined_whole.obj'))
|
680 |
-
|
681 |
-
return results
|
682 |
-
|
683 |
-
|
684 |
-
class InferSlrmAPI:
|
685 |
-
@spaces.GPU
|
686 |
-
def __init__(self, config):
|
687 |
-
self.config_path = config['config_path']
|
688 |
-
self.config = OmegaConf.load(self.config_path)
|
689 |
-
self.config_name = os.path.basename(self.config_path).replace('.yaml', '')
|
690 |
-
self.model_config = self.config.model_config
|
691 |
-
self.infer_config = self.config.infer_config
|
692 |
-
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
693 |
-
self.model = instantiate_from_config(self.model_config)
|
694 |
-
state_dict = torch.load(self.infer_config.model_path, map_location='cpu')
|
695 |
-
self.model.load_state_dict(state_dict, strict=False)
|
696 |
-
self.model = self.model.to(self.device)
|
697 |
-
self.model.init_flexicubes_geometry(self.device, fovy=30.0, is_ortho=self.model.is_ortho)
|
698 |
-
self.model = self.model.eval()
|
699 |
-
|
700 |
-
@spaces.GPU
|
701 |
-
def gen(self, imgs):
|
702 |
-
imgs = [ cv2.imread(img[0])[:, :, ::-1] for img in imgs ]
|
703 |
-
imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0
|
704 |
-
imgs = torch.from_numpy(np.array(imgs)).permute(0, 3, 1, 2).contiguous().float() # (6, 3, 1024, 1024)
|
705 |
-
mesh_glb_fpaths = self.make3d(imgs)
|
706 |
-
return mesh_glb_fpaths[1:4] + mesh_glb_fpaths[0:1]
|
707 |
-
|
708 |
-
@spaces.GPU
|
709 |
-
def make3d(self, images):
|
710 |
-
input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device)
|
711 |
-
|
712 |
-
images = images.unsqueeze(0).to(device)
|
713 |
-
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
|
714 |
-
|
715 |
-
mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
|
716 |
-
print(mesh_fpath)
|
717 |
-
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
718 |
-
mesh_dirname = os.path.dirname(mesh_fpath)
|
719 |
-
|
720 |
-
with torch.no_grad():
|
721 |
-
# get triplane
|
722 |
-
planes = self.model.forward_planes(images, input_cameras.float())
|
723 |
-
|
724 |
-
# get mesh
|
725 |
-
mesh_glb_fpaths = []
|
726 |
-
for j in range(4):
|
727 |
-
mesh_glb_fpath = self.make_mesh(mesh_fpath.replace(mesh_fpath[-4:], f'_{j}{mesh_fpath[-4:]}'), planes, level=[0, 3, 4, 2][j])
|
728 |
-
mesh_glb_fpaths.append(mesh_glb_fpath)
|
729 |
-
|
730 |
-
return mesh_glb_fpaths
|
731 |
-
|
732 |
-
@spaces.GPU
|
733 |
-
def make_mesh(self, mesh_fpath, planes, level=None):
|
734 |
-
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
735 |
-
mesh_dirname = os.path.dirname(mesh_fpath)
|
736 |
-
mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb")
|
737 |
-
|
738 |
-
with torch.no_grad():
|
739 |
-
# get mesh
|
740 |
-
mesh_out = self.model.extract_mesh(
|
741 |
-
planes,
|
742 |
-
use_texture_map=False,
|
743 |
-
levels=torch.tensor([level]).to(device),
|
744 |
-
**self.infer_config,
|
745 |
-
)
|
746 |
|
747 |
-
|
748 |
-
|
749 |
|
750 |
-
|
751 |
-
|
752 |
-
|
|
|
|
|
|
|
|
|
|
|
753 |
|
754 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
755 |
|
756 |
-
|
757 |
|
758 |
|
759 |
parser = argparse.ArgumentParser()
|
|
|
107 |
hf_hub_download(repo_id, file, local_dir="./ckpt")
|
108 |
|
109 |
@spaces.GPU
|
110 |
+
def set_seed22(seed):
|
111 |
random.seed(seed)
|
112 |
np.random.seed(seed)
|
113 |
torch.manual_seed(seed)
|
|
|
174 |
def inference(validation_pipeline, input_image, vae, feature_extractor, image_encoder, unet, ref_unet, tokenizer,
|
175 |
text_encoder, pretrained_model_path, validation, val_width, val_height, unet_condition_type,
|
176 |
use_noise=True, noise_d=256, crop=False, seed=100, timestep=20):
|
177 |
+
set_seed2(seed)
|
178 |
generator = torch.Generator(device=device).manual_seed(seed)
|
179 |
|
180 |
totensor = transforms.ToTensor()
|
|
|
372 |
return infer_multiview_gen(img, seed, num_levels)
|
373 |
|
374 |
def genStage3(self, img):
|
375 |
+
return infer_slrm_gen(img)
|
376 |
|
377 |
def genStage4(self, meshes, imgs):
|
378 |
+
return infer_refine(meshes, imgs)
|
379 |
|
380 |
|
381 |
############## Refine ##############
|
|
|
400 |
return c_linear.clip(0, 1.)
|
401 |
|
402 |
|
403 |
+
@spaces.GPU
|
404 |
def save_py3dmesh_with_trimesh_fast(meshes: Meshes, save_glb_path, apply_sRGB_to_LinearRGB=True):
|
405 |
# convert from pytorch3d meshes to trimesh mesh
|
406 |
vertices = meshes.verts_packed().cpu().float().numpy()
|
|
|
516 |
return distract_mask, distract_bbox, random_sampled_points, final_mask
|
517 |
|
518 |
|
519 |
+
infer_refine_sam = sam_model_registry["vit_h"](checkpoint="./ckpt/sam_vit_h_4b8939.pth").cuda()
|
520 |
+
infer_refine_generator = SamAutomaticMaskGenerator(
|
521 |
+
model=infer_refine_sam,
|
522 |
+
points_per_side=64,
|
523 |
+
pred_iou_thresh=0.80,
|
524 |
+
stability_score_thresh=0.92,
|
525 |
+
crop_n_layers=1,
|
526 |
+
crop_n_points_downscale_factor=2,
|
527 |
+
min_mask_region_area=100,
|
528 |
+
)
|
529 |
+
infer_refine_outside_ratio = 0.20
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
530 |
|
531 |
+
@spaces.GPU
|
532 |
+
def infer_refine(meshes, imgs):
|
533 |
+
fixed_v, fixed_f, fixed_t = None, None, None
|
534 |
+
flow_vert, flow_vector = None, None
|
535 |
+
last_colors, last_normals = None, None
|
536 |
+
last_front_color, last_front_normal = None, None
|
537 |
+
distract_mask = None
|
538 |
|
539 |
+
mv, proj = make_star_cameras_orthographic(8, 1, r=1.2)
|
540 |
+
mv = mv[[4, 3, 2, 0, 6, 5]]
|
541 |
+
renderer = NormalsRenderer(mv,proj,(1024,1024))
|
|
|
542 |
|
543 |
+
results = []
|
|
|
|
|
|
|
|
|
544 |
|
545 |
+
for name_idx, level in zip([2, 0, 1], [2, 1, 0]):
|
546 |
+
mesh = trimesh.load(meshes[name_idx])
|
547 |
+
new_mesh = mesh.split(only_watertight=False)
|
548 |
+
new_mesh = [ j for j in new_mesh if len(j.vertices) >= 300 ]
|
549 |
+
mesh = trimesh.Scene(new_mesh).dump(concatenate=True)
|
550 |
+
mesh_v, mesh_f = mesh.vertices, mesh.faces
|
551 |
+
|
552 |
+
if last_colors is None:
|
553 |
+
images = renderer.render(
|
554 |
+
torch.tensor(mesh_v, device='cuda').float(),
|
555 |
+
torch.ones_like(torch.from_numpy(mesh_v), device='cuda').float(),
|
556 |
+
torch.tensor(mesh_f, device='cuda'),
|
557 |
+
)
|
558 |
+
mask = (images[..., 3] < 0.9).cpu().numpy()
|
559 |
|
560 |
+
colors, normals = [], []
|
561 |
+
for i in range(6):
|
562 |
+
color = np.array(imgs[level]['images'][i])
|
563 |
+
normal = np.array(imgs[level]['normals'][i])
|
564 |
|
565 |
+
if last_colors is not None:
|
566 |
+
offset = calc_horizontal_offset(np.array(last_colors[i]), color)
|
567 |
+
# print('offset', i, offset)
|
568 |
+
else:
|
569 |
+
offset = calc_horizontal_offset2(mask[i], color)
|
570 |
+
# print('init offset', i, offset)
|
571 |
+
|
572 |
+
if offset != 0:
|
573 |
+
color = np.roll(color, offset, axis=1)
|
574 |
+
normal = np.roll(normal, offset, axis=1)
|
575 |
+
|
576 |
+
color = Image.fromarray(color)
|
577 |
+
normal = Image.fromarray(normal)
|
578 |
+
colors.append(color)
|
579 |
+
normals.append(normal)
|
580 |
+
|
581 |
+
if last_front_color is not None and level == 0:
|
582 |
+
original_mask, distract_bbox, _, distract_mask = get_distract_mask(infer_refine_generator, last_front_color, np.array(colors[0]).astype(np.float32) / 255.0, outside_ratio=infer_refine_outside_ratio)
|
583 |
+
else:
|
584 |
+
distract_mask = None
|
585 |
+
distract_bbox = None
|
586 |
+
|
587 |
+
last_front_color = np.array(colors[0]).astype(np.float32) / 255.0
|
588 |
+
last_front_normal = np.array(normals[0]).astype(np.float32) / 255.0
|
589 |
+
|
590 |
+
if last_colors is None:
|
591 |
+
from copy import deepcopy
|
592 |
+
last_colors, last_normals = deepcopy(colors), deepcopy(normals)
|
593 |
+
|
594 |
+
# my mesh flow weight by nearest vertexs
|
595 |
+
if fixed_v is not None and fixed_f is not None and level == 1:
|
596 |
+
t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
|
597 |
+
|
598 |
+
fixed_v_cpu = fixed_v.cpu().numpy()
|
599 |
+
kdtree_anchor = KDTree(fixed_v_cpu)
|
600 |
+
kdtree_mesh_v = KDTree(mesh_v)
|
601 |
+
_, idx_anchor = kdtree_anchor.query(mesh_v, k=1)
|
602 |
+
_, idx_mesh_v = kdtree_mesh_v.query(mesh_v, k=25)
|
603 |
+
idx_anchor = idx_anchor.squeeze()
|
604 |
+
neighbors = torch.tensor(mesh_v).cuda()[idx_mesh_v] # V, 25, 3
|
605 |
+
# calculate the distances neighbors [V, 25, 3]; mesh_v [V, 3] -> [V, 25]
|
606 |
+
neighbor_dists = torch.norm(neighbors - torch.tensor(mesh_v).cuda()[:, None], dim=-1)
|
607 |
+
neighbor_dists[neighbor_dists > 0.06] = 114514.
|
608 |
+
neighbor_weights = torch.exp(-neighbor_dists * 1.)
|
609 |
+
neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
|
610 |
+
anchors = fixed_v[idx_anchor] # V, 3
|
611 |
+
anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
|
612 |
+
dis_anchor = torch.clamp(((anchors - torch.tensor(mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
|
613 |
+
vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
|
614 |
+
vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
|
615 |
+
weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
|
616 |
+
mesh_v += weighted_vec_anchor.cpu().numpy()
|
617 |
+
|
618 |
+
t = trimesh.Trimesh(vertices=mesh_v, faces=mesh_f)
|
619 |
+
|
620 |
+
mesh_v = torch.tensor(mesh_v, device='cuda', dtype=torch.float32)
|
621 |
+
mesh_f = torch.tensor(mesh_f, device='cuda')
|
622 |
+
|
623 |
+
new_mesh, simp_v, simp_f = geo_refine(mesh_v, mesh_f, colors, normals, fixed_v=fixed_v, fixed_f=fixed_f, distract_mask=distract_mask, distract_bbox=distract_bbox)
|
624 |
+
|
625 |
+
# my mesh flow weight by nearest vertexs
|
626 |
+
try:
|
627 |
+
if fixed_v is not None and fixed_f is not None and level != 0:
|
628 |
+
new_mesh_v = new_mesh.verts_packed().cpu().numpy()
|
629 |
|
630 |
fixed_v_cpu = fixed_v.cpu().numpy()
|
631 |
kdtree_anchor = KDTree(fixed_v_cpu)
|
632 |
+
kdtree_mesh_v = KDTree(new_mesh_v)
|
633 |
+
_, idx_anchor = kdtree_anchor.query(new_mesh_v, k=1)
|
634 |
+
_, idx_mesh_v = kdtree_mesh_v.query(new_mesh_v, k=25)
|
635 |
idx_anchor = idx_anchor.squeeze()
|
636 |
+
neighbors = torch.tensor(new_mesh_v).cuda()[idx_mesh_v] # V, 25, 3
|
637 |
+
# calculate the distances neighbors [V, 25, 3]; new_mesh_v [V, 3] -> [V, 25]
|
638 |
+
neighbor_dists = torch.norm(neighbors - torch.tensor(new_mesh_v).cuda()[:, None], dim=-1)
|
639 |
neighbor_dists[neighbor_dists > 0.06] = 114514.
|
640 |
neighbor_weights = torch.exp(-neighbor_dists * 1.)
|
641 |
neighbor_weights = neighbor_weights / neighbor_weights.sum(dim=1, keepdim=True)
|
642 |
anchors = fixed_v[idx_anchor] # V, 3
|
643 |
anchor_normals = calc_vertex_normals(fixed_v, fixed_f)[idx_anchor] # V, 3
|
644 |
+
dis_anchor = torch.clamp(((anchors - torch.tensor(new_mesh_v).cuda()) * anchor_normals).sum(-1), min=0) + 0.01
|
645 |
vec_anchor = dis_anchor[:, None] * anchor_normals # V, 3
|
646 |
vec_anchor = vec_anchor[idx_mesh_v] # V, 25, 3
|
647 |
weighted_vec_anchor = (vec_anchor * neighbor_weights[:, :, None]).sum(1) # V, 3
|
648 |
+
new_mesh_v += weighted_vec_anchor.cpu().numpy()
|
649 |
+
|
650 |
+
# replace new_mesh verts with new_mesh_v
|
651 |
+
new_mesh = Meshes(verts=[torch.tensor(new_mesh_v, device='cuda')], faces=new_mesh.faces_list(), textures=new_mesh.textures)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
652 |
|
653 |
+
except Exception as e:
|
654 |
+
pass
|
655 |
|
656 |
+
notsimp_v, notsimp_f, notsimp_t = new_mesh.verts_packed(), new_mesh.faces_packed(), new_mesh.textures.verts_features_packed()
|
657 |
+
|
658 |
+
if fixed_v is None:
|
659 |
+
fixed_v, fixed_f = simp_v, simp_f
|
660 |
+
complete_v, complete_f, complete_t = notsimp_v, notsimp_f, notsimp_t
|
661 |
+
else:
|
662 |
+
fixed_f = torch.cat([fixed_f, simp_f + fixed_v.shape[0]], dim=0)
|
663 |
+
fixed_v = torch.cat([fixed_v, simp_v], dim=0)
|
664 |
|
665 |
+
complete_f = torch.cat([complete_f, notsimp_f + complete_v.shape[0]], dim=0)
|
666 |
+
complete_v = torch.cat([complete_v, notsimp_v], dim=0)
|
667 |
+
complete_t = torch.cat([complete_t, notsimp_t], dim=0)
|
668 |
+
|
669 |
+
if level == 2:
|
670 |
+
new_mesh = Meshes(verts=[new_mesh.verts_packed()], faces=[new_mesh.faces_packed()], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[torch.ones_like(new_mesh.textures.verts_features_packed(), device=new_mesh.verts_packed().device)*0.5]))
|
671 |
+
|
672 |
+
save_py3dmesh_with_trimesh_fast(new_mesh, meshes[name_idx].replace('.obj', '_refined.obj'), apply_sRGB_to_LinearRGB=False)
|
673 |
+
results.append(meshes[name_idx].replace('.obj', '_refined.obj'))
|
674 |
+
|
675 |
+
# save whole mesh
|
676 |
+
save_py3dmesh_with_trimesh_fast(Meshes(verts=[complete_v], faces=[complete_f], textures=pytorch3d.renderer.mesh.textures.TexturesVertex(verts_features=[complete_t])), meshes[name_idx].replace('.obj', '_refined_whole.obj'), apply_sRGB_to_LinearRGB=False)
|
677 |
+
results.append(meshes[name_idx].replace('.obj', '_refined_whole.obj'))
|
678 |
+
|
679 |
+
return results
|
680 |
+
|
681 |
+
config_slrm = {
|
682 |
+
'config_path': './configs/mesh-slrm-infer.yaml'
|
683 |
+
}
|
684 |
+
infer_slrm_config_path = config_slrm['config_path']
|
685 |
+
infer_slrm_config = OmegaConf.load(infer_slrm_config_path)
|
686 |
+
infer_slrm_config_name = os.path.basename(infer_slrm_config_path).replace('.yaml', '')
|
687 |
+
infer_slrm_model_config = infer_slrm_config.model_config
|
688 |
+
infer_slrm_infer_config = infer_slrm_config.infer_config
|
689 |
+
infer_slrm_device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
690 |
+
infer_slrm_model = instantiate_from_config(infer_slrm_model_config)
|
691 |
+
state_dict = torch.load(infer_slrm_infer_config.model_path, map_location='cpu')
|
692 |
+
infer_slrm_model.load_state_dict(state_dict, strict=False)
|
693 |
+
infer_slrm_model = infer_slrm_model.to(infer_slrm_device)
|
694 |
+
infer_slrm_model.init_flexicubes_geometry(infer_slrm_device, fovy=30.0, is_ortho=infer_slrm_model.is_ortho)
|
695 |
+
infer_slrm_model = infer_slrm_model.eval()
|
696 |
+
|
697 |
+
@spaces.GPU
|
698 |
+
def infer_slrm_gen(imgs):
|
699 |
+
imgs = [ cv2.imread(img[0])[:, :, ::-1] for img in imgs ]
|
700 |
+
imgs = np.stack(imgs, axis=0).astype(np.float32) / 255.0
|
701 |
+
imgs = torch.from_numpy(np.array(imgs)).permute(0, 3, 1, 2).contiguous().float() # (6, 3, 1024, 1024)
|
702 |
+
mesh_glb_fpaths = infer_slrm_make3d(imgs)
|
703 |
+
return mesh_glb_fpaths[1:4] + mesh_glb_fpaths[0:1]
|
704 |
+
|
705 |
+
@spaces.GPU
|
706 |
+
def infer_slrm_make3d(images):
|
707 |
+
input_cameras = torch.tensor(np.load('slrm/cameras.npy')).to(device)
|
708 |
+
|
709 |
+
images = images.unsqueeze(0).to(device)
|
710 |
+
images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1)
|
711 |
+
|
712 |
+
mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name
|
713 |
+
print(mesh_fpath)
|
714 |
+
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
715 |
+
mesh_dirname = os.path.dirname(mesh_fpath)
|
716 |
+
|
717 |
+
with torch.no_grad():
|
718 |
+
# get triplane
|
719 |
+
planes = infer_slrm_model.forward_planes(images, input_cameras.float())
|
720 |
+
|
721 |
+
# get mesh
|
722 |
+
mesh_glb_fpaths = []
|
723 |
+
for j in range(4):
|
724 |
+
mesh_glb_fpath = infer_slrm_make_mesh(mesh_fpath.replace(mesh_fpath[-4:], f'_{j}{mesh_fpath[-4:]}'), planes, level=[0, 3, 4, 2][j])
|
725 |
+
mesh_glb_fpaths.append(mesh_glb_fpath)
|
726 |
+
|
727 |
+
return mesh_glb_fpaths
|
728 |
+
|
729 |
+
@spaces.GPU
|
730 |
+
def infer_slrm_make_mesh(mesh_fpath, planes, level=None):
|
731 |
+
mesh_basename = os.path.basename(mesh_fpath).split('.')[0]
|
732 |
+
mesh_dirname = os.path.dirname(mesh_fpath)
|
733 |
+
|
734 |
+
with torch.no_grad():
|
735 |
+
# get mesh
|
736 |
+
mesh_out = infer_slrm_model.extract_mesh(
|
737 |
+
planes,
|
738 |
+
use_texture_map=False,
|
739 |
+
levels=torch.tensor([level]).to(device),
|
740 |
+
**infer_slrm_infer_config,
|
741 |
+
)
|
742 |
+
|
743 |
+
vertices, faces, vertex_colors = mesh_out
|
744 |
+
vertices = vertices[:, [1, 2, 0]]
|
745 |
+
|
746 |
+
if level == 2:
|
747 |
+
# fill all vertex_colors with 127
|
748 |
+
vertex_colors = np.ones_like(vertex_colors) * 127
|
749 |
+
|
750 |
+
save_obj(vertices, faces, vertex_colors, mesh_fpath)
|
751 |
|
752 |
+
return mesh_fpath
|
753 |
|
754 |
|
755 |
parser = argparse.ArgumentParser()
|