diff --git a/instant-mesh/configs/instant-mesh-base.yaml b/instant-mesh/configs/instant-mesh-base.yaml deleted file mode 100644 index ad4f4c0cd0d3c6f4d3038b657a41dab82c048dd1..0000000000000000000000000000000000000000 --- a/instant-mesh/configs/instant-mesh-base.yaml +++ /dev/null @@ -1,22 +0,0 @@ -model_config: - target: src.models.lrm_mesh.InstantMesh - params: - encoder_feat_dim: 768 - encoder_freeze: false - encoder_model_name: facebook/dino-vitb16 - transformer_dim: 1024 - transformer_layers: 12 - transformer_heads: 16 - triplane_low_res: 32 - triplane_high_res: 64 - triplane_dim: 40 - rendering_samples_per_ray: 96 - grid_res: 128 - grid_scale: 2.1 - - -infer_config: - unet_path: ckpts/diffusion_pytorch_model.bin - model_path: ckpts/instant_mesh_base.ckpt - texture_resolution: 1024 - render_resolution: 512 \ No newline at end of file diff --git a/instant-mesh/configs/instant-mesh-large.yaml b/instant-mesh/configs/instant-mesh-large.yaml deleted file mode 100644 index e296bc89f6d0d0649136ba2ce0e34490f76a5e41..0000000000000000000000000000000000000000 --- a/instant-mesh/configs/instant-mesh-large.yaml +++ /dev/null @@ -1,22 +0,0 @@ -model_config: - target: src.models.lrm_mesh.InstantMesh - params: - encoder_feat_dim: 768 - encoder_freeze: false - encoder_model_name: facebook/dino-vitb16 - transformer_dim: 1024 - transformer_layers: 16 - transformer_heads: 16 - triplane_low_res: 32 - triplane_high_res: 64 - triplane_dim: 80 - rendering_samples_per_ray: 128 - grid_res: 128 - grid_scale: 2.1 - - -infer_config: - unet_path: ckpts/diffusion_pytorch_model.bin - model_path: ckpts/instant_mesh_large.ckpt - texture_resolution: 1024 - render_resolution: 512 \ No newline at end of file diff --git a/instant-mesh/configs/instant-nerf-base.yaml b/instant-mesh/configs/instant-nerf-base.yaml deleted file mode 100644 index ded3d484751127d430891fc28eb2de664aecd5e1..0000000000000000000000000000000000000000 --- a/instant-mesh/configs/instant-nerf-base.yaml +++ /dev/null @@ -1,21 +0,0 @@ -model_config: - target: src.models.lrm.InstantNeRF - params: - encoder_feat_dim: 768 - encoder_freeze: false - encoder_model_name: facebook/dino-vitb16 - transformer_dim: 1024 - transformer_layers: 12 - transformer_heads: 16 - triplane_low_res: 32 - triplane_high_res: 64 - triplane_dim: 40 - rendering_samples_per_ray: 96 - - -infer_config: - unet_path: ckpts/diffusion_pytorch_model.bin - model_path: ckpts/instant_nerf_base.ckpt - mesh_threshold: 10.0 - mesh_resolution: 256 - render_resolution: 384 \ No newline at end of file diff --git a/instant-mesh/configs/instant-nerf-large.yaml b/instant-mesh/configs/instant-nerf-large.yaml deleted file mode 100644 index 57494b69d74ee78dca2e2cead2ef68ddfd0fd531..0000000000000000000000000000000000000000 --- a/instant-mesh/configs/instant-nerf-large.yaml +++ /dev/null @@ -1,21 +0,0 @@ -model_config: - target: src.models.lrm.InstantNeRF - params: - encoder_feat_dim: 768 - encoder_freeze: false - encoder_model_name: facebook/dino-vitb16 - transformer_dim: 1024 - transformer_layers: 16 - transformer_heads: 16 - triplane_low_res: 32 - triplane_high_res: 64 - triplane_dim: 80 - rendering_samples_per_ray: 128 - - -infer_config: - unet_path: ckpts/diffusion_pytorch_model.bin - model_path: ckpts/instant_nerf_large.ckpt - mesh_threshold: 10.0 - mesh_resolution: 256 - render_resolution: 384 \ No newline at end of file diff --git a/instant-mesh/examples/bird.jpg b/instant-mesh/examples/bird.jpg deleted file mode 100644 index ac70a36ebefb87fb283f3bb95d07fe71700702a3..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/bird.jpg and /dev/null differ diff --git a/instant-mesh/examples/bubble_mart_blue.png b/instant-mesh/examples/bubble_mart_blue.png deleted file mode 100644 index af870322d4a8a2f237546fbea9560bb8e5f50364..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/bubble_mart_blue.png and /dev/null differ diff --git a/instant-mesh/examples/cake.jpg b/instant-mesh/examples/cake.jpg deleted file mode 100644 index 8dbebb6901e1230405be3451c0165e80458d5542..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/cake.jpg and /dev/null differ diff --git a/instant-mesh/examples/cartoon_dinosaur.png b/instant-mesh/examples/cartoon_dinosaur.png deleted file mode 100644 index 598964626b767eb6470a28a68537c091fc5de2f8..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/cartoon_dinosaur.png and /dev/null differ diff --git a/instant-mesh/examples/cartoon_panda.png b/instant-mesh/examples/cartoon_panda.png deleted file mode 100644 index f283753d2a17fa46ac2b8dd76afe998284e1ba03..0000000000000000000000000000000000000000 --- a/instant-mesh/examples/cartoon_panda.png +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c82fea6ac66b782b2aa1c6bd133447b5f54f688c7eb44998c4b00f190d47b2b7 -size 1517334 diff --git a/instant-mesh/examples/chair_armed.png b/instant-mesh/examples/chair_armed.png deleted file mode 100644 index 2ab67e95ed57fbc5ebcd7d934827fd7fb03ab3ff..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/chair_armed.png and /dev/null differ diff --git a/instant-mesh/examples/chair_comfort.jpg b/instant-mesh/examples/chair_comfort.jpg deleted file mode 100644 index 918347fe51773d7ecaa7fb929274db8d7d5d3e19..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/chair_comfort.jpg and /dev/null differ diff --git a/instant-mesh/examples/chair_wood.jpg b/instant-mesh/examples/chair_wood.jpg deleted file mode 100644 index bc60569896fb02a46185aabb85086890f0f400d7..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/chair_wood.jpg and /dev/null differ diff --git a/instant-mesh/examples/chest.jpg b/instant-mesh/examples/chest.jpg deleted file mode 100644 index 26ae0b145887e43b850d298b94fe54828e909492..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/chest.jpg and /dev/null differ diff --git a/instant-mesh/examples/cute_horse.jpg b/instant-mesh/examples/cute_horse.jpg deleted file mode 100644 index ec8807d313b983e3cc34ee89bbf3f312d6ce66eb..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/cute_horse.jpg and /dev/null differ diff --git a/instant-mesh/examples/cute_tiger.jpg b/instant-mesh/examples/cute_tiger.jpg deleted file mode 100644 index 82e873258d9f3fd6d569205ab75deb8a26918356..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/cute_tiger.jpg and /dev/null differ diff --git a/instant-mesh/examples/earphone.jpg b/instant-mesh/examples/earphone.jpg deleted file mode 100644 index 498e4196b0d68f8809d049e7178b80592a31a0a2..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/earphone.jpg and /dev/null differ diff --git a/instant-mesh/examples/fox.jpg b/instant-mesh/examples/fox.jpg deleted file mode 100644 index 1f2efc1c3a9c4ad8f36ad93082c124c91a6e9ef7..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/fox.jpg and /dev/null differ diff --git a/instant-mesh/examples/fruit.jpg b/instant-mesh/examples/fruit.jpg deleted file mode 100644 index 07034ad3721de0e09c7509b22a7d3bc9679304d0..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/fruit.jpg and /dev/null differ diff --git a/instant-mesh/examples/fruit_elephant.jpg b/instant-mesh/examples/fruit_elephant.jpg deleted file mode 100644 index ef8eaf3b88ae0a38272b34802fe40032055afa58..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/fruit_elephant.jpg and /dev/null differ diff --git a/instant-mesh/examples/genshin_building.png b/instant-mesh/examples/genshin_building.png deleted file mode 100644 index 00b6a949d01283e1ae30fac4bd6040e13f18a055..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/genshin_building.png and /dev/null differ diff --git a/instant-mesh/examples/genshin_teapot.png b/instant-mesh/examples/genshin_teapot.png deleted file mode 100644 index 1f13a6edfe67ced810b4513117279067f0360fae..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/genshin_teapot.png and /dev/null differ diff --git a/instant-mesh/examples/hatsune_miku.png b/instant-mesh/examples/hatsune_miku.png deleted file mode 100644 index 2fecf005fdd56a396c4894256fbb98fcc1c4dd8f..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/hatsune_miku.png and /dev/null differ diff --git a/instant-mesh/examples/house2.jpg b/instant-mesh/examples/house2.jpg deleted file mode 100644 index 2eb8d63a6b91d5b16e729710c8b703aa5c11f9e5..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/house2.jpg and /dev/null differ diff --git a/instant-mesh/examples/mushroom_teapot.jpg b/instant-mesh/examples/mushroom_teapot.jpg deleted file mode 100644 index a6c767354305f5467a4c0d5f199eee2a120f4501..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/mushroom_teapot.jpg and /dev/null differ diff --git a/instant-mesh/examples/pikachu.png b/instant-mesh/examples/pikachu.png deleted file mode 100644 index e7579c16957a3e13b80d53cf0a41ddfdfd47b92d..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/pikachu.png and /dev/null differ diff --git a/instant-mesh/examples/plant.jpg b/instant-mesh/examples/plant.jpg deleted file mode 100644 index 3519c1639c3f837d9f1147cba1172e6aaab25a23..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/plant.jpg and /dev/null differ diff --git a/instant-mesh/examples/robot.jpg b/instant-mesh/examples/robot.jpg deleted file mode 100644 index 929450fba69a20389f39d46cb51d27facc1bba6d..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/robot.jpg and /dev/null differ diff --git a/instant-mesh/examples/sea_turtle.png b/instant-mesh/examples/sea_turtle.png deleted file mode 100644 index 27c3e2a9c7d44cb33914422b410ef41cf6591433..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/sea_turtle.png and /dev/null differ diff --git a/instant-mesh/examples/skating_shoe.jpg b/instant-mesh/examples/skating_shoe.jpg deleted file mode 100644 index 5f21cb1d43e9d42d2836118963fc1d2874523748..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/skating_shoe.jpg and /dev/null differ diff --git a/instant-mesh/examples/sorting_board.png b/instant-mesh/examples/sorting_board.png deleted file mode 100644 index a40fb8362afce0e323dd4517bba784cc652f5f6c..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/sorting_board.png and /dev/null differ diff --git a/instant-mesh/examples/sword.png b/instant-mesh/examples/sword.png deleted file mode 100644 index 3068cb9bdbbd9ed3c0a143fd5c741abbc58508e3..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/sword.png and /dev/null differ diff --git a/instant-mesh/examples/toy_car.jpg b/instant-mesh/examples/toy_car.jpg deleted file mode 100644 index ffa72aa6c1510e200e5d640461b779d2e7bf4997..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/toy_car.jpg and /dev/null differ diff --git a/instant-mesh/examples/watermelon.png b/instant-mesh/examples/watermelon.png deleted file mode 100644 index 52b39917abcbd2f1eef9b7c8cf9aa602bddde1bf..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/watermelon.png and /dev/null differ diff --git a/instant-mesh/examples/whitedog.png b/instant-mesh/examples/whitedog.png deleted file mode 100644 index 16c598a8133643898408ea806b69d5b18c53be7d..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/whitedog.png and /dev/null differ diff --git a/instant-mesh/examples/x_teapot.jpg b/instant-mesh/examples/x_teapot.jpg deleted file mode 100644 index 4e1cb46c5541dcc4ea544864e2eeebd42dfcb18a..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/x_teapot.jpg and /dev/null differ diff --git a/instant-mesh/examples/x_toyduck.jpg b/instant-mesh/examples/x_toyduck.jpg deleted file mode 100644 index 5e60d43bd76d7511e44568c4f9bba2a11a1a4f04..0000000000000000000000000000000000000000 Binary files a/instant-mesh/examples/x_toyduck.jpg and /dev/null differ diff --git a/instant-mesh/src/__init__.py b/instant-mesh/src/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/instant-mesh/src/data/__init__.py b/instant-mesh/src/data/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/instant-mesh/src/data/objaverse.py b/instant-mesh/src/data/objaverse.py deleted file mode 100644 index dd27f86c2469e74da28e27929929d84cd1718965..0000000000000000000000000000000000000000 --- a/instant-mesh/src/data/objaverse.py +++ /dev/null @@ -1,329 +0,0 @@ -import os, sys -import math -import json -import importlib -from pathlib import Path - -import cv2 -import random -import numpy as np -from PIL import Image -import webdataset as wds -import pytorch_lightning as pl - -import torch -import torch.nn.functional as F -from torch.utils.data import Dataset -from torch.utils.data import DataLoader -from torch.utils.data.distributed import DistributedSampler -from torchvision import transforms - -from src.utils.train_util import instantiate_from_config -from src.utils.camera_util import ( - FOV_to_intrinsics, - center_looking_at_camera_pose, - get_surrounding_views, -) - - -class DataModuleFromConfig(pl.LightningDataModule): - def __init__( - self, - batch_size=8, - num_workers=4, - train=None, - validation=None, - test=None, - **kwargs, - ): - super().__init__() - - self.batch_size = batch_size - self.num_workers = num_workers - - self.dataset_configs = dict() - if train is not None: - self.dataset_configs['train'] = train - if validation is not None: - self.dataset_configs['validation'] = validation - if test is not None: - self.dataset_configs['test'] = test - - def setup(self, stage): - - if stage in ['fit']: - self.datasets = dict((k, instantiate_from_config(self.dataset_configs[k])) for k in self.dataset_configs) - else: - raise NotImplementedError - - def train_dataloader(self): - - sampler = DistributedSampler(self.datasets['train']) - return wds.WebLoader(self.datasets['train'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False, sampler=sampler) - - def val_dataloader(self): - - sampler = DistributedSampler(self.datasets['validation']) - return wds.WebLoader(self.datasets['validation'], batch_size=1, num_workers=self.num_workers, shuffle=False, sampler=sampler) - - def test_dataloader(self): - - return wds.WebLoader(self.datasets['test'], batch_size=self.batch_size, num_workers=self.num_workers, shuffle=False) - - -class ObjaverseData(Dataset): - def __init__(self, - root_dir='objaverse/', - meta_fname='valid_paths.json', - input_image_dir='rendering_random_32views', - target_image_dir='rendering_random_32views', - input_view_num=6, - target_view_num=2, - total_view_n=32, - fov=50, - camera_rotation=True, - validation=False, - ): - self.root_dir = Path(root_dir) - self.input_image_dir = input_image_dir - self.target_image_dir = target_image_dir - - self.input_view_num = input_view_num - self.target_view_num = target_view_num - self.total_view_n = total_view_n - self.fov = fov - self.camera_rotation = camera_rotation - - with open(os.path.join(root_dir, meta_fname)) as f: - filtered_dict = json.load(f) - paths = filtered_dict['good_objs'] - self.paths = paths - - self.depth_scale = 4.0 - - total_objects = len(self.paths) - print('============= length of dataset %d =============' % len(self.paths)) - - def __len__(self): - return len(self.paths) - - def load_im(self, path, color): - ''' - replace background pixel with random color in rendering - ''' - pil_img = Image.open(path) - - image = np.asarray(pil_img, dtype=np.float32) / 255. - alpha = image[:, :, 3:] - image = image[:, :, :3] * alpha + color * (1 - alpha) - - image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() - alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() - return image, alpha - - def __getitem__(self, index): - # load data - while True: - input_image_path = os.path.join(self.root_dir, self.input_image_dir, self.paths[index]) - target_image_path = os.path.join(self.root_dir, self.target_image_dir, self.paths[index]) - - indices = np.random.choice(range(self.total_view_n), self.input_view_num + self.target_view_num, replace=False) - input_indices = indices[:self.input_view_num] - target_indices = indices[self.input_view_num:] - - '''background color, default: white''' - bg_white = [1., 1., 1.] - bg_black = [0., 0., 0.] - - image_list = [] - alpha_list = [] - depth_list = [] - normal_list = [] - pose_list = [] - - try: - input_cameras = np.load(os.path.join(input_image_path, 'cameras.npz'))['cam_poses'] - for idx in input_indices: - image, alpha = self.load_im(os.path.join(input_image_path, '%03d.png' % idx), bg_white) - normal, _ = self.load_im(os.path.join(input_image_path, '%03d_normal.png' % idx), bg_black) - depth = cv2.imread(os.path.join(input_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale - depth = torch.from_numpy(depth).unsqueeze(0) - pose = input_cameras[idx] - pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0) - - image_list.append(image) - alpha_list.append(alpha) - depth_list.append(depth) - normal_list.append(normal) - pose_list.append(pose) - - target_cameras = np.load(os.path.join(target_image_path, 'cameras.npz'))['cam_poses'] - for idx in target_indices: - image, alpha = self.load_im(os.path.join(target_image_path, '%03d.png' % idx), bg_white) - normal, _ = self.load_im(os.path.join(target_image_path, '%03d_normal.png' % idx), bg_black) - depth = cv2.imread(os.path.join(target_image_path, '%03d_depth.png' % idx), cv2.IMREAD_UNCHANGED) / 255.0 * self.depth_scale - depth = torch.from_numpy(depth).unsqueeze(0) - pose = target_cameras[idx] - pose = np.concatenate([pose, np.array([[0, 0, 0, 1]])], axis=0) - - image_list.append(image) - alpha_list.append(alpha) - depth_list.append(depth) - normal_list.append(normal) - pose_list.append(pose) - - except Exception as e: - print(e) - index = np.random.randint(0, len(self.paths)) - continue - - break - - images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W) - alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W) - depths = torch.stack(depth_list, dim=0).float() # (6+V, 1, H, W) - normals = torch.stack(normal_list, dim=0).float() # (6+V, 3, H, W) - w2cs = torch.from_numpy(np.stack(pose_list, axis=0)).float() # (6+V, 4, 4) - c2ws = torch.linalg.inv(w2cs).float() - - normals = normals * 2.0 - 1.0 - normals = F.normalize(normals, dim=1) - normals = (normals + 1.0) / 2.0 - normals = torch.lerp(torch.zeros_like(normals), normals, alphas) - - # random rotation along z axis - if self.camera_rotation: - degree = np.random.uniform(0, math.pi * 2) - rot = torch.tensor([ - [np.cos(degree), -np.sin(degree), 0, 0], - [np.sin(degree), np.cos(degree), 0, 0], - [0, 0, 1, 0], - [0, 0, 0, 1], - ]).unsqueeze(0).float() - c2ws = torch.matmul(rot, c2ws) - - # rotate normals - N, _, H, W = normals.shape - normals = normals * 2.0 - 1.0 - normals = torch.matmul(rot[:, :3, :3], normals.view(N, 3, -1)).view(N, 3, H, W) - normals = F.normalize(normals, dim=1) - normals = (normals + 1.0) / 2.0 - normals = torch.lerp(torch.zeros_like(normals), normals, alphas) - - # random scaling - if np.random.rand() < 0.5: - scale = np.random.uniform(0.8, 1.0) - c2ws[:, :3, 3] *= scale - depths *= scale - - # instrinsics of perspective cameras - K = FOV_to_intrinsics(self.fov) - Ks = K.unsqueeze(0).repeat(self.input_view_num + self.target_view_num, 1, 1).float() - - data = { - 'input_images': images[:self.input_view_num], # (6, 3, H, W) - 'input_alphas': alphas[:self.input_view_num], # (6, 1, H, W) - 'input_depths': depths[:self.input_view_num], # (6, 1, H, W) - 'input_normals': normals[:self.input_view_num], # (6, 3, H, W) - 'input_c2ws': c2ws_input[:self.input_view_num], # (6, 4, 4) - 'input_Ks': Ks[:self.input_view_num], # (6, 3, 3) - - # lrm generator input and supervision - 'target_images': images[self.input_view_num:], # (V, 3, H, W) - 'target_alphas': alphas[self.input_view_num:], # (V, 1, H, W) - 'target_depths': depths[self.input_view_num:], # (V, 1, H, W) - 'target_normals': normals[self.input_view_num:], # (V, 3, H, W) - 'target_c2ws': c2ws[self.input_view_num:], # (V, 4, 4) - 'target_Ks': Ks[self.input_view_num:], # (V, 3, 3) - - 'depth_available': 1, - } - return data - - -class ValidationData(Dataset): - def __init__(self, - root_dir='objaverse/', - input_view_num=6, - input_image_size=256, - fov=50, - ): - self.root_dir = Path(root_dir) - self.input_view_num = input_view_num - self.input_image_size = input_image_size - self.fov = fov - - self.paths = sorted(os.listdir(self.root_dir)) - print('============= length of dataset %d =============' % len(self.paths)) - - cam_distance = 2.5 - azimuths = np.array([30, 90, 150, 210, 270, 330]) - elevations = np.array([30, -20, 30, -20, 30, -20]) - azimuths = np.deg2rad(azimuths) - elevations = np.deg2rad(elevations) - - x = cam_distance * np.cos(elevations) * np.cos(azimuths) - y = cam_distance * np.cos(elevations) * np.sin(azimuths) - z = cam_distance * np.sin(elevations) - - cam_locations = np.stack([x, y, z], axis=-1) - cam_locations = torch.from_numpy(cam_locations).float() - c2ws = center_looking_at_camera_pose(cam_locations) - self.c2ws = c2ws.float() - self.Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(6, 1, 1).float() - - render_c2ws = get_surrounding_views(M=8, radius=cam_distance) - render_Ks = FOV_to_intrinsics(self.fov).unsqueeze(0).repeat(render_c2ws.shape[0], 1, 1) - self.render_c2ws = render_c2ws.float() - self.render_Ks = render_Ks.float() - - def __len__(self): - return len(self.paths) - - def load_im(self, path, color): - ''' - replace background pixel with random color in rendering - ''' - pil_img = Image.open(path) - pil_img = pil_img.resize((self.input_image_size, self.input_image_size), resample=Image.BICUBIC) - - image = np.asarray(pil_img, dtype=np.float32) / 255. - if image.shape[-1] == 4: - alpha = image[:, :, 3:] - image = image[:, :, :3] * alpha + color * (1 - alpha) - else: - alpha = np.ones_like(image[:, :, :1]) - - image = torch.from_numpy(image).permute(2, 0, 1).contiguous().float() - alpha = torch.from_numpy(alpha).permute(2, 0, 1).contiguous().float() - return image, alpha - - def __getitem__(self, index): - # load data - input_image_path = os.path.join(self.root_dir, self.paths[index]) - - '''background color, default: white''' - # color = np.random.uniform(0.48, 0.52) - bkg_color = [1.0, 1.0, 1.0] - - image_list = [] - alpha_list = [] - - for idx in range(self.input_view_num): - image, alpha = self.load_im(os.path.join(input_image_path, f'{idx:03d}.png'), bkg_color) - image_list.append(image) - alpha_list.append(alpha) - - images = torch.stack(image_list, dim=0).float() # (6+V, 3, H, W) - alphas = torch.stack(alpha_list, dim=0).float() # (6+V, 1, H, W) - - data = { - 'input_images': images, # (6, 3, H, W) - 'input_alphas': alphas, # (6, 1, H, W) - 'input_c2ws': self.c2ws, # (6, 4, 4) - 'input_Ks': self.Ks, # (6, 3, 3) - - 'render_c2ws': self.render_c2ws, - 'render_Ks': self.render_Ks, - } - return data diff --git a/instant-mesh/src/model.py b/instant-mesh/src/model.py deleted file mode 100644 index 584a6dcc59a641104f8942e7f4b4fc225e551f6a..0000000000000000000000000000000000000000 --- a/instant-mesh/src/model.py +++ /dev/null @@ -1,310 +0,0 @@ -import os -import numpy as np -import torch -import torch.nn.functional as F -from torchvision.transforms import v2 -from torchvision.utils import make_grid, save_image -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity -import pytorch_lightning as pl -from einops import rearrange, repeat - -from src.utils.train_util import instantiate_from_config - - -class MVRecon(pl.LightningModule): - def __init__( - self, - lrm_generator_config, - lrm_path=None, - input_size=256, - render_size=192, - ): - super(MVRecon, self).__init__() - - self.input_size = input_size - self.render_size = render_size - - # init modules - self.lrm_generator = instantiate_from_config(lrm_generator_config) - if lrm_path is not None: - lrm_ckpt = torch.load(lrm_path) - self.lrm_generator.load_state_dict(lrm_ckpt['weights'], strict=False) - - self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg') - - self.validation_step_outputs = [] - - def on_fit_start(self): - if self.global_rank == 0: - os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) - os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) - - def prepare_batch_data(self, batch): - lrm_generator_input = {} - render_gt = {} # for supervision - - # input images - images = batch['input_images'] - images = v2.functional.resize( - images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) - - lrm_generator_input['images'] = images.to(self.device) - - # input cameras and render cameras - input_c2ws = batch['input_c2ws'].flatten(-2) - input_Ks = batch['input_Ks'].flatten(-2) - target_c2ws = batch['target_c2ws'].flatten(-2) - target_Ks = batch['target_Ks'].flatten(-2) - render_cameras_input = torch.cat([input_c2ws, input_Ks], dim=-1) - render_cameras_target = torch.cat([target_c2ws, target_Ks], dim=-1) - render_cameras = torch.cat([render_cameras_input, render_cameras_target], dim=1) - - input_extrinsics = input_c2ws[:, :, :12] - input_intrinsics = torch.stack([ - input_Ks[:, :, 0], input_Ks[:, :, 4], - input_Ks[:, :, 2], input_Ks[:, :, 5], - ], dim=-1) - cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1) - - # add noise to input cameras - cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02 - - lrm_generator_input['cameras'] = cameras.to(self.device) - lrm_generator_input['render_cameras'] = render_cameras.to(self.device) - - # target images - target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1) - target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1) - target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1) - - # random crop - render_size = np.random.randint(self.render_size, 513) - target_images = v2.functional.resize( - target_images, render_size, interpolation=3, antialias=True).clamp(0, 1) - target_depths = v2.functional.resize( - target_depths, render_size, interpolation=0, antialias=True) - target_alphas = v2.functional.resize( - target_alphas, render_size, interpolation=0, antialias=True) - - crop_params = v2.RandomCrop.get_params( - target_images, output_size=(self.render_size, self.render_size)) - target_images = v2.functional.crop(target_images, *crop_params) - target_depths = v2.functional.crop(target_depths, *crop_params)[:, :, 0:1] - target_alphas = v2.functional.crop(target_alphas, *crop_params)[:, :, 0:1] - - lrm_generator_input['render_size'] = render_size - lrm_generator_input['crop_params'] = crop_params - - render_gt['target_images'] = target_images.to(self.device) - render_gt['target_depths'] = target_depths.to(self.device) - render_gt['target_alphas'] = target_alphas.to(self.device) - - return lrm_generator_input, render_gt - - def prepare_validation_batch_data(self, batch): - lrm_generator_input = {} - - # input images - images = batch['input_images'] - images = v2.functional.resize( - images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) - - lrm_generator_input['images'] = images.to(self.device) - - input_c2ws = batch['input_c2ws'].flatten(-2) - input_Ks = batch['input_Ks'].flatten(-2) - - input_extrinsics = input_c2ws[:, :, :12] - input_intrinsics = torch.stack([ - input_Ks[:, :, 0], input_Ks[:, :, 4], - input_Ks[:, :, 2], input_Ks[:, :, 5], - ], dim=-1) - cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1) - - lrm_generator_input['cameras'] = cameras.to(self.device) - - render_c2ws = batch['render_c2ws'].flatten(-2) - render_Ks = batch['render_Ks'].flatten(-2) - render_cameras = torch.cat([render_c2ws, render_Ks], dim=-1) - - lrm_generator_input['render_cameras'] = render_cameras.to(self.device) - lrm_generator_input['render_size'] = 384 - lrm_generator_input['crop_params'] = None - - return lrm_generator_input - - def forward_lrm_generator( - self, - images, - cameras, - render_cameras, - render_size=192, - crop_params=None, - chunk_size=1, - ): - planes = torch.utils.checkpoint.checkpoint( - self.lrm_generator.forward_planes, - images, - cameras, - use_reentrant=False, - ) - frames = [] - for i in range(0, render_cameras.shape[1], chunk_size): - frames.append( - torch.utils.checkpoint.checkpoint( - self.lrm_generator.synthesizer, - planes, - cameras=render_cameras[:, i:i+chunk_size], - render_size=render_size, - crop_params=crop_params, - use_reentrant=False - ) - ) - frames = { - k: torch.cat([r[k] for r in frames], dim=1) - for k in frames[0].keys() - } - return frames - - def forward(self, lrm_generator_input): - images = lrm_generator_input['images'] - cameras = lrm_generator_input['cameras'] - render_cameras = lrm_generator_input['render_cameras'] - render_size = lrm_generator_input['render_size'] - crop_params = lrm_generator_input['crop_params'] - - out = self.forward_lrm_generator( - images, - cameras, - render_cameras, - render_size=render_size, - crop_params=crop_params, - chunk_size=1, - ) - render_images = torch.clamp(out['images_rgb'], 0.0, 1.0) - render_depths = out['images_depth'] - render_alphas = torch.clamp(out['images_weight'], 0.0, 1.0) - - out = { - 'render_images': render_images, - 'render_depths': render_depths, - 'render_alphas': render_alphas, - } - return out - - def training_step(self, batch, batch_idx): - lrm_generator_input, render_gt = self.prepare_batch_data(batch) - - render_out = self.forward(lrm_generator_input) - - loss, loss_dict = self.compute_loss(render_out, render_gt) - - self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) - - if self.global_step % 1000 == 0 and self.global_rank == 0: - B, N, C, H, W = render_gt['target_images'].shape - N_in = lrm_generator_input['images'].shape[1] - - input_images = v2.functional.resize( - lrm_generator_input['images'], (H, W), interpolation=3, antialias=True).clamp(0, 1) - input_images = torch.cat( - [input_images, torch.ones(B, N-N_in, C, H, W).to(input_images)], dim=1) - - input_images = rearrange( - input_images, 'b n c h w -> b c h (n w)') - target_images = rearrange( - render_gt['target_images'], 'b n c h w -> b c h (n w)') - render_images = rearrange( - render_out['render_images'], 'b n c h w -> b c h (n w)') - target_alphas = rearrange( - repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') - render_alphas = rearrange( - repeat(render_out['render_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') - target_depths = rearrange( - repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') - render_depths = rearrange( - repeat(render_out['render_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') - MAX_DEPTH = torch.max(target_depths) - target_depths = target_depths / MAX_DEPTH * target_alphas - render_depths = render_depths / MAX_DEPTH - - grid = torch.cat([ - input_images, - target_images, render_images, - target_alphas, render_alphas, - target_depths, render_depths, - ], dim=-2) - grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1)) - - save_image(grid, os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png')) - - return loss - - def compute_loss(self, render_out, render_gt): - # NOTE: the rgb value range of OpenLRM is [0, 1] - render_images = render_out['render_images'] - target_images = render_gt['target_images'].to(render_images) - render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 - target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 - - loss_mse = F.mse_loss(render_images, target_images) - loss_lpips = 2.0 * self.lpips(render_images, target_images) - - render_alphas = render_out['render_alphas'] - target_alphas = render_gt['target_alphas'] - loss_mask = F.mse_loss(render_alphas, target_alphas) - - loss = loss_mse + loss_lpips + loss_mask - - prefix = 'train' - loss_dict = {} - loss_dict.update({f'{prefix}/loss_mse': loss_mse}) - loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) - loss_dict.update({f'{prefix}/loss_mask': loss_mask}) - loss_dict.update({f'{prefix}/loss': loss}) - - return loss, loss_dict - - @torch.no_grad() - def validation_step(self, batch, batch_idx): - lrm_generator_input = self.prepare_validation_batch_data(batch) - - render_out = self.forward(lrm_generator_input) - render_images = render_out['render_images'] - render_images = rearrange(render_images, 'b n c h w -> b c h (n w)') - - self.validation_step_outputs.append(render_images) - - def on_validation_epoch_end(self): - images = torch.cat(self.validation_step_outputs, dim=-1) - - all_images = self.all_gather(images) - all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') - - if self.global_rank == 0: - image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png') - - grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1)) - save_image(grid, image_path) - print(f"Saved image to {image_path}") - - self.validation_step_outputs.clear() - - def configure_optimizers(self): - lr = self.learning_rate - - params = [] - - lrm_params_fast, lrm_params_slow = [], [] - for n, p in self.lrm_generator.named_parameters(): - if 'adaLN_modulation' in n or 'camera_embedder' in n: - lrm_params_fast.append(p) - else: - lrm_params_slow.append(p) - params.append({"params": lrm_params_fast, "lr": lr, "weight_decay": 0.01 }) - params.append({"params": lrm_params_slow, "lr": lr / 10.0, "weight_decay": 0.01 }) - - optimizer = torch.optim.AdamW(params, lr=lr, betas=(0.90, 0.95)) - scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 3000, eta_min=lr/4) - - return {'optimizer': optimizer, 'lr_scheduler': scheduler} diff --git a/instant-mesh/src/model_mesh.py b/instant-mesh/src/model_mesh.py deleted file mode 100644 index 99945a0b410242a71678ad0034bf38315a34571b..0000000000000000000000000000000000000000 --- a/instant-mesh/src/model_mesh.py +++ /dev/null @@ -1,325 +0,0 @@ -import os -import numpy as np -import torch -import torch.nn.functional as F -from torchvision.transforms import v2 -from torchvision.utils import make_grid, save_image -from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity -import pytorch_lightning as pl -from einops import rearrange, repeat - -from src.utils.train_util import instantiate_from_config - - -# Regulrarization loss for FlexiCubes -def sdf_reg_loss_batch(sdf, all_edges): - sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2) - mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) - sdf_f1x6x2 = sdf_f1x6x2[mask] - sdf_diff = F.binary_cross_entropy_with_logits( - sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \ - F.binary_cross_entropy_with_logits( - sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()) - return sdf_diff - - -class MVRecon(pl.LightningModule): - def __init__( - self, - lrm_generator_config, - input_size=256, - render_size=512, - init_ckpt=None, - ): - super(MVRecon, self).__init__() - - self.input_size = input_size - self.render_size = render_size - - # init modules - self.lrm_generator = instantiate_from_config(lrm_generator_config) - - self.lpips = LearnedPerceptualImagePatchSimilarity(net_type='vgg') - - # Load weights from pretrained MVRecon model, and use the mlp - # weights to initialize the weights of sdf and rgb mlps. - if init_ckpt is not None: - sd = torch.load(init_ckpt, map_location='cpu')['state_dict'] - sd = {k: v for k, v in sd.items() if k.startswith('lrm_generator')} - sd_fc = {} - for k, v in sd.items(): - if k.startswith('lrm_generator.synthesizer.decoder.net.'): - if k.startswith('lrm_generator.synthesizer.decoder.net.6.'): # last layer - # Here we assume the density filed's isosurface threshold is t, - # we reverse the sign of density filed to initialize SDF field. - # -(w*x + b - t) = (-w)*x + (t - b) - if 'weight' in k: - sd_fc[k.replace('net.', 'net_sdf.')] = -v[0:1] - else: - sd_fc[k.replace('net.', 'net_sdf.')] = 3.0 - v[0:1] - sd_fc[k.replace('net.', 'net_rgb.')] = v[1:4] - else: - sd_fc[k.replace('net.', 'net_sdf.')] = v - sd_fc[k.replace('net.', 'net_rgb.')] = v - else: - sd_fc[k] = v - sd_fc = {k.replace('lrm_generator.', ''): v for k, v in sd_fc.items()} - # missing `net_deformation` and `net_weight` parameters - self.lrm_generator.load_state_dict(sd_fc, strict=False) - print(f'Loaded weights from {init_ckpt}') - - self.validation_step_outputs = [] - - def on_fit_start(self): - device = torch.device(f'cuda:{self.global_rank}') - self.lrm_generator.init_flexicubes_geometry(device) - if self.global_rank == 0: - os.makedirs(os.path.join(self.logdir, 'images'), exist_ok=True) - os.makedirs(os.path.join(self.logdir, 'images_val'), exist_ok=True) - - def prepare_batch_data(self, batch): - lrm_generator_input = {} - render_gt = {} - - # input images - images = batch['input_images'] - images = v2.functional.resize( - images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) - - lrm_generator_input['images'] = images.to(self.device) - - # input cameras and render cameras - input_c2ws = batch['input_c2ws'] - input_Ks = batch['input_Ks'] - target_c2ws = batch['target_c2ws'] - - render_c2ws = torch.cat([input_c2ws, target_c2ws], dim=1) - render_w2cs = torch.linalg.inv(render_c2ws) - - input_extrinsics = input_c2ws.flatten(-2) - input_extrinsics = input_extrinsics[:, :, :12] - input_intrinsics = input_Ks.flatten(-2) - input_intrinsics = torch.stack([ - input_intrinsics[:, :, 0], input_intrinsics[:, :, 4], - input_intrinsics[:, :, 2], input_intrinsics[:, :, 5], - ], dim=-1) - cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1) - - # add noise to input_cameras - cameras = cameras + torch.rand_like(cameras) * 0.04 - 0.02 - - lrm_generator_input['cameras'] = cameras.to(self.device) - lrm_generator_input['render_cameras'] = render_w2cs.to(self.device) - - # target images - target_images = torch.cat([batch['input_images'], batch['target_images']], dim=1) - target_depths = torch.cat([batch['input_depths'], batch['target_depths']], dim=1) - target_alphas = torch.cat([batch['input_alphas'], batch['target_alphas']], dim=1) - target_normals = torch.cat([batch['input_normals'], batch['target_normals']], dim=1) - - render_size = self.render_size - target_images = v2.functional.resize( - target_images, render_size, interpolation=3, antialias=True).clamp(0, 1) - target_depths = v2.functional.resize( - target_depths, render_size, interpolation=0, antialias=True) - target_alphas = v2.functional.resize( - target_alphas, render_size, interpolation=0, antialias=True) - target_normals = v2.functional.resize( - target_normals, render_size, interpolation=3, antialias=True) - - lrm_generator_input['render_size'] = render_size - - render_gt['target_images'] = target_images.to(self.device) - render_gt['target_depths'] = target_depths.to(self.device) - render_gt['target_alphas'] = target_alphas.to(self.device) - render_gt['target_normals'] = target_normals.to(self.device) - - return lrm_generator_input, render_gt - - def prepare_validation_batch_data(self, batch): - lrm_generator_input = {} - - # input images - images = batch['input_images'] - images = v2.functional.resize( - images, self.input_size, interpolation=3, antialias=True).clamp(0, 1) - - lrm_generator_input['images'] = images.to(self.device) - - # input cameras - input_c2ws = batch['input_c2ws'].flatten(-2) - input_Ks = batch['input_Ks'].flatten(-2) - - input_extrinsics = input_c2ws[:, :, :12] - input_intrinsics = torch.stack([ - input_Ks[:, :, 0], input_Ks[:, :, 4], - input_Ks[:, :, 2], input_Ks[:, :, 5], - ], dim=-1) - cameras = torch.cat([input_extrinsics, input_intrinsics], dim=-1) - - lrm_generator_input['cameras'] = cameras.to(self.device) - - # render cameras - render_c2ws = batch['render_c2ws'] - render_w2cs = torch.linalg.inv(render_c2ws) - - lrm_generator_input['render_cameras'] = render_w2cs.to(self.device) - lrm_generator_input['render_size'] = 384 - - return lrm_generator_input - - def forward_lrm_generator(self, images, cameras, render_cameras, render_size=512): - planes = torch.utils.checkpoint.checkpoint( - self.lrm_generator.forward_planes, - images, - cameras, - use_reentrant=False, - ) - out = self.lrm_generator.forward_geometry( - planes, - render_cameras, - render_size, - ) - return out - - def forward(self, lrm_generator_input): - images = lrm_generator_input['images'] - cameras = lrm_generator_input['cameras'] - render_cameras = lrm_generator_input['render_cameras'] - render_size = lrm_generator_input['render_size'] - - out = self.forward_lrm_generator( - images, cameras, render_cameras, render_size=render_size) - - return out - - def training_step(self, batch, batch_idx): - lrm_generator_input, render_gt = self.prepare_batch_data(batch) - - render_out = self.forward(lrm_generator_input) - - loss, loss_dict = self.compute_loss(render_out, render_gt) - - self.log_dict(loss_dict, prog_bar=True, logger=True, on_step=True, on_epoch=True) - - if self.global_step % 1000 == 0 and self.global_rank == 0: - B, N, C, H, W = render_gt['target_images'].shape - N_in = lrm_generator_input['images'].shape[1] - - target_images = rearrange( - render_gt['target_images'], 'b n c h w -> b c h (n w)') - render_images = rearrange( - render_out['img'], 'b n c h w -> b c h (n w)') - target_alphas = rearrange( - repeat(render_gt['target_alphas'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') - render_alphas = rearrange( - repeat(render_out['mask'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') - target_depths = rearrange( - repeat(render_gt['target_depths'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') - render_depths = rearrange( - repeat(render_out['depth'], 'b n 1 h w -> b n 3 h w'), 'b n c h w -> b c h (n w)') - target_normals = rearrange( - render_gt['target_normals'], 'b n c h w -> b c h (n w)') - render_normals = rearrange( - render_out['normal'], 'b n c h w -> b c h (n w)') - MAX_DEPTH = torch.max(target_depths) - target_depths = target_depths / MAX_DEPTH * target_alphas - render_depths = render_depths / MAX_DEPTH - - grid = torch.cat([ - target_images, render_images, - target_alphas, render_alphas, - target_depths, render_depths, - target_normals, render_normals, - ], dim=-2) - grid = make_grid(grid, nrow=target_images.shape[0], normalize=True, value_range=(0, 1)) - - image_path = os.path.join(self.logdir, 'images', f'train_{self.global_step:07d}.png') - save_image(grid, image_path) - print(f"Saved image to {image_path}") - - return loss - - def compute_loss(self, render_out, render_gt): - # NOTE: the rgb value range of OpenLRM is [0, 1] - render_images = render_out['img'] - target_images = render_gt['target_images'].to(render_images) - render_images = rearrange(render_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 - target_images = rearrange(target_images, 'b n ... -> (b n) ...') * 2.0 - 1.0 - loss_mse = F.mse_loss(render_images, target_images) - loss_lpips = 2.0 * self.lpips(render_images, target_images) - - render_alphas = render_out['mask'] - target_alphas = render_gt['target_alphas'] - loss_mask = F.mse_loss(render_alphas, target_alphas) - - render_depths = render_out['depth'] - target_depths = render_gt['target_depths'] - loss_depth = 0.5 * F.l1_loss(render_depths[target_alphas>0], target_depths[target_alphas>0]) - - render_normals = render_out['normal'] * 2.0 - 1.0 - target_normals = render_gt['target_normals'] * 2.0 - 1.0 - similarity = (render_normals * target_normals).sum(dim=-3).abs() - normal_mask = target_alphas.squeeze(-3) - loss_normal = 1 - similarity[normal_mask>0].mean() - loss_normal = 0.2 * loss_normal - - # flexicubes regularization loss - sdf = render_out['sdf'] - sdf_reg_loss = render_out['sdf_reg_loss'] - sdf_reg_loss_entropy = sdf_reg_loss_batch(sdf, self.lrm_generator.geometry.all_edges).mean() * 0.01 - _, flexicubes_surface_reg, flexicubes_weights_reg = sdf_reg_loss - flexicubes_surface_reg = flexicubes_surface_reg.mean() * 0.5 - flexicubes_weights_reg = flexicubes_weights_reg.mean() * 0.1 - - loss_reg = sdf_reg_loss_entropy + flexicubes_surface_reg + flexicubes_weights_reg - - loss = loss_mse + loss_lpips + loss_mask + loss_normal + loss_reg - - prefix = 'train' - loss_dict = {} - loss_dict.update({f'{prefix}/loss_mse': loss_mse}) - loss_dict.update({f'{prefix}/loss_lpips': loss_lpips}) - loss_dict.update({f'{prefix}/loss_mask': loss_mask}) - loss_dict.update({f'{prefix}/loss_normal': loss_normal}) - loss_dict.update({f'{prefix}/loss_depth': loss_depth}) - loss_dict.update({f'{prefix}/loss_reg_sdf': sdf_reg_loss_entropy}) - loss_dict.update({f'{prefix}/loss_reg_surface': flexicubes_surface_reg}) - loss_dict.update({f'{prefix}/loss_reg_weights': flexicubes_weights_reg}) - loss_dict.update({f'{prefix}/loss': loss}) - - return loss, loss_dict - - @torch.no_grad() - def validation_step(self, batch, batch_idx): - lrm_generator_input = self.prepare_validation_batch_data(batch) - - render_out = self.forward(lrm_generator_input) - render_images = render_out['img'] - render_images = rearrange(render_images, 'b n c h w -> b c h (n w)') - - self.validation_step_outputs.append(render_images) - - def on_validation_epoch_end(self): - images = torch.cat(self.validation_step_outputs, dim=-1) - - all_images = self.all_gather(images) - all_images = rearrange(all_images, 'r b c h w -> (r b) c h w') - - if self.global_rank == 0: - image_path = os.path.join(self.logdir, 'images_val', f'val_{self.global_step:07d}.png') - - grid = make_grid(all_images, nrow=1, normalize=True, value_range=(0, 1)) - save_image(grid, image_path) - print(f"Saved image to {image_path}") - - self.validation_step_outputs.clear() - - def configure_optimizers(self): - lr = self.learning_rate - - optimizer = torch.optim.AdamW( - self.lrm_generator.parameters(), lr=lr, betas=(0.90, 0.95), weight_decay=0.01) - scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, 100000, eta_min=0) - - return {'optimizer': optimizer, 'lr_scheduler': scheduler} \ No newline at end of file diff --git a/instant-mesh/src/models/__init__.py b/instant-mesh/src/models/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/instant-mesh/src/models/decoder/__init__.py b/instant-mesh/src/models/decoder/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/instant-mesh/src/models/decoder/transformer.py b/instant-mesh/src/models/decoder/transformer.py deleted file mode 100644 index d8e628c0bf589ee827908c894b93cc107f1c58b9..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/decoder/transformer.py +++ /dev/null @@ -1,123 +0,0 @@ -# Copyright (c) 2023, Zexin He -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch -import torch.nn as nn - - -class BasicTransformerBlock(nn.Module): - """ - Transformer block that takes in a cross-attention condition and another modulation vector applied to sub-blocks. - """ - # use attention from torch.nn.MultiHeadAttention - # Block contains a cross-attention layer, a self-attention layer, and a MLP - def __init__( - self, - inner_dim: int, - cond_dim: int, - num_heads: int, - eps: float, - attn_drop: float = 0., - attn_bias: bool = False, - mlp_ratio: float = 4., - mlp_drop: float = 0., - ): - super().__init__() - - self.norm1 = nn.LayerNorm(inner_dim) - self.cross_attn = nn.MultiheadAttention( - embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim, - dropout=attn_drop, bias=attn_bias, batch_first=True) - self.norm2 = nn.LayerNorm(inner_dim) - self.self_attn = nn.MultiheadAttention( - embed_dim=inner_dim, num_heads=num_heads, - dropout=attn_drop, bias=attn_bias, batch_first=True) - self.norm3 = nn.LayerNorm(inner_dim) - self.mlp = nn.Sequential( - nn.Linear(inner_dim, int(inner_dim * mlp_ratio)), - nn.GELU(), - nn.Dropout(mlp_drop), - nn.Linear(int(inner_dim * mlp_ratio), inner_dim), - nn.Dropout(mlp_drop), - ) - - def forward(self, x, cond): - # x: [N, L, D] - # cond: [N, L_cond, D_cond] - x = x + self.cross_attn(self.norm1(x), cond, cond)[0] - before_sa = self.norm2(x) - x = x + self.self_attn(before_sa, before_sa, before_sa)[0] - x = x + self.mlp(self.norm3(x)) - return x - - -class TriplaneTransformer(nn.Module): - """ - Transformer with condition that generates a triplane representation. - - Reference: - Timm: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L486 - """ - def __init__( - self, - inner_dim: int, - image_feat_dim: int, - triplane_low_res: int, - triplane_high_res: int, - triplane_dim: int, - num_layers: int, - num_heads: int, - eps: float = 1e-6, - ): - super().__init__() - - # attributes - self.triplane_low_res = triplane_low_res - self.triplane_high_res = triplane_high_res - self.triplane_dim = triplane_dim - - # modules - # initialize pos_embed with 1/sqrt(dim) * N(0, 1) - self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, inner_dim) * (1. / inner_dim) ** 0.5) - self.layers = nn.ModuleList([ - BasicTransformerBlock( - inner_dim=inner_dim, cond_dim=image_feat_dim, num_heads=num_heads, eps=eps) - for _ in range(num_layers) - ]) - self.norm = nn.LayerNorm(inner_dim, eps=eps) - self.deconv = nn.ConvTranspose2d(inner_dim, triplane_dim, kernel_size=2, stride=2, padding=0) - - def forward(self, image_feats): - # image_feats: [N, L_cond, D_cond] - - N = image_feats.shape[0] - H = W = self.triplane_low_res - L = 3 * H * W - - x = self.pos_embed.repeat(N, 1, 1) # [N, L, D] - for layer in self.layers: - x = layer(x, image_feats) - x = self.norm(x) - - # separate each plane and apply deconv - x = x.view(N, 3, H, W, -1) - x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W] - x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W] - x = self.deconv(x) # [3*N, D', H', W'] - x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W'] - x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W'] - x = x.contiguous() - - return x diff --git a/instant-mesh/src/models/encoder/__init__.py b/instant-mesh/src/models/encoder/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/instant-mesh/src/models/encoder/dino.py b/instant-mesh/src/models/encoder/dino.py deleted file mode 100644 index 684444cab2a13979bcd5688069e9f7729d4ca784..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/encoder/dino.py +++ /dev/null @@ -1,550 +0,0 @@ -# coding=utf-8 -# Copyright 2021 Google AI, Ross Wightman, The HuggingFace Inc. team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" PyTorch ViT model.""" - - -import collections.abc -import math -from typing import Dict, List, Optional, Set, Tuple, Union - -import torch -from torch import nn - -from transformers.activations import ACT2FN -from transformers.modeling_outputs import ( - BaseModelOutput, - BaseModelOutputWithPooling, -) -from transformers import PreTrainedModel, ViTConfig -from transformers.pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer - - -class ViTEmbeddings(nn.Module): - """ - Construct the CLS token, position and patch embeddings. Optionally, also the mask token. - """ - - def __init__(self, config: ViTConfig, use_mask_token: bool = False) -> None: - super().__init__() - - self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size)) - self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size)) if use_mask_token else None - self.patch_embeddings = ViTPatchEmbeddings(config) - num_patches = self.patch_embeddings.num_patches - self.position_embeddings = nn.Parameter(torch.randn(1, num_patches + 1, config.hidden_size)) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - self.config = config - - def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor: - """ - This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher - resolution images. - - Source: - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174 - """ - - num_patches = embeddings.shape[1] - 1 - num_positions = self.position_embeddings.shape[1] - 1 - if num_patches == num_positions and height == width: - return self.position_embeddings - class_pos_embed = self.position_embeddings[:, 0] - patch_pos_embed = self.position_embeddings[:, 1:] - dim = embeddings.shape[-1] - h0 = height // self.config.patch_size - w0 = width // self.config.patch_size - # we add a small number to avoid floating point error in the interpolation - # see discussion at https://github.com/facebookresearch/dino/issues/8 - h0, w0 = h0 + 0.1, w0 + 0.1 - patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim) - patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2) - patch_pos_embed = nn.functional.interpolate( - patch_pos_embed, - scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)), - mode="bicubic", - align_corners=False, - ) - assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1] - patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) - return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1) - - def forward( - self, - pixel_values: torch.Tensor, - bool_masked_pos: Optional[torch.BoolTensor] = None, - interpolate_pos_encoding: bool = False, - ) -> torch.Tensor: - batch_size, num_channels, height, width = pixel_values.shape - embeddings = self.patch_embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding) - - if bool_masked_pos is not None: - seq_length = embeddings.shape[1] - mask_tokens = self.mask_token.expand(batch_size, seq_length, -1) - # replace the masked visual tokens by mask_tokens - mask = bool_masked_pos.unsqueeze(-1).type_as(mask_tokens) - embeddings = embeddings * (1.0 - mask) + mask_tokens * mask - - # add the [CLS] token to the embedded patch tokens - cls_tokens = self.cls_token.expand(batch_size, -1, -1) - embeddings = torch.cat((cls_tokens, embeddings), dim=1) - - # add positional encoding to each token - if interpolate_pos_encoding: - embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width) - else: - embeddings = embeddings + self.position_embeddings - - embeddings = self.dropout(embeddings) - - return embeddings - - -class ViTPatchEmbeddings(nn.Module): - """ - This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial - `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a - Transformer. - """ - - def __init__(self, config): - super().__init__() - image_size, patch_size = config.image_size, config.patch_size - num_channels, hidden_size = config.num_channels, config.hidden_size - - image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) - patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) - num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) - self.image_size = image_size - self.patch_size = patch_size - self.num_channels = num_channels - self.num_patches = num_patches - - self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) - - def forward(self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False) -> torch.Tensor: - batch_size, num_channels, height, width = pixel_values.shape - if num_channels != self.num_channels: - raise ValueError( - "Make sure that the channel dimension of the pixel values match with the one set in the configuration." - f" Expected {self.num_channels} but got {num_channels}." - ) - if not interpolate_pos_encoding: - if height != self.image_size[0] or width != self.image_size[1]: - raise ValueError( - f"Input image size ({height}*{width}) doesn't match model" - f" ({self.image_size[0]}*{self.image_size[1]})." - ) - embeddings = self.projection(pixel_values).flatten(2).transpose(1, 2) - return embeddings - - -class ViTSelfAttention(nn.Module): - def __init__(self, config: ViTConfig) -> None: - super().__init__() - if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"): - raise ValueError( - f"The hidden size {config.hidden_size,} is not a multiple of the number of attention " - f"heads {config.num_attention_heads}." - ) - - self.num_attention_heads = config.num_attention_heads - self.attention_head_size = int(config.hidden_size / config.num_attention_heads) - self.all_head_size = self.num_attention_heads * self.attention_head_size - - self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias) - - self.dropout = nn.Dropout(config.attention_probs_dropout_prob) - - def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor: - new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) - x = x.view(new_x_shape) - return x.permute(0, 2, 1, 3) - - def forward( - self, hidden_states, head_mask: Optional[torch.Tensor] = None, output_attentions: bool = False - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - mixed_query_layer = self.query(hidden_states) - - key_layer = self.transpose_for_scores(self.key(hidden_states)) - value_layer = self.transpose_for_scores(self.value(hidden_states)) - query_layer = self.transpose_for_scores(mixed_query_layer) - - # Take the dot product between "query" and "key" to get the raw attention scores. - attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) - - attention_scores = attention_scores / math.sqrt(self.attention_head_size) - - # Normalize the attention scores to probabilities. - attention_probs = nn.functional.softmax(attention_scores, dim=-1) - - # This is actually dropping out entire tokens to attend to, which might - # seem a bit unusual, but is taken from the original Transformer paper. - attention_probs = self.dropout(attention_probs) - - # Mask heads if we want to - if head_mask is not None: - attention_probs = attention_probs * head_mask - - context_layer = torch.matmul(attention_probs, value_layer) - - context_layer = context_layer.permute(0, 2, 1, 3).contiguous() - new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) - context_layer = context_layer.view(new_context_layer_shape) - - outputs = (context_layer, attention_probs) if output_attentions else (context_layer,) - - return outputs - - -class ViTSelfOutput(nn.Module): - """ - The residual connection is defined in ViTLayer instead of here (as is the case with other models), due to the - layernorm applied before each block. - """ - - def __init__(self, config: ViTConfig) -> None: - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - - return hidden_states - - -class ViTAttention(nn.Module): - def __init__(self, config: ViTConfig) -> None: - super().__init__() - self.attention = ViTSelfAttention(config) - self.output = ViTSelfOutput(config) - self.pruned_heads = set() - - def prune_heads(self, heads: Set[int]) -> None: - if len(heads) == 0: - return - heads, index = find_pruneable_heads_and_indices( - heads, self.attention.num_attention_heads, self.attention.attention_head_size, self.pruned_heads - ) - - # Prune linear layers - self.attention.query = prune_linear_layer(self.attention.query, index) - self.attention.key = prune_linear_layer(self.attention.key, index) - self.attention.value = prune_linear_layer(self.attention.value, index) - self.output.dense = prune_linear_layer(self.output.dense, index, dim=1) - - # Update hyper params and store pruned heads - self.attention.num_attention_heads = self.attention.num_attention_heads - len(heads) - self.attention.all_head_size = self.attention.attention_head_size * self.attention.num_attention_heads - self.pruned_heads = self.pruned_heads.union(heads) - - def forward( - self, - hidden_states: torch.Tensor, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - self_outputs = self.attention(hidden_states, head_mask, output_attentions) - - attention_output = self.output(self_outputs[0], hidden_states) - - outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them - return outputs - - -class ViTIntermediate(nn.Module): - def __init__(self, config: ViTConfig) -> None: - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.intermediate_size) - if isinstance(config.hidden_act, str): - self.intermediate_act_fn = ACT2FN[config.hidden_act] - else: - self.intermediate_act_fn = config.hidden_act - - def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.intermediate_act_fn(hidden_states) - - return hidden_states - - -class ViTOutput(nn.Module): - def __init__(self, config: ViTConfig) -> None: - super().__init__() - self.dense = nn.Linear(config.intermediate_size, config.hidden_size) - self.dropout = nn.Dropout(config.hidden_dropout_prob) - - def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor: - hidden_states = self.dense(hidden_states) - hidden_states = self.dropout(hidden_states) - - hidden_states = hidden_states + input_tensor - - return hidden_states - - -def modulate(x, shift, scale): - return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) - - -class ViTLayer(nn.Module): - """This corresponds to the Block class in the timm implementation.""" - - def __init__(self, config: ViTConfig) -> None: - super().__init__() - self.chunk_size_feed_forward = config.chunk_size_feed_forward - self.seq_len_dim = 1 - self.attention = ViTAttention(config) - self.intermediate = ViTIntermediate(config) - self.output = ViTOutput(config) - self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - - self.adaLN_modulation = nn.Sequential( - nn.SiLU(), - nn.Linear(config.hidden_size, 4 * config.hidden_size, bias=True) - ) - nn.init.constant_(self.adaLN_modulation[-1].weight, 0) - nn.init.constant_(self.adaLN_modulation[-1].bias, 0) - - def forward( - self, - hidden_states: torch.Tensor, - adaln_input: torch.Tensor = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - ) -> Union[Tuple[torch.Tensor, torch.Tensor], Tuple[torch.Tensor]]: - shift_msa, scale_msa, shift_mlp, scale_mlp = self.adaLN_modulation(adaln_input).chunk(4, dim=1) - - self_attention_outputs = self.attention( - modulate(self.layernorm_before(hidden_states), shift_msa, scale_msa), # in ViT, layernorm is applied before self-attention - head_mask, - output_attentions=output_attentions, - ) - attention_output = self_attention_outputs[0] - outputs = self_attention_outputs[1:] # add self attentions if we output attention weights - - # first residual connection - hidden_states = attention_output + hidden_states - - # in ViT, layernorm is also applied after self-attention - layer_output = modulate(self.layernorm_after(hidden_states), shift_mlp, scale_mlp) - layer_output = self.intermediate(layer_output) - - # second residual connection is done here - layer_output = self.output(layer_output, hidden_states) - - outputs = (layer_output,) + outputs - - return outputs - - -class ViTEncoder(nn.Module): - def __init__(self, config: ViTConfig) -> None: - super().__init__() - self.config = config - self.layer = nn.ModuleList([ViTLayer(config) for _ in range(config.num_hidden_layers)]) - self.gradient_checkpointing = False - - def forward( - self, - hidden_states: torch.Tensor, - adaln_input: torch.Tensor = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, - output_hidden_states: bool = False, - return_dict: bool = True, - ) -> Union[tuple, BaseModelOutput]: - all_hidden_states = () if output_hidden_states else None - all_self_attentions = () if output_attentions else None - - for i, layer_module in enumerate(self.layer): - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - layer_head_mask = head_mask[i] if head_mask is not None else None - - if self.gradient_checkpointing and self.training: - layer_outputs = self._gradient_checkpointing_func( - layer_module.__call__, - hidden_states, - adaln_input, - layer_head_mask, - output_attentions, - ) - else: - layer_outputs = layer_module(hidden_states, adaln_input, layer_head_mask, output_attentions) - - hidden_states = layer_outputs[0] - - if output_attentions: - all_self_attentions = all_self_attentions + (layer_outputs[1],) - - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, all_hidden_states, all_self_attentions] if v is not None) - return BaseModelOutput( - last_hidden_state=hidden_states, - hidden_states=all_hidden_states, - attentions=all_self_attentions, - ) - - -class ViTPreTrainedModel(PreTrainedModel): - """ - An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained - models. - """ - - config_class = ViTConfig - base_model_prefix = "vit" - main_input_name = "pixel_values" - supports_gradient_checkpointing = True - _no_split_modules = ["ViTEmbeddings", "ViTLayer"] - - def _init_weights(self, module: Union[nn.Linear, nn.Conv2d, nn.LayerNorm]) -> None: - """Initialize the weights""" - if isinstance(module, (nn.Linear, nn.Conv2d)): - # Upcast the input in `fp32` and cast it back to desired `dtype` to avoid - # `trunc_normal_cpu` not implemented in `half` issues - module.weight.data = nn.init.trunc_normal_( - module.weight.data.to(torch.float32), mean=0.0, std=self.config.initializer_range - ).to(module.weight.dtype) - if module.bias is not None: - module.bias.data.zero_() - elif isinstance(module, nn.LayerNorm): - module.bias.data.zero_() - module.weight.data.fill_(1.0) - elif isinstance(module, ViTEmbeddings): - module.position_embeddings.data = nn.init.trunc_normal_( - module.position_embeddings.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.position_embeddings.dtype) - - module.cls_token.data = nn.init.trunc_normal_( - module.cls_token.data.to(torch.float32), - mean=0.0, - std=self.config.initializer_range, - ).to(module.cls_token.dtype) - - -class ViTModel(ViTPreTrainedModel): - def __init__(self, config: ViTConfig, add_pooling_layer: bool = True, use_mask_token: bool = False): - super().__init__(config) - self.config = config - - self.embeddings = ViTEmbeddings(config, use_mask_token=use_mask_token) - self.encoder = ViTEncoder(config) - - self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.pooler = ViTPooler(config) if add_pooling_layer else None - - # Initialize weights and apply final processing - self.post_init() - - def get_input_embeddings(self) -> ViTPatchEmbeddings: - return self.embeddings.patch_embeddings - - def _prune_heads(self, heads_to_prune: Dict[int, List[int]]) -> None: - """ - Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base - class PreTrainedModel - """ - for layer, heads in heads_to_prune.items(): - self.encoder.layer[layer].attention.prune_heads(heads) - - def forward( - self, - pixel_values: Optional[torch.Tensor] = None, - adaln_input: Optional[torch.Tensor] = None, - bool_masked_pos: Optional[torch.BoolTensor] = None, - head_mask: Optional[torch.Tensor] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - interpolate_pos_encoding: Optional[bool] = None, - return_dict: Optional[bool] = None, - ) -> Union[Tuple, BaseModelOutputWithPooling]: - r""" - bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, num_patches)`, *optional*): - Boolean masked positions. Indicates which patches are masked (1) and which aren't (0). - """ - output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - - if pixel_values is None: - raise ValueError("You have to specify pixel_values") - - # Prepare head mask if needed - # 1.0 in head_mask indicate we keep the head - # attention_probs has shape bsz x n_heads x N x N - # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads] - # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length] - head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers) - - # TODO: maybe have a cleaner way to cast the input (from `ImageProcessor` side?) - expected_dtype = self.embeddings.patch_embeddings.projection.weight.dtype - if pixel_values.dtype != expected_dtype: - pixel_values = pixel_values.to(expected_dtype) - - embedding_output = self.embeddings( - pixel_values, bool_masked_pos=bool_masked_pos, interpolate_pos_encoding=interpolate_pos_encoding - ) - - encoder_outputs = self.encoder( - embedding_output, - adaln_input=adaln_input, - head_mask=head_mask, - output_attentions=output_attentions, - output_hidden_states=output_hidden_states, - return_dict=return_dict, - ) - sequence_output = encoder_outputs[0] - sequence_output = self.layernorm(sequence_output) - pooled_output = self.pooler(sequence_output) if self.pooler is not None else None - - if not return_dict: - head_outputs = (sequence_output, pooled_output) if pooled_output is not None else (sequence_output,) - return head_outputs + encoder_outputs[1:] - - return BaseModelOutputWithPooling( - last_hidden_state=sequence_output, - pooler_output=pooled_output, - hidden_states=encoder_outputs.hidden_states, - attentions=encoder_outputs.attentions, - ) - - -class ViTPooler(nn.Module): - def __init__(self, config: ViTConfig): - super().__init__() - self.dense = nn.Linear(config.hidden_size, config.hidden_size) - self.activation = nn.Tanh() - - def forward(self, hidden_states): - # We "pool" the model by simply taking the hidden state corresponding - # to the first token. - first_token_tensor = hidden_states[:, 0] - pooled_output = self.dense(first_token_tensor) - pooled_output = self.activation(pooled_output) - return pooled_output \ No newline at end of file diff --git a/instant-mesh/src/models/encoder/dino_wrapper.py b/instant-mesh/src/models/encoder/dino_wrapper.py deleted file mode 100644 index e84fd51e7dfcfd1a969b763f5a49aeb7f608e6f9..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/encoder/dino_wrapper.py +++ /dev/null @@ -1,80 +0,0 @@ -# Copyright (c) 2023, Zexin He -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import torch.nn as nn -from transformers import ViTImageProcessor -from einops import rearrange, repeat -from .dino import ViTModel - - -class DinoWrapper(nn.Module): - """ - Dino v1 wrapper using huggingface transformer implementation. - """ - def __init__(self, model_name: str, freeze: bool = True): - super().__init__() - self.model, self.processor = self._build_dino(model_name) - self.camera_embedder = nn.Sequential( - nn.Linear(16, self.model.config.hidden_size, bias=True), - nn.SiLU(), - nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size, bias=True) - ) - if freeze: - self._freeze() - - def forward(self, image, camera): - # image: [B, N, C, H, W] - # camera: [B, N, D] - # RGB image with [0,1] scale and properly sized - if image.ndim == 5: - image = rearrange(image, 'b n c h w -> (b n) c h w') - dtype = image.dtype - inputs = self.processor( - images=image.float(), - return_tensors="pt", - do_rescale=False, - do_resize=False, - ).to(self.model.device).to(dtype) - # embed camera - N = camera.shape[1] - camera_embeddings = self.camera_embedder(camera) - camera_embeddings = rearrange(camera_embeddings, 'b n d -> (b n) d') - embeddings = camera_embeddings - # This resampling of positional embedding uses bicubic interpolation - outputs = self.model(**inputs, adaln_input=embeddings, interpolate_pos_encoding=True) - last_hidden_states = outputs.last_hidden_state - return last_hidden_states - - def _freeze(self): - print(f"======== Freezing DinoWrapper ========") - self.model.eval() - for name, param in self.model.named_parameters(): - param.requires_grad = False - - @staticmethod - def _build_dino(model_name: str, proxy_error_retries: int = 3, proxy_error_cooldown: int = 5): - import requests - try: - model = ViTModel.from_pretrained(model_name, add_pooling_layer=False) - processor = ViTImageProcessor.from_pretrained(model_name) - return model, processor - except requests.exceptions.ProxyError as err: - if proxy_error_retries > 0: - print(f"Huggingface ProxyError: Retrying in {proxy_error_cooldown} seconds...") - import time - time.sleep(proxy_error_cooldown) - return DinoWrapper._build_dino(model_name, proxy_error_retries - 1, proxy_error_cooldown) - else: - raise err diff --git a/instant-mesh/src/models/geometry/__init__.py b/instant-mesh/src/models/geometry/__init__.py deleted file mode 100644 index 89e9a6c2fffe82a55693885dae78c1a630924389..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/geometry/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. diff --git a/instant-mesh/src/models/geometry/camera/__init__.py b/instant-mesh/src/models/geometry/camera/__init__.py deleted file mode 100644 index c5c7082e47c65a08e25489b3c3fd010d07ad9758..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/geometry/camera/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. - -import torch -from torch import nn - - -class Camera(nn.Module): - def __init__(self): - super(Camera, self).__init__() - pass diff --git a/instant-mesh/src/models/geometry/camera/perspective_camera.py b/instant-mesh/src/models/geometry/camera/perspective_camera.py deleted file mode 100644 index 7dcab0d2a321a77a5d3c2d4c3f40ba2cc32f6dfa..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/geometry/camera/perspective_camera.py +++ /dev/null @@ -1,35 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. - -import torch -from . import Camera -import numpy as np - - -def projection(x=0.1, n=1.0, f=50.0, near_plane=None): - if near_plane is None: - near_plane = n - return np.array( - [[n / x, 0, 0, 0], - [0, n / -x, 0, 0], - [0, 0, -(f + near_plane) / (f - near_plane), -(2 * f * near_plane) / (f - near_plane)], - [0, 0, -1, 0]]).astype(np.float32) - - -class PerspectiveCamera(Camera): - def __init__(self, fovy=49.0, device='cuda'): - super(PerspectiveCamera, self).__init__() - self.device = device - focal = np.tan(fovy / 180.0 * np.pi * 0.5) - self.proj_mtx = torch.from_numpy(projection(x=focal, f=1000.0, n=1.0, near_plane=0.1)).to(self.device).unsqueeze(dim=0) - - def project(self, points_bxnx4): - out = torch.matmul( - points_bxnx4, - torch.transpose(self.proj_mtx, 1, 2)) - return out diff --git a/instant-mesh/src/models/geometry/render/__init__.py b/instant-mesh/src/models/geometry/render/__init__.py deleted file mode 100644 index 483cfabbf395853f1ca3e67b856d5f17b9889d1b..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/geometry/render/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -import torch - -class Renderer(): - def __init__(self): - pass - - def forward(self): - pass \ No newline at end of file diff --git a/instant-mesh/src/models/geometry/render/neural_render.py b/instant-mesh/src/models/geometry/render/neural_render.py deleted file mode 100644 index 473464480125c050ee6dba973450678a197145fb..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/geometry/render/neural_render.py +++ /dev/null @@ -1,121 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. - -import torch -import torch.nn.functional as F -import nvdiffrast.torch as dr -from . import Renderer - -_FG_LUT = None - - -def interpolate(attr, rast, attr_idx, rast_db=None): - return dr.interpolate( - attr.contiguous(), rast, attr_idx, rast_db=rast_db, - diff_attrs=None if rast_db is None else 'all') - - -def xfm_points(points, matrix, use_python=True): - '''Transform points. - Args: - points: Tensor containing 3D points with shape [minibatch_size, num_vertices, 3] or [1, num_vertices, 3] - matrix: A 4x4 transform matrix with shape [minibatch_size, 4, 4] - use_python: Use PyTorch's torch.matmul (for validation) - Returns: - Transformed points in homogeneous 4D with shape [minibatch_size, num_vertices, 4]. - ''' - out = torch.matmul(torch.nn.functional.pad(points, pad=(0, 1), mode='constant', value=1.0), torch.transpose(matrix, 1, 2)) - if torch.is_anomaly_enabled(): - assert torch.all(torch.isfinite(out)), "Output of xfm_points contains inf or NaN" - return out - - -def dot(x, y): - return torch.sum(x * y, -1, keepdim=True) - - -def compute_vertex_normal(v_pos, t_pos_idx): - i0 = t_pos_idx[:, 0] - i1 = t_pos_idx[:, 1] - i2 = t_pos_idx[:, 2] - - v0 = v_pos[i0, :] - v1 = v_pos[i1, :] - v2 = v_pos[i2, :] - - face_normals = torch.cross(v1 - v0, v2 - v0) - - # Splat face normals to vertices - v_nrm = torch.zeros_like(v_pos) - v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) - v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) - v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) - - # Normalize, replace zero (degenerated) normals with some default value - v_nrm = torch.where( - dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) - ) - v_nrm = F.normalize(v_nrm, dim=1) - assert torch.all(torch.isfinite(v_nrm)) - - return v_nrm - - -class NeuralRender(Renderer): - def __init__(self, device='cuda', camera_model=None): - super(NeuralRender, self).__init__() - self.device = device - self.ctx = dr.RasterizeCudaContext(device=device) - self.projection_mtx = None - self.camera = camera_model - - def render_mesh( - self, - mesh_v_pos_bxnx3, - mesh_t_pos_idx_fx3, - camera_mv_bx4x4, - mesh_v_feat_bxnxd, - resolution=256, - spp=1, - device='cuda', - hierarchical_mask=False - ): - assert not hierarchical_mask - - mtx_in = torch.tensor(camera_mv_bx4x4, dtype=torch.float32, device=device) if not torch.is_tensor(camera_mv_bx4x4) else camera_mv_bx4x4 - v_pos = xfm_points(mesh_v_pos_bxnx3, mtx_in) # Rotate it to camera coordinates - v_pos_clip = self.camera.project(v_pos) # Projection in the camera - - v_nrm = compute_vertex_normal(mesh_v_pos_bxnx3[0], mesh_t_pos_idx_fx3.long()) # vertex normals in world coordinates - - # Render the image, - # Here we only return the feature (3D location) at each pixel, which will be used as the input for neural render - num_layers = 1 - mask_pyramid = None - assert mesh_t_pos_idx_fx3.shape[0] > 0 # Make sure we have shapes - mesh_v_feat_bxnxd = torch.cat([mesh_v_feat_bxnxd.repeat(v_pos.shape[0], 1, 1), v_pos], dim=-1) # Concatenate the pos - - with dr.DepthPeeler(self.ctx, v_pos_clip, mesh_t_pos_idx_fx3, [resolution * spp, resolution * spp]) as peeler: - for _ in range(num_layers): - rast, db = peeler.rasterize_next_layer() - gb_feat, _ = interpolate(mesh_v_feat_bxnxd, rast, mesh_t_pos_idx_fx3) - - hard_mask = torch.clamp(rast[..., -1:], 0, 1) - antialias_mask = dr.antialias( - hard_mask.clone().contiguous(), rast, v_pos_clip, - mesh_t_pos_idx_fx3) - - depth = gb_feat[..., -2:-1] - ori_mesh_feature = gb_feat[..., :-4] - - normal, _ = interpolate(v_nrm[None, ...], rast, mesh_t_pos_idx_fx3) - normal = dr.antialias(normal.clone().contiguous(), rast, v_pos_clip, mesh_t_pos_idx_fx3) - normal = F.normalize(normal, dim=-1) - normal = torch.lerp(torch.zeros_like(normal), (normal + 1.0) / 2.0, hard_mask.float()) # black background - - return ori_mesh_feature, antialias_mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal diff --git a/instant-mesh/src/models/geometry/rep_3d/__init__.py b/instant-mesh/src/models/geometry/rep_3d/__init__.py deleted file mode 100644 index a3d5628a8433298477d1963f92578d47106b4a0f..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/geometry/rep_3d/__init__.py +++ /dev/null @@ -1,18 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. - -import torch -import numpy as np - - -class Geometry(): - def __init__(self): - pass - - def forward(self): - pass diff --git a/instant-mesh/src/models/geometry/rep_3d/dmtet.py b/instant-mesh/src/models/geometry/rep_3d/dmtet.py deleted file mode 100644 index b6a709380abac0bbf66fd1c8582485f3982223e4..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/geometry/rep_3d/dmtet.py +++ /dev/null @@ -1,504 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. - -import torch -import numpy as np -import os -from . import Geometry -from .dmtet_utils import get_center_boundary_index -import torch.nn.functional as F - - -############################################################################### -# DMTet utility functions -############################################################################### -def create_mt_variable(device): - triangle_table = torch.tensor( - [ - [-1, -1, -1, -1, -1, -1], - [1, 0, 2, -1, -1, -1], - [4, 0, 3, -1, -1, -1], - [1, 4, 2, 1, 3, 4], - [3, 1, 5, -1, -1, -1], - [2, 3, 0, 2, 5, 3], - [1, 4, 0, 1, 5, 4], - [4, 2, 5, -1, -1, -1], - [4, 5, 2, -1, -1, -1], - [4, 1, 0, 4, 5, 1], - [3, 2, 0, 3, 5, 2], - [1, 3, 5, -1, -1, -1], - [4, 1, 2, 4, 3, 1], - [3, 0, 4, -1, -1, -1], - [2, 0, 1, -1, -1, -1], - [-1, -1, -1, -1, -1, -1] - ], dtype=torch.long, device=device) - - num_triangles_table = torch.tensor([0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long, device=device) - base_tet_edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=device) - v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=device)) - return triangle_table, num_triangles_table, base_tet_edges, v_id - - -def sort_edges(edges_ex2): - with torch.no_grad(): - order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() - order = order.unsqueeze(dim=1) - a = torch.gather(input=edges_ex2, index=order, dim=1) - b = torch.gather(input=edges_ex2, index=1 - order, dim=1) - return torch.stack([a, b], -1) - - -############################################################################### -# marching tetrahedrons (differentiable) -############################################################################### - -def marching_tets(pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id): - with torch.no_grad(): - occ_n = sdf_n > 0 - occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) - occ_sum = torch.sum(occ_fx4, -1) - valid_tets = (occ_sum > 0) & (occ_sum < 4) - occ_sum = occ_sum[valid_tets] - - # find all vertices - all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) - all_edges = sort_edges(all_edges) - unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) - - unique_edges = unique_edges.long() - mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 - mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 - mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) - idx_map = mapping[idx_map] # map edges to verts - - interp_v = unique_edges[mask_edges] # .long() - edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) - edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) - edges_to_interp_sdf[:, -1] *= -1 - - denominator = edges_to_interp_sdf.sum(1, keepdim=True) - - edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator - verts = (edges_to_interp * edges_to_interp_sdf).sum(1) - - idx_map = idx_map.reshape(-1, 6) - - tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) - num_triangles = num_triangles_table[tetindex] - - # Generate triangle indices - faces = torch.cat( - ( - torch.gather( - input=idx_map[num_triangles == 1], dim=1, - index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), - torch.gather( - input=idx_map[num_triangles == 2], dim=1, - index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), - ), dim=0) - return verts, faces - - -def create_tetmesh_variables(device='cuda'): - tet_table = torch.tensor( - [[-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1], - [0, 4, 5, 6, -1, -1, -1, -1, -1, -1, -1, -1], - [1, 4, 7, 8, -1, -1, -1, -1, -1, -1, -1, -1], - [1, 0, 8, 7, 0, 5, 8, 7, 0, 5, 6, 8], - [2, 5, 7, 9, -1, -1, -1, -1, -1, -1, -1, -1], - [2, 0, 9, 7, 0, 4, 9, 7, 0, 4, 6, 9], - [2, 1, 9, 5, 1, 4, 9, 5, 1, 4, 8, 9], - [6, 0, 1, 2, 6, 1, 2, 8, 6, 8, 2, 9], - [3, 6, 8, 9, -1, -1, -1, -1, -1, -1, -1, -1], - [3, 0, 9, 8, 0, 4, 9, 8, 0, 4, 5, 9], - [3, 1, 9, 6, 1, 4, 9, 6, 1, 4, 7, 9], - [5, 0, 1, 3, 5, 1, 3, 7, 5, 7, 3, 9], - [3, 2, 8, 6, 2, 5, 8, 6, 2, 5, 7, 8], - [4, 0, 2, 3, 4, 2, 3, 7, 4, 7, 3, 8], - [4, 1, 2, 3, 4, 2, 3, 5, 4, 5, 3, 6], - [-1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1]], dtype=torch.long, device=device) - num_tets_table = torch.tensor([0, 1, 1, 3, 1, 3, 3, 3, 1, 3, 3, 3, 3, 3, 3, 0], dtype=torch.long, device=device) - return tet_table, num_tets_table - - -def marching_tets_tetmesh( - pos_nx3, sdf_n, tet_fx4, triangle_table, num_triangles_table, base_tet_edges, v_id, - return_tet_mesh=False, ori_v=None, num_tets_table=None, tet_table=None): - with torch.no_grad(): - occ_n = sdf_n > 0 - occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) - occ_sum = torch.sum(occ_fx4, -1) - valid_tets = (occ_sum > 0) & (occ_sum < 4) - occ_sum = occ_sum[valid_tets] - - # find all vertices - all_edges = tet_fx4[valid_tets][:, base_tet_edges].reshape(-1, 2) - all_edges = sort_edges(all_edges) - unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) - - unique_edges = unique_edges.long() - mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 - mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=sdf_n.device) * -1 - mapping[mask_edges] = torch.arange(mask_edges.sum(), dtype=torch.long, device=sdf_n.device) - idx_map = mapping[idx_map] # map edges to verts - - interp_v = unique_edges[mask_edges] # .long() - edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) - edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) - edges_to_interp_sdf[:, -1] *= -1 - - denominator = edges_to_interp_sdf.sum(1, keepdim=True) - - edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator - verts = (edges_to_interp * edges_to_interp_sdf).sum(1) - - idx_map = idx_map.reshape(-1, 6) - - tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) - num_triangles = num_triangles_table[tetindex] - - # Generate triangle indices - faces = torch.cat( - ( - torch.gather( - input=idx_map[num_triangles == 1], dim=1, - index=triangle_table[tetindex[num_triangles == 1]][:, :3]).reshape(-1, 3), - torch.gather( - input=idx_map[num_triangles == 2], dim=1, - index=triangle_table[tetindex[num_triangles == 2]][:, :6]).reshape(-1, 3), - ), dim=0) - if not return_tet_mesh: - return verts, faces - occupied_verts = ori_v[occ_n] - mapping = torch.ones((pos_nx3.shape[0]), dtype=torch.long, device="cuda") * -1 - mapping[occ_n] = torch.arange(occupied_verts.shape[0], device="cuda") - tet_fx4 = mapping[tet_fx4.reshape(-1)].reshape((-1, 4)) - - idx_map = torch.cat([tet_fx4[valid_tets] + verts.shape[0], idx_map], -1) # t x 10 - tet_verts = torch.cat([verts, occupied_verts], 0) - num_tets = num_tets_table[tetindex] - - tets = torch.cat( - ( - torch.gather(input=idx_map[num_tets == 1], dim=1, index=tet_table[tetindex[num_tets == 1]][:, :4]).reshape( - -1, - 4), - torch.gather(input=idx_map[num_tets == 3], dim=1, index=tet_table[tetindex[num_tets == 3]][:, :12]).reshape( - -1, - 4), - ), dim=0) - # add fully occupied tets - fully_occupied = occ_fx4.sum(-1) == 4 - tet_fully_occupied = tet_fx4[fully_occupied] + verts.shape[0] - tets = torch.cat([tets, tet_fully_occupied]) - - return verts, faces, tet_verts, tets - - -############################################################################### -# Compact tet grid -############################################################################### - -def compact_tets(pos_nx3, sdf_n, tet_fx4): - with torch.no_grad(): - # Find surface tets - occ_n = sdf_n > 0 - occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) - occ_sum = torch.sum(occ_fx4, -1) - valid_tets = (occ_sum > 0) & (occ_sum < 4) # one value per tet, these are the surface tets - - valid_vtx = tet_fx4[valid_tets].reshape(-1) - unique_vtx, idx_map = torch.unique(valid_vtx, dim=0, return_inverse=True) - new_pos = pos_nx3[unique_vtx] - new_sdf = sdf_n[unique_vtx] - new_tets = idx_map.reshape(-1, 4) - return new_pos, new_sdf, new_tets - - -############################################################################### -# Subdivide volume -############################################################################### - -def batch_subdivide_volume(tet_pos_bxnx3, tet_bxfx4, grid_sdf): - device = tet_pos_bxnx3.device - # get new verts - tet_fx4 = tet_bxfx4[0] - edges = [0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3] - all_edges = tet_fx4[:, edges].reshape(-1, 2) - all_edges = sort_edges(all_edges) - unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) - idx_map = idx_map + tet_pos_bxnx3.shape[1] - all_values = torch.cat([tet_pos_bxnx3, grid_sdf], -1) - mid_points_pos = all_values[:, unique_edges.reshape(-1)].reshape( - all_values.shape[0], -1, 2, - all_values.shape[-1]).mean(2) - new_v = torch.cat([all_values, mid_points_pos], 1) - new_v, new_sdf = new_v[..., :3], new_v[..., 3] - - # get new tets - - idx_a, idx_b, idx_c, idx_d = tet_fx4[:, 0], tet_fx4[:, 1], tet_fx4[:, 2], tet_fx4[:, 3] - idx_ab = idx_map[0::6] - idx_ac = idx_map[1::6] - idx_ad = idx_map[2::6] - idx_bc = idx_map[3::6] - idx_bd = idx_map[4::6] - idx_cd = idx_map[5::6] - - tet_1 = torch.stack([idx_a, idx_ab, idx_ac, idx_ad], dim=1) - tet_2 = torch.stack([idx_b, idx_bc, idx_ab, idx_bd], dim=1) - tet_3 = torch.stack([idx_c, idx_ac, idx_bc, idx_cd], dim=1) - tet_4 = torch.stack([idx_d, idx_ad, idx_cd, idx_bd], dim=1) - tet_5 = torch.stack([idx_ab, idx_ac, idx_ad, idx_bd], dim=1) - tet_6 = torch.stack([idx_ab, idx_ac, idx_bd, idx_bc], dim=1) - tet_7 = torch.stack([idx_cd, idx_ac, idx_bd, idx_ad], dim=1) - tet_8 = torch.stack([idx_cd, idx_ac, idx_bc, idx_bd], dim=1) - - tet_np = torch.cat([tet_1, tet_2, tet_3, tet_4, tet_5, tet_6, tet_7, tet_8], dim=0) - tet_np = tet_np.reshape(1, -1, 4).expand(tet_pos_bxnx3.shape[0], -1, -1) - tet = tet_np.long().to(device) - - return new_v, tet, new_sdf - - -############################################################################### -# Adjacency -############################################################################### -def tet_to_tet_adj_sparse(tet_tx4): - # include self connection!!!!!!!!!!!!!!!!!!! - with torch.no_grad(): - t = tet_tx4.shape[0] - device = tet_tx4.device - idx_array = torch.LongTensor( - [0, 1, 2, - 1, 0, 3, - 2, 3, 0, - 3, 2, 1]).to(device).reshape(4, 3).unsqueeze(0).expand(t, -1, -1) # (t, 4, 3) - - # get all faces - all_faces = torch.gather(input=tet_tx4.unsqueeze(1).expand(-1, 4, -1), index=idx_array, dim=-1).reshape( - -1, - 3) # (tx4, 3) - all_faces_tet_idx = torch.arange(t, device=device).unsqueeze(-1).expand(-1, 4).reshape(-1) - # sort and group - all_faces_sorted, _ = torch.sort(all_faces, dim=1) - - all_faces_unique, inverse_indices, counts = torch.unique( - all_faces_sorted, dim=0, return_counts=True, - return_inverse=True) - tet_face_fx3 = all_faces_unique[counts == 2] - counts = counts[inverse_indices] # tx4 - valid = (counts == 2) - - group = inverse_indices[valid] - # print (inverse_indices.shape, group.shape, all_faces_tet_idx.shape) - _, indices = torch.sort(group) - all_faces_tet_idx_grouped = all_faces_tet_idx[valid][indices] - tet_face_tetidx_fx2 = torch.stack([all_faces_tet_idx_grouped[::2], all_faces_tet_idx_grouped[1::2]], dim=-1) - - tet_adj_idx = torch.cat([tet_face_tetidx_fx2, torch.flip(tet_face_tetidx_fx2, [1])]) - adj_self = torch.arange(t, device=tet_tx4.device) - adj_self = torch.stack([adj_self, adj_self], -1) - tet_adj_idx = torch.cat([tet_adj_idx, adj_self]) - - tet_adj_idx = torch.unique(tet_adj_idx, dim=0) - values = torch.ones( - tet_adj_idx.shape[0], device=tet_tx4.device).float() - adj_sparse = torch.sparse.FloatTensor( - tet_adj_idx.t(), values, torch.Size([t, t])) - - # normalization - neighbor_num = 1.0 / torch.sparse.sum( - adj_sparse, dim=1).to_dense() - values = torch.index_select(neighbor_num, 0, tet_adj_idx[:, 0]) - adj_sparse = torch.sparse.FloatTensor( - tet_adj_idx.t(), values, torch.Size([t, t])) - return adj_sparse - - -############################################################################### -# Compact grid -############################################################################### - -def get_tet_bxfx4x3(bxnxz, bxfx4): - n_batch, z = bxnxz.shape[0], bxnxz.shape[2] - gather_input = bxnxz.unsqueeze(2).expand( - n_batch, bxnxz.shape[1], 4, z) - gather_index = bxfx4.unsqueeze(-1).expand( - n_batch, bxfx4.shape[1], 4, z).long() - tet_bxfx4xz = torch.gather( - input=gather_input, dim=1, index=gather_index) - - return tet_bxfx4xz - - -def shrink_grid(tet_pos_bxnx3, tet_bxfx4, grid_sdf): - with torch.no_grad(): - assert tet_pos_bxnx3.shape[0] == 1 - - occ = grid_sdf[0] > 0 - occ_sum = get_tet_bxfx4x3(occ.unsqueeze(0).unsqueeze(-1), tet_bxfx4).reshape(-1, 4).sum(-1) - mask = (occ_sum > 0) & (occ_sum < 4) - - # build connectivity graph - adj_matrix = tet_to_tet_adj_sparse(tet_bxfx4[0]) - mask = mask.float().unsqueeze(-1) - - # Include a one ring of neighbors - for i in range(1): - mask = torch.sparse.mm(adj_matrix, mask) - mask = mask.squeeze(-1) > 0 - - mapping = torch.zeros((tet_pos_bxnx3.shape[1]), device=tet_pos_bxnx3.device, dtype=torch.long) - new_tet_bxfx4 = tet_bxfx4[:, mask].long() - selected_verts_idx = torch.unique(new_tet_bxfx4) - new_tet_pos_bxnx3 = tet_pos_bxnx3[:, selected_verts_idx] - mapping[selected_verts_idx] = torch.arange(selected_verts_idx.shape[0], device=tet_pos_bxnx3.device) - new_tet_bxfx4 = mapping[new_tet_bxfx4.reshape(-1)].reshape(new_tet_bxfx4.shape) - new_grid_sdf = grid_sdf[:, selected_verts_idx] - return new_tet_pos_bxnx3, new_tet_bxfx4, new_grid_sdf - - -############################################################################### -# Regularizer -############################################################################### - -def sdf_reg_loss(sdf, all_edges): - sdf_f1x6x2 = sdf[all_edges.reshape(-1)].reshape(-1, 2) - mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) - sdf_f1x6x2 = sdf_f1x6x2[mask] - sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits( - sdf_f1x6x2[..., 0], - (sdf_f1x6x2[..., 1] > 0).float()) + \ - torch.nn.functional.binary_cross_entropy_with_logits( - sdf_f1x6x2[..., 1], - (sdf_f1x6x2[..., 0] > 0).float()) - return sdf_diff - - -def sdf_reg_loss_batch(sdf, all_edges): - sdf_f1x6x2 = sdf[:, all_edges.reshape(-1)].reshape(sdf.shape[0], -1, 2) - mask = torch.sign(sdf_f1x6x2[..., 0]) != torch.sign(sdf_f1x6x2[..., 1]) - sdf_f1x6x2 = sdf_f1x6x2[mask] - sdf_diff = torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 0], (sdf_f1x6x2[..., 1] > 0).float()) + \ - torch.nn.functional.binary_cross_entropy_with_logits(sdf_f1x6x2[..., 1], (sdf_f1x6x2[..., 0] > 0).float()) - return sdf_diff - - -############################################################################### -# Geometry interface -############################################################################### -class DMTetGeometry(Geometry): - def __init__( - self, grid_res=64, scale=2.0, device='cuda', renderer=None, - render_type='neural_render', args=None): - super(DMTetGeometry, self).__init__() - self.grid_res = grid_res - self.device = device - self.args = args - tets = np.load('data/tets/%d_compress.npz' % (grid_res)) - self.verts = torch.from_numpy(tets['vertices']).float().to(self.device) - # Make sure the tet is zero-centered and length is equal to 1 - length = self.verts.max(dim=0)[0] - self.verts.min(dim=0)[0] - length = length.max() - mid = (self.verts.max(dim=0)[0] + self.verts.min(dim=0)[0]) / 2.0 - self.verts = (self.verts - mid.unsqueeze(dim=0)) / length - if isinstance(scale, list): - self.verts[:, 0] = self.verts[:, 0] * scale[0] - self.verts[:, 1] = self.verts[:, 1] * scale[1] - self.verts[:, 2] = self.verts[:, 2] * scale[1] - else: - self.verts = self.verts * scale - self.indices = torch.from_numpy(tets['tets']).long().to(self.device) - self.triangle_table, self.num_triangles_table, self.base_tet_edges, self.v_id = create_mt_variable(self.device) - self.tet_table, self.num_tets_table = create_tetmesh_variables(self.device) - # Parameters for regularization computation - edges = torch.tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long, device=self.device) - all_edges = self.indices[:, edges].reshape(-1, 2) - all_edges_sorted = torch.sort(all_edges, dim=1)[0] - self.all_edges = torch.unique(all_edges_sorted, dim=0) - - # Parameters used for fix boundary sdf - self.center_indices, self.boundary_indices = get_center_boundary_index(self.verts) - self.renderer = renderer - self.render_type = render_type - - def getAABB(self): - return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values - - def get_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): - if indices is None: - indices = self.indices - verts, faces = marching_tets( - v_deformed_nx3, sdf_n, indices, self.triangle_table, - self.num_triangles_table, self.base_tet_edges, self.v_id) - faces = torch.cat( - [faces[:, 0:1], - faces[:, 2:3], - faces[:, 1:2], ], dim=-1) - return verts, faces - - def get_tet_mesh(self, v_deformed_nx3, sdf_n, with_uv=False, indices=None): - if indices is None: - indices = self.indices - verts, faces, tet_verts, tets = marching_tets_tetmesh( - v_deformed_nx3, sdf_n, indices, self.triangle_table, - self.num_triangles_table, self.base_tet_edges, self.v_id, return_tet_mesh=True, - num_tets_table=self.num_tets_table, tet_table=self.tet_table, ori_v=v_deformed_nx3) - faces = torch.cat( - [faces[:, 0:1], - faces[:, 2:3], - faces[:, 1:2], ], dim=-1) - return verts, faces, tet_verts, tets - - def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): - return_value = dict() - if self.render_type == 'neural_render': - tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth = self.renderer.render_mesh( - mesh_v_nx3.unsqueeze(dim=0), - mesh_f_fx3.int(), - camera_mv_bx4x4, - mesh_v_nx3.unsqueeze(dim=0), - resolution=resolution, - device=self.device, - hierarchical_mask=hierarchical_mask - ) - - return_value['tex_pos'] = tex_pos - return_value['mask'] = mask - return_value['hard_mask'] = hard_mask - return_value['rast'] = rast - return_value['v_pos_clip'] = v_pos_clip - return_value['mask_pyramid'] = mask_pyramid - return_value['depth'] = depth - else: - raise NotImplementedError - - return return_value - - def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): - # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 - v_list = [] - f_list = [] - n_batch = v_deformed_bxnx3.shape[0] - all_render_output = [] - for i_batch in range(n_batch): - verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) - v_list.append(verts_nx3) - f_list.append(faces_fx3) - render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) - all_render_output.append(render_output) - - # Concatenate all render output - return_keys = all_render_output[0].keys() - return_value = dict() - for k in return_keys: - value = [v[k] for v in all_render_output] - return_value[k] = value - # We can do concatenation outside of the render - return return_value diff --git a/instant-mesh/src/models/geometry/rep_3d/dmtet_utils.py b/instant-mesh/src/models/geometry/rep_3d/dmtet_utils.py deleted file mode 100644 index 8d466a9e78c49d947c115707693aa18d759885ad..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/geometry/rep_3d/dmtet_utils.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. - -import torch - - -def get_center_boundary_index(verts): - length_ = torch.sum(verts ** 2, dim=-1) - center_idx = torch.argmin(length_) - boundary_neg = verts == verts.max() - boundary_pos = verts == verts.min() - boundary = torch.bitwise_or(boundary_pos, boundary_neg) - boundary = torch.sum(boundary.float(), dim=-1) - boundary_idx = torch.nonzero(boundary) - return center_idx, boundary_idx.squeeze(dim=-1) diff --git a/instant-mesh/src/models/geometry/rep_3d/extract_texture_map.py b/instant-mesh/src/models/geometry/rep_3d/extract_texture_map.py deleted file mode 100644 index a5d62bb5a6c5cdf632fb504db3d2dfa99a3abbd3..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/geometry/rep_3d/extract_texture_map.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. - -import torch -import xatlas -import numpy as np -import nvdiffrast.torch as dr - - -# ============================================================================================== -def interpolate(attr, rast, attr_idx, rast_db=None): - return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') - - -def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): - vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) - - # Convert to tensors - indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) - - uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) - mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) - # mesh_v_tex. ture - uv_clip = uvs[None, ...] * 2.0 - 1.0 - - # pad to four component coordinate - uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) - - # rasterize - rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) - - # Interpolate world space position - gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) - mask = rast[..., 3:4] > 0 - return uvs, mesh_tex_idx, gb_pos, mask diff --git a/instant-mesh/src/models/geometry/rep_3d/flexicubes.py b/instant-mesh/src/models/geometry/rep_3d/flexicubes.py deleted file mode 100644 index 26d7b91b6266d802baaf55b64238629cd0f740d0..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/geometry/rep_3d/flexicubes.py +++ /dev/null @@ -1,579 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. -import torch -from .tables import * - -__all__ = [ - 'FlexiCubes' -] - - -class FlexiCubes: - """ - This class implements the FlexiCubes method for extracting meshes from scalar fields. - It maintains a series of lookup tables and indices to support the mesh extraction process. - FlexiCubes, a differentiable variant of the Dual Marching Cubes (DMC) scheme, enhances - the geometric fidelity and mesh quality of reconstructed meshes by dynamically adjusting - the surface representation through gradient-based optimization. - - During instantiation, the class loads DMC tables from a file and transforms them into - PyTorch tensors on the specified device. - - Attributes: - device (str): Specifies the computational device (default is "cuda"). - dmc_table (torch.Tensor): Dual Marching Cubes (DMC) table that encodes the edges - associated with each dual vertex in 256 Marching Cubes (MC) configurations. - num_vd_table (torch.Tensor): Table holding the number of dual vertices in each of - the 256 MC configurations. - check_table (torch.Tensor): Table resolving ambiguity in cases C16 and C19 - of the DMC configurations. - tet_table (torch.Tensor): Lookup table used in tetrahedralizing the isosurface. - quad_split_1 (torch.Tensor): Indices for splitting a quad into two triangles - along one diagonal. - quad_split_2 (torch.Tensor): Alternative indices for splitting a quad into - two triangles along the other diagonal. - quad_split_train (torch.Tensor): Indices for splitting a quad into four triangles - during training by connecting all edges to their midpoints. - cube_corners (torch.Tensor): Defines the positions of a standard unit cube's - eight corners in 3D space, ordered starting from the origin (0,0,0), - moving along the x-axis, then y-axis, and finally z-axis. - Used as a blueprint for generating a voxel grid. - cube_corners_idx (torch.Tensor): Cube corners indexed as powers of 2, used - to retrieve the case id. - cube_edges (torch.Tensor): Edge connections in a cube, listed in pairs. - Used to retrieve edge vertices in DMC. - edge_dir_table (torch.Tensor): A mapping tensor that associates edge indices with - their corresponding axis. For instance, edge_dir_table[0] = 0 indicates that the - first edge is oriented along the x-axis. - dir_faces_table (torch.Tensor): A tensor that maps the corresponding axis of shared edges - across four adjacent cubes to the shared faces of these cubes. For instance, - dir_faces_table[0] = [5, 4] implies that for four cubes sharing an edge along - the x-axis, the first and second cubes share faces indexed as 5 and 4, respectively. - This tensor is only utilized during isosurface tetrahedralization. - adj_pairs (torch.Tensor): - A tensor containing index pairs that correspond to neighboring cubes that share the same edge. - qef_reg_scale (float): - The scaling factor applied to the regularization loss to prevent issues with singularity - when solving the QEF. This parameter is only used when a 'grad_func' is specified. - weight_scale (float): - The scale of weights in FlexiCubes. Should be between 0 and 1. - """ - - def __init__(self, device="cuda", qef_reg_scale=1e-3, weight_scale=0.99): - - self.device = device - self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False) - self.num_vd_table = torch.tensor(num_vd_table, - dtype=torch.long, device=device, requires_grad=False) - self.check_table = torch.tensor( - check_table, - dtype=torch.long, device=device, requires_grad=False) - - self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False) - self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False) - self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False) - self.quad_split_train = torch.tensor( - [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False) - - self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [ - 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device) - self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False)) - self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6, - 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False) - - self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1], - dtype=torch.long, device=device) - self.dir_faces_table = torch.tensor([ - [[5, 4], [3, 2], [4, 5], [2, 3]], - [[5, 4], [1, 0], [4, 5], [0, 1]], - [[3, 2], [1, 0], [2, 3], [0, 1]] - ], dtype=torch.long, device=device) - self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device) - self.qef_reg_scale = qef_reg_scale - self.weight_scale = weight_scale - - def construct_voxel_grid(self, res): - """ - Generates a voxel grid based on the specified resolution. - - Args: - res (int or list[int]): The resolution of the voxel grid. If an integer - is provided, it is used for all three dimensions. If a list or tuple - of 3 integers is provided, they define the resolution for the x, - y, and z dimensions respectively. - - Returns: - (torch.Tensor, torch.Tensor): Returns the vertices and the indices of the - cube corners (index into vertices) of the constructed voxel grid. - The vertices are centered at the origin, with the length of each - dimension in the grid being one. - """ - base_cube_f = torch.arange(8).to(self.device) - if isinstance(res, int): - res = (res, res, res) - voxel_grid_template = torch.ones(res, device=self.device) - - res = torch.tensor([res], dtype=torch.float, device=self.device) - coords = torch.nonzero(voxel_grid_template).float() / res # N, 3 - verts = (self.cube_corners.unsqueeze(0) / res + coords.unsqueeze(1)).reshape(-1, 3) - cubes = (base_cube_f.unsqueeze(0) + - torch.arange(coords.shape[0], device=self.device).unsqueeze(1) * 8).reshape(-1) - - verts_rounded = torch.round(verts * 10**5) / (10**5) - verts_unique, inverse_indices = torch.unique(verts_rounded, dim=0, return_inverse=True) - cubes = inverse_indices[cubes.reshape(-1)].reshape(-1, 8) - - return verts_unique - 0.5, cubes - - def __call__(self, x_nx3, s_n, cube_fx8, res, beta_fx12=None, alpha_fx8=None, - gamma_f=None, training=False, output_tetmesh=False, grad_func=None): - r""" - Main function for mesh extraction from scalar field using FlexiCubes. This function converts - discrete signed distance fields, encoded on voxel grids and additional per-cube parameters, - to triangle or tetrahedral meshes using a differentiable operation as described in - `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_. FlexiCubes enhances - mesh quality and geometric fidelity by adjusting the surface representation based on gradient - optimization. The output surface is differentiable with respect to the input vertex positions, - scalar field values, and weight parameters. - - If you intend to extract a surface mesh from a fixed Signed Distance Field without the - optimization of parameters, it is suggested to provide the "grad_func" which should - return the surface gradient at any given 3D position. When grad_func is provided, the process - to determine the dual vertex position adapts to solve a Quadratic Error Function (QEF), as - described in the `Manifold Dual Contouring`_ paper, and employs an smart splitting strategy. - Please note, this approach is non-differentiable. - - For more details and example usage in optimization, refer to the - `Flexible Isosurface Extraction for Gradient-Based Mesh Optimization`_ SIGGRAPH 2023 paper. - - Args: - x_nx3 (torch.Tensor): Coordinates of the voxel grid vertices, can be deformed. - s_n (torch.Tensor): Scalar field values at each vertex of the voxel grid. Negative values - denote that the corresponding vertex resides inside the isosurface. This affects - the directions of the extracted triangle faces and volume to be tetrahedralized. - cube_fx8 (torch.Tensor): Indices of 8 vertices for each cube in the voxel grid. - res (int or list[int]): The resolution of the voxel grid. If an integer is provided, it - is used for all three dimensions. If a list or tuple of 3 integers is provided, they - specify the resolution for the x, y, and z dimensions respectively. - beta_fx12 (torch.Tensor, optional): Weight parameters for the cube edges to adjust dual - vertices positioning. Defaults to uniform value for all edges. - alpha_fx8 (torch.Tensor, optional): Weight parameters for the cube corners to adjust dual - vertices positioning. Defaults to uniform value for all vertices. - gamma_f (torch.Tensor, optional): Weight parameters to control the splitting of - quadrilaterals into triangles. Defaults to uniform value for all cubes. - training (bool, optional): If set to True, applies differentiable quad splitting for - training. Defaults to False. - output_tetmesh (bool, optional): If set to True, outputs a tetrahedral mesh, otherwise, - outputs a triangular mesh. Defaults to False. - grad_func (callable, optional): A function to compute the surface gradient at specified - 3D positions (input: Nx3 positions). The function should return gradients as an Nx3 - tensor. If None, the original FlexiCubes algorithm is utilized. Defaults to None. - - Returns: - (torch.Tensor, torch.LongTensor, torch.Tensor): Tuple containing: - - Vertices for the extracted triangular/tetrahedral mesh. - - Faces for the extracted triangular/tetrahedral mesh. - - Regularizer L_dev, computed per dual vertex. - - .. _Flexible Isosurface Extraction for Gradient-Based Mesh Optimization: - https://research.nvidia.com/labs/toronto-ai/flexicubes/ - .. _Manifold Dual Contouring: - https://people.engr.tamu.edu/schaefer/research/dualsimp_tvcg.pdf - """ - - surf_cubes, occ_fx8 = self._identify_surf_cubes(s_n, cube_fx8) - if surf_cubes.sum() == 0: - return torch.zeros( - (0, 3), - device=self.device), torch.zeros( - (0, 4), - dtype=torch.long, device=self.device) if output_tetmesh else torch.zeros( - (0, 3), - dtype=torch.long, device=self.device), torch.zeros( - (0), - device=self.device) - beta_fx12, alpha_fx8, gamma_f = self._normalize_weights(beta_fx12, alpha_fx8, gamma_f, surf_cubes) - - case_ids = self._get_case_id(occ_fx8, surf_cubes, res) - - surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(s_n, cube_fx8, surf_cubes) - - vd, L_dev, vd_gamma, vd_idx_map = self._compute_vd( - x_nx3, cube_fx8[surf_cubes], surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func) - vertices, faces, s_edges, edge_indices = self._triangulate( - s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func) - if not output_tetmesh: - return vertices, faces, L_dev - else: - vertices, tets = self._tetrahedralize( - x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, - surf_cubes, training) - return vertices, tets, L_dev - - def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges): - """ - Regularizer L_dev as in Equation 8 - """ - dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1) - mean_l2 = torch.zeros_like(vd[:, 0]) - mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float() - mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs() - return mad - - def _normalize_weights(self, beta_fx12, alpha_fx8, gamma_f, surf_cubes): - """ - Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones. - """ - n_cubes = surf_cubes.shape[0] - - if beta_fx12 is not None: - beta_fx12 = (torch.tanh(beta_fx12) * self.weight_scale + 1) - else: - beta_fx12 = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device) - - if alpha_fx8 is not None: - alpha_fx8 = (torch.tanh(alpha_fx8) * self.weight_scale + 1) - else: - alpha_fx8 = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device) - - if gamma_f is not None: - gamma_f = torch.sigmoid(gamma_f) * self.weight_scale + (1 - self.weight_scale)/2 - else: - gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device) - - return beta_fx12[surf_cubes], alpha_fx8[surf_cubes], gamma_f[surf_cubes] - - @torch.no_grad() - def _get_case_id(self, occ_fx8, surf_cubes, res): - """ - Obtains the ID of topology cases based on cell corner occupancy. This function resolves the - ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the - supplementary material. It should be noted that this function assumes a regular grid. - """ - case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1) - - problem_config = self.check_table.to(self.device)[case_ids] - to_check = problem_config[..., 0] == 1 - problem_config = problem_config[to_check] - if not isinstance(res, (list, tuple)): - res = [res, res, res] - - # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array, - # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes). - # This allows efficient checking on adjacent cubes. - problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long) - vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3 - vol_idx_problem = vol_idx[surf_cubes][to_check] - problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config - vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4] - - within_range = ( - vol_idx_problem_adj[..., 0] >= 0) & ( - vol_idx_problem_adj[..., 0] < res[0]) & ( - vol_idx_problem_adj[..., 1] >= 0) & ( - vol_idx_problem_adj[..., 1] < res[1]) & ( - vol_idx_problem_adj[..., 2] >= 0) & ( - vol_idx_problem_adj[..., 2] < res[2]) - - vol_idx_problem = vol_idx_problem[within_range] - vol_idx_problem_adj = vol_idx_problem_adj[within_range] - problem_config = problem_config[within_range] - problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0], - vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]] - # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted. - to_invert = (problem_config_adj[..., 0] == 1) - idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert] - case_ids.index_put_((idx,), problem_config[to_invert][..., -1]) - return case_ids - - @torch.no_grad() - def _identify_surf_edges(self, s_n, cube_fx8, surf_cubes): - """ - Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge - can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge - and marks the cube edges with this index. - """ - occ_n = s_n < 0 - all_edges = cube_fx8[surf_cubes][:, self.cube_edges].reshape(-1, 2) - unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) - - unique_edges = unique_edges.long() - mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 - - surf_edges_mask = mask_edges[_idx_map] - counts = counts[_idx_map] - - mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_fx8.device) * -1 - mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_fx8.device) - # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index - # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1. - idx_map = mapping[_idx_map] - surf_edges = unique_edges[mask_edges] - return surf_edges, idx_map, counts, surf_edges_mask - - @torch.no_grad() - def _identify_surf_cubes(self, s_n, cube_fx8): - """ - Identifies grid cubes that intersect with the underlying surface by checking if the signs at - all corners are not identical. - """ - occ_n = s_n < 0 - occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) - _occ_sum = torch.sum(occ_fx8, -1) - surf_cubes = (_occ_sum > 0) & (_occ_sum < 8) - return surf_cubes, occ_fx8 - - def _linear_interp(self, edges_weight, edges_x): - """ - Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'. - """ - edge_dim = edges_weight.dim() - 2 - assert edges_weight.shape[edge_dim] == 2 - edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), - - torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)], edge_dim) - denominator = edges_weight.sum(edge_dim) - ue = (edges_x * edges_weight).sum(edge_dim) / denominator - return ue - - def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3=None): - p_bxnx3 = p_bxnx3.reshape(-1, 7, 3) - norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3) - c_bx3 = c_bx3.reshape(-1, 3) - A = norm_bxnx3 - B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True) - - A_reg = (torch.eye(3, device=p_bxnx3.device) * self.qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1) - B_reg = (self.qef_reg_scale * c_bx3).unsqueeze(-1) - A = torch.cat([A, A_reg], 1) - B = torch.cat([B, B_reg], 1) - dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1) - return dual_verts - - def _compute_vd(self, x_nx3, surf_cubes_fx8, surf_edges, s_n, case_ids, beta_fx12, alpha_fx8, gamma_f, idx_map, grad_func): - """ - Computes the location of dual vertices as described in Section 4.2 - """ - alpha_nx12x2 = torch.index_select(input=alpha_fx8, index=self.cube_edges, dim=1).reshape(-1, 12, 2) - surf_edges_x = torch.index_select(input=x_nx3, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3) - surf_edges_s = torch.index_select(input=s_n, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1) - zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x) - - idx_map = idx_map.reshape(-1, 12) - num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0) - edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], [] - - total_num_vd = 0 - vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False) - if grad_func is not None: - normals = torch.nn.functional.normalize(grad_func(zero_crossing), dim=-1) - vd = [] - for num in torch.unique(num_vd): - cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching) - curr_num_vd = cur_cubes.sum() * num - curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7) - curr_edge_group_to_vd = torch.arange( - curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd - total_num_vd += curr_num_vd - curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[ - cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group) - - curr_mask = (curr_edge_group != -1) - edge_group.append(torch.masked_select(curr_edge_group, curr_mask)) - edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask)) - edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask)) - vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True)) - vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1)) - - if grad_func is not None: - with torch.no_grad(): - cube_e_verts_idx = idx_map[cur_cubes] - curr_edge_group[~curr_mask] = 0 - - verts_group_idx = torch.gather(input=cube_e_verts_idx, dim=1, index=curr_edge_group) - verts_group_idx[verts_group_idx == -1] = 0 - verts_group_pos = torch.index_select( - input=zero_crossing, index=verts_group_idx.reshape(-1), dim=0).reshape(-1, num.item(), 7, 3) - v0 = x_nx3[surf_cubes_fx8[cur_cubes][:, 0]].reshape(-1, 1, 1, 3).repeat(1, num.item(), 1, 1) - curr_mask = curr_mask.reshape(-1, num.item(), 7, 1) - verts_centroid = (verts_group_pos * curr_mask).sum(2) / (curr_mask.sum(2)) - - normals_bx7x3 = torch.index_select(input=normals, index=verts_group_idx.reshape(-1), dim=0).reshape( - -1, num.item(), 7, - 3) - curr_mask = curr_mask.squeeze(2) - vd.append(self._solve_vd_QEF((verts_group_pos - v0) * curr_mask, normals_bx7x3 * curr_mask, - verts_centroid - v0.squeeze(2)) + v0.reshape(-1, 3)) - edge_group = torch.cat(edge_group) - edge_group_to_vd = torch.cat(edge_group_to_vd) - edge_group_to_cube = torch.cat(edge_group_to_cube) - vd_num_edges = torch.cat(vd_num_edges) - vd_gamma = torch.cat(vd_gamma) - - if grad_func is not None: - vd = torch.cat(vd) - L_dev = torch.zeros([1], device=self.device) - else: - vd = torch.zeros((total_num_vd, 3), device=self.device) - beta_sum = torch.zeros((total_num_vd, 1), device=self.device) - - idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group) - - x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3) - s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1) - - zero_crossing_group = torch.index_select( - input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3) - - alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0, - index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1) - ue_group = self._linear_interp(s_group * alpha_group, x_group) - - beta_group = torch.gather(input=beta_fx12.reshape(-1), dim=0, - index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1) - beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group) - vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum - L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges) - - v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd - - vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube * - 12 + edge_group, src=v_idx[edge_group_to_vd]) - - return vd, L_dev, vd_gamma, vd_idx_map - - def _triangulate(self, s_n, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, grad_func): - """ - Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into - triangles based on the gamma parameter, as described in Section 4.3. - """ - with torch.no_grad(): - group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes. - group = idx_map.reshape(-1)[group_mask] - vd_idx = vd_idx_map[group_mask] - edge_indices, indices = torch.sort(group, stable=True) - quad_vd_idx = vd_idx[indices].reshape(-1, 4) - - # Ensure all face directions point towards the positive SDF to maintain consistent winding. - s_edges = s_n[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2) - flip_mask = s_edges[:, 0] > 0 - quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]], - quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]])) - if grad_func is not None: - # when grad_func is given, split quadrilaterals along the diagonals with more consistent gradients. - with torch.no_grad(): - vd_gamma = torch.nn.functional.normalize(grad_func(vd), dim=-1) - quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) - gamma_02 = (quad_gamma[:, 0] * quad_gamma[:, 2]).sum(-1, keepdims=True) - gamma_13 = (quad_gamma[:, 1] * quad_gamma[:, 3]).sum(-1, keepdims=True) - else: - quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4) - gamma_02 = torch.index_select(input=quad_gamma, index=torch.tensor( - 0, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(2, device=self.device), dim=1) - gamma_13 = torch.index_select(input=quad_gamma, index=torch.tensor( - 1, device=self.device), dim=1) * torch.index_select(input=quad_gamma, index=torch.tensor(3, device=self.device), dim=1) - if not training: - mask = (gamma_02 > gamma_13).squeeze(1) - faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device) - faces[mask] = quad_vd_idx[mask][:, self.quad_split_1] - faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2] - faces = faces.reshape(-1, 3) - else: - vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3) - vd_02 = (torch.index_select(input=vd_quad, index=torch.tensor(0, device=self.device), dim=1) + - torch.index_select(input=vd_quad, index=torch.tensor(2, device=self.device), dim=1)) / 2 - vd_13 = (torch.index_select(input=vd_quad, index=torch.tensor(1, device=self.device), dim=1) + - torch.index_select(input=vd_quad, index=torch.tensor(3, device=self.device), dim=1)) / 2 - weight_sum = (gamma_02 + gamma_13) + 1e-8 - vd_center = ((vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / - weight_sum.unsqueeze(-1)).squeeze(1) - vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0] - vd = torch.cat([vd, vd_center]) - faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2) - faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3) - return vd, faces, s_edges, edge_indices - - def _tetrahedralize( - self, x_nx3, s_n, cube_fx8, vertices, faces, surf_edges, s_edges, vd_idx_map, case_ids, edge_indices, - surf_cubes, training): - """ - Tetrahedralizes the interior volume to produce a tetrahedral mesh, as described in Section 4.5. - """ - occ_n = s_n < 0 - occ_fx8 = occ_n[cube_fx8.reshape(-1)].reshape(-1, 8) - occ_sum = torch.sum(occ_fx8, -1) - - inside_verts = x_nx3[occ_n] - mapping_inside_verts = torch.ones((occ_n.shape[0]), dtype=torch.long, device=self.device) * -1 - mapping_inside_verts[occ_n] = torch.arange(occ_n.sum(), device=self.device) + vertices.shape[0] - """ - For each grid edge connecting two grid vertices with different - signs, we first form a four-sided pyramid by connecting one - of the grid vertices with four mesh vertices that correspond - to the grid edge and then subdivide the pyramid into two tetrahedra - """ - inside_verts_idx = mapping_inside_verts[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1, 2)[ - s_edges < 0]] - if not training: - inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 2).reshape(-1) - else: - inside_verts_idx = inside_verts_idx.unsqueeze(1).expand(-1, 4).reshape(-1) - - tets_surface = torch.cat([faces, inside_verts_idx.unsqueeze(-1)], -1) - """ - For each grid edge connecting two grid vertices with the - same sign, the tetrahedron is formed by the two grid vertices - and two vertices in consecutive adjacent cells - """ - inside_cubes = (occ_sum == 8) - inside_cubes_center = x_nx3[cube_fx8[inside_cubes].reshape(-1)].reshape(-1, 8, 3).mean(1) - inside_cubes_center_idx = torch.arange( - inside_cubes_center.shape[0], device=inside_cubes.device) + vertices.shape[0] + inside_verts.shape[0] - - surface_n_inside_cubes = surf_cubes | inside_cubes - edge_center_vertex_idx = torch.ones(((surface_n_inside_cubes).sum(), 13), - dtype=torch.long, device=x_nx3.device) * -1 - surf_cubes = surf_cubes[surface_n_inside_cubes] - inside_cubes = inside_cubes[surface_n_inside_cubes] - edge_center_vertex_idx[surf_cubes, :12] = vd_idx_map.reshape(-1, 12) - edge_center_vertex_idx[inside_cubes, 12] = inside_cubes_center_idx - - all_edges = cube_fx8[surface_n_inside_cubes][:, self.cube_edges].reshape(-1, 2) - unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True) - unique_edges = unique_edges.long() - mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 2 - mask = mask_edges[_idx_map] - counts = counts[_idx_map] - mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=self.device) * -1 - mapping[mask_edges] = torch.arange(mask_edges.sum(), device=self.device) - idx_map = mapping[_idx_map] - - group_mask = (counts == 4) & mask - group = idx_map.reshape(-1)[group_mask] - edge_indices, indices = torch.sort(group) - cube_idx = torch.arange((_idx_map.shape[0] // 12), dtype=torch.long, - device=self.device).unsqueeze(1).expand(-1, 12).reshape(-1)[group_mask] - edge_idx = torch.arange((12), dtype=torch.long, device=self.device).unsqueeze( - 0).expand(_idx_map.shape[0] // 12, -1).reshape(-1)[group_mask] - # Identify the face shared by the adjacent cells. - cube_idx_4 = cube_idx[indices].reshape(-1, 4) - edge_dir = self.edge_dir_table[edge_idx[indices]].reshape(-1, 4)[..., 0] - shared_faces_4x2 = self.dir_faces_table[edge_dir].reshape(-1) - cube_idx_4x2 = cube_idx_4[:, self.adj_pairs].reshape(-1) - # Identify an edge of the face with different signs and - # select the mesh vertex corresponding to the identified edge. - case_ids_expand = torch.ones((surface_n_inside_cubes).sum(), dtype=torch.long, device=x_nx3.device) * 255 - case_ids_expand[surf_cubes] = case_ids - cases = case_ids_expand[cube_idx_4x2] - quad_edge = edge_center_vertex_idx[cube_idx_4x2, self.tet_table[cases, shared_faces_4x2]].reshape(-1, 2) - mask = (quad_edge == -1).sum(-1) == 0 - inside_edge = mapping_inside_verts[unique_edges[mask_edges][edge_indices].reshape(-1)].reshape(-1, 2) - tets_inside = torch.cat([quad_edge, inside_edge], -1)[mask] - - tets = torch.cat([tets_surface, tets_inside]) - vertices = torch.cat([vertices, inside_verts, inside_cubes_center]) - return vertices, tets diff --git a/instant-mesh/src/models/geometry/rep_3d/flexicubes_geometry.py b/instant-mesh/src/models/geometry/rep_3d/flexicubes_geometry.py deleted file mode 100644 index bf050ee20361f78957839942f83fe77efde231b7..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/geometry/rep_3d/flexicubes_geometry.py +++ /dev/null @@ -1,120 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. - -import torch -import numpy as np -import os -from . import Geometry -from .flexicubes import FlexiCubes # replace later -from .dmtet import sdf_reg_loss_batch -import torch.nn.functional as F - -def get_center_boundary_index(grid_res, device): - v = torch.zeros((grid_res + 1, grid_res + 1, grid_res + 1), dtype=torch.bool, device=device) - v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = True - center_indices = torch.nonzero(v.reshape(-1)) - - v[grid_res // 2 + 1, grid_res // 2 + 1, grid_res // 2 + 1] = False - v[:2, ...] = True - v[-2:, ...] = True - v[:, :2, ...] = True - v[:, -2:, ...] = True - v[:, :, :2] = True - v[:, :, -2:] = True - boundary_indices = torch.nonzero(v.reshape(-1)) - return center_indices, boundary_indices - -############################################################################### -# Geometry interface -############################################################################### -class FlexiCubesGeometry(Geometry): - def __init__( - self, grid_res=64, scale=2.0, device='cuda', renderer=None, - render_type='neural_render', args=None): - super(FlexiCubesGeometry, self).__init__() - self.grid_res = grid_res - self.device = device - self.args = args - self.fc = FlexiCubes(device, weight_scale=0.5) - self.verts, self.indices = self.fc.construct_voxel_grid(grid_res) - if isinstance(scale, list): - self.verts[:, 0] = self.verts[:, 0] * scale[0] - self.verts[:, 1] = self.verts[:, 1] * scale[1] - self.verts[:, 2] = self.verts[:, 2] * scale[1] - else: - self.verts = self.verts * scale - - all_edges = self.indices[:, self.fc.cube_edges].reshape(-1, 2) - self.all_edges = torch.unique(all_edges, dim=0) - - # Parameters used for fix boundary sdf - self.center_indices, self.boundary_indices = get_center_boundary_index(self.grid_res, device) - self.renderer = renderer - self.render_type = render_type - - def getAABB(self): - return torch.min(self.verts, dim=0).values, torch.max(self.verts, dim=0).values - - def get_mesh(self, v_deformed_nx3, sdf_n, weight_n=None, with_uv=False, indices=None, is_training=False): - if indices is None: - indices = self.indices - - verts, faces, v_reg_loss = self.fc(v_deformed_nx3, sdf_n, indices, self.grid_res, - beta_fx12=weight_n[:, :12], alpha_fx8=weight_n[:, 12:20], - gamma_f=weight_n[:, 20], training=is_training - ) - return verts, faces, v_reg_loss - - - def render_mesh(self, mesh_v_nx3, mesh_f_fx3, camera_mv_bx4x4, resolution=256, hierarchical_mask=False): - return_value = dict() - if self.render_type == 'neural_render': - tex_pos, mask, hard_mask, rast, v_pos_clip, mask_pyramid, depth, normal = self.renderer.render_mesh( - mesh_v_nx3.unsqueeze(dim=0), - mesh_f_fx3.int(), - camera_mv_bx4x4, - mesh_v_nx3.unsqueeze(dim=0), - resolution=resolution, - device=self.device, - hierarchical_mask=hierarchical_mask - ) - - return_value['tex_pos'] = tex_pos - return_value['mask'] = mask - return_value['hard_mask'] = hard_mask - return_value['rast'] = rast - return_value['v_pos_clip'] = v_pos_clip - return_value['mask_pyramid'] = mask_pyramid - return_value['depth'] = depth - return_value['normal'] = normal - else: - raise NotImplementedError - - return return_value - - def render(self, v_deformed_bxnx3=None, sdf_bxn=None, camera_mv_bxnviewx4x4=None, resolution=256): - # Here I assume a batch of meshes (can be different mesh and geometry), for the other shapes, the batch is 1 - v_list = [] - f_list = [] - n_batch = v_deformed_bxnx3.shape[0] - all_render_output = [] - for i_batch in range(n_batch): - verts_nx3, faces_fx3 = self.get_mesh(v_deformed_bxnx3[i_batch], sdf_bxn[i_batch]) - v_list.append(verts_nx3) - f_list.append(faces_fx3) - render_output = self.render_mesh(verts_nx3, faces_fx3, camera_mv_bxnviewx4x4[i_batch], resolution) - all_render_output.append(render_output) - - # Concatenate all render output - return_keys = all_render_output[0].keys() - return_value = dict() - for k in return_keys: - value = [v[k] for v in all_render_output] - return_value[k] = value - # We can do concatenation outside of the render - return return_value diff --git a/instant-mesh/src/models/geometry/rep_3d/tables.py b/instant-mesh/src/models/geometry/rep_3d/tables.py deleted file mode 100644 index 5873e7727b5595a1e4fbc3bd10ae5be8f3d06cca..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/geometry/rep_3d/tables.py +++ /dev/null @@ -1,791 +0,0 @@ -# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. -dmc_table = [ -[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 6, 7, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 5, 6, 8], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 5, 9, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 6, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 7, 8, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [5, 6, 10, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [5, 6, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 9, 10, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [4, 6, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 8, 9, 10, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 6, 7, 8, 10, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 6, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [1, 2, 5, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 6, 9, -1, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 6, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 6, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 6, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 6, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 6, 7, 9], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 6, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 6, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 6, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 6, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 3, 5, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, 6, 7, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 6, 9, 11, -1], [4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 6, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 6, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 6, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 6, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 6, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 6, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 6, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[6, 7, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [5, 7, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [4, 5, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 9, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [4, 7, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 10, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 8, 10, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[8, 9, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 9, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 8, 10, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 10, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 5, 7, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [2, 3, 4, 5, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 9, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 8, 9, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 4, 7, 10], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 8, 9, 10, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 9, 10, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 8, 10, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 10, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 5, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 5, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 5, 7, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 5, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 4, 5, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 5, 8, 9, 11], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [1, 2, 4, 7, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 4, 7, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 4, 7, 8, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 2, 8, 9, 11, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 2, 3, 9, 11, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 2, 8, 11, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[2, 3, 11, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 5, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 5, 7, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 5, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[5, 7, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 5, 8, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 5, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 5, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 5, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 4, 7, 9, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 4, 7, 8, 9, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 4, 7, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[4, 7, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[1, 3, 8, 9, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 1, 9, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[0, 3, 8, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]], -[[-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1], [-1, -1, -1, -1, -1, -1, -1]] -] -num_vd_table = [0, 1, 1, 1, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 3, 1, 2, 2, -2, 1, 2, 1, 2, 1, 1, 2, 1, 1, 2, 2, 2, 1, 2, 3, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, -1, 2, 1, 2, 2, 1, 1, 2, 1, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 2, 3, 2, 2, 1, 1, 1, 1, -1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, 3, 2, 2, 2, 2, 2, 1, 3, 4, 2, -2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 1, 1, 2, 2, 2, 2, 2, -3, 2, 1, 2, 1, 1, 1, 1, 1, 1, 2, 2, 3, 2, 3, 2, 4, 2, 2, 2, 2, 1, 2, 1, 2, 1, 1, -2, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, -1, 2, 1, 1, 1, 2, 2, 2, 1, 1, 2, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, -1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, -1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0] -check_table = [ -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 194], -[1, -1, 0, 0, 193], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 164], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 161], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 152], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 145], -[1, 0, 0, 1, 144], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 137], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 133], -[1, 0, 1, 0, 132], -[1, 1, 0, 0, 131], -[1, 1, 0, 0, 130], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 100], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 98], -[0, 0, 0, 0, 0], -[1, 0, 0, 1, 96], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 88], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 82], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 74], -[0, 0, 0, 0, 0], -[1, 0, 1, 0, 72], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 70], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 67], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 65], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 56], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 52], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 44], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 1, 0, 0, 40], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 38], -[1, 0, -1, 0, 37], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 33], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 28], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 26], -[1, 0, 0, -1, 25], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, -1, 0, 0, 20], -[0, 0, 0, 0, 0], -[1, 0, -1, 0, 18], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 9], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[1, 0, 0, -1, 6], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0], -[0, 0, 0, 0, 0] -] -tet_table = [ -[-1, -1, -1, -1, -1, -1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, -1], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[0, 4, 0, 4, 4, -1], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[5, 5, 5, 5, 5, 5], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, -1, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, -1, 2, 4, 4, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 4, 4, 2], -[1, 1, 1, 1, 1, 1], -[2, 4, 2, 4, 4, 2], -[0, 4, 0, 4, 4, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 5, 2, 5, 5, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, -1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[4, 1, 1, 4, 4, 1], -[0, 1, 1, 0, 0, 1], -[4, 0, 0, 4, 4, 0], -[2, 2, 2, 2, 2, 2], -[-1, 1, 1, 4, 4, 1], -[0, 1, 1, 4, 4, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[5, 1, 1, 5, 5, 1], -[0, 1, 1, 0, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[8, 8, 8, 8, 8, 8], -[1, 1, 1, 4, 4, 1], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 4, 4, 1], -[0, 4, 0, 4, 4, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 5, 5, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[6, 6, 6, 6, 6, 6], -[6, -1, 0, 6, 0, 6], -[6, 0, 0, 6, 0, 6], -[6, 1, 1, 6, 1, 6], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[6, 4, -1, 6, 4, 6], -[6, 4, 0, 6, 4, 6], -[6, 0, 0, 6, 0, 6], -[6, 1, 1, 6, 1, 6], -[5, 5, 5, 5, 5, 5], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[2, 4, 2, 2, 4, 2], -[0, 4, 0, 4, 4, 0], -[2, 0, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[6, 1, 1, 6, -1, 6], -[6, 1, 1, 6, 0, 6], -[6, 0, 0, 6, 0, 6], -[6, 2, 2, 6, 2, 6], -[4, 1, 1, 4, 4, 1], -[0, 1, 1, 0, 0, 1], -[4, 0, 0, 4, 4, 4], -[2, 2, 2, 2, 2, 2], -[6, 1, 1, 6, 4, 6], -[6, 1, 1, 6, 4, 6], -[6, 0, 0, 6, 0, 6], -[6, 2, 2, 6, 2, 6], -[5, 1, 1, 5, 5, 1], -[0, 1, 1, 0, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[6, 6, 6, 6, 6, 6], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 4, 1], -[0, 4, 0, 4, 4, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 5, 0, 5, 0, 5], -[5, 5, 5, 5, 5, 5], -[5, 5, 5, 5, 5, 5], -[0, 5, 0, 5, 0, 5], -[-1, 5, 0, 5, 0, 5], -[1, 5, 1, 5, 1, 5], -[4, 5, -1, 5, 4, 5], -[0, 5, 0, 5, 0, 5], -[4, 5, 0, 5, 4, 5], -[1, 5, 1, 5, 1, 5], -[4, 4, 4, 4, 4, 4], -[0, 4, 0, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[6, 6, 6, 6, 6, 6], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 5, 2, 5, -1, 5], -[0, 5, 0, 5, 0, 5], -[2, 5, 2, 5, 0, 5], -[1, 5, 1, 5, 1, 5], -[2, 5, 2, 5, 4, 5], -[0, 5, 0, 5, 0, 5], -[2, 5, 2, 5, 4, 5], -[1, 5, 1, 5, 1, 5], -[2, 4, 2, 4, 4, 2], -[0, 4, 0, 4, 4, 4], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 6, 2, 6, 6, 2], -[0, 0, 0, 0, 0, 0], -[2, 0, 2, 0, 0, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, 1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[4, 1, 1, 1, 4, 1], -[0, 1, 1, 1, 0, 1], -[4, 0, 0, 4, 4, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[1, 1, 1, 1, 4, 1], -[0, 0, 0, 0, 0, 0], -[4, 0, 0, 4, 4, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[6, 0, 0, 6, 0, 6], -[0, 0, 0, 0, 0, 0], -[6, 6, 6, 6, 6, 6], -[5, 5, 5, 5, 5, 5], -[5, 5, 0, 5, 0, 5], -[5, 5, 0, 5, 0, 5], -[5, 5, 1, 5, 1, 5], -[4, 4, 4, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[4, 4, 0, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[4, 4, 4, 4, 4, 4], -[4, 4, 0, 4, 4, 4], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[8, 8, 8, 8, 8, 8], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 0, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 1, 1, 4, 4, 1], -[2, 2, 2, 2, 2, 2], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 0, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 4, 2, 4, 4, 2], -[1, 1, 1, 1, 1, 1], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[2, 2, 2, 2, 2, 2], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[5, 5, 5, 5, 5, 5], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[4, 4, 4, 4, 4, 4], -[1, 1, 1, 1, 1, 1], -[0, 0, 0, 0, 0, 0], -[0, 0, 0, 0, 0, 0], -[12, 12, 12, 12, 12, 12] -] diff --git a/instant-mesh/src/models/lrm.py b/instant-mesh/src/models/lrm.py deleted file mode 100644 index eea9ee3353d74fb60451fec87f6c2c30816f64ae..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/lrm.py +++ /dev/null @@ -1,196 +0,0 @@ -# Copyright (c) 2023, Zexin He -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import torch -import torch.nn as nn -import mcubes -import nvdiffrast.torch as dr -from einops import rearrange, repeat - -from .encoder.dino_wrapper import DinoWrapper -from .decoder.transformer import TriplaneTransformer -from .renderer.synthesizer import TriplaneSynthesizer -from ..utils.mesh_util import xatlas_uvmap - - -class InstantNeRF(nn.Module): - """ - Full model of the large reconstruction model. - """ - def __init__( - self, - encoder_freeze: bool = False, - encoder_model_name: str = 'facebook/dino-vitb16', - encoder_feat_dim: int = 768, - transformer_dim: int = 1024, - transformer_layers: int = 16, - transformer_heads: int = 16, - triplane_low_res: int = 32, - triplane_high_res: int = 64, - triplane_dim: int = 80, - rendering_samples_per_ray: int = 128, - ): - super().__init__() - - # modules - self.encoder = DinoWrapper( - model_name=encoder_model_name, - freeze=encoder_freeze, - ) - - self.transformer = TriplaneTransformer( - inner_dim=transformer_dim, - num_layers=transformer_layers, - num_heads=transformer_heads, - image_feat_dim=encoder_feat_dim, - triplane_low_res=triplane_low_res, - triplane_high_res=triplane_high_res, - triplane_dim=triplane_dim, - ) - - self.synthesizer = TriplaneSynthesizer( - triplane_dim=triplane_dim, - samples_per_ray=rendering_samples_per_ray, - ) - - def forward_planes(self, images, cameras): - # images: [B, V, C_img, H_img, W_img] - # cameras: [B, V, 16] - B = images.shape[0] - - # encode images - image_feats = self.encoder(images, cameras) - image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B) - - # transformer generating planes - planes = self.transformer(image_feats) - - return planes - - def forward(self, images, cameras, render_cameras, render_size: int): - # images: [B, V, C_img, H_img, W_img] - # cameras: [B, V, 16] - # render_cameras: [B, M, D_cam_render] - # render_size: int - B, M = render_cameras.shape[:2] - - planes = self.forward_planes(images, cameras) - - # render target views - render_results = self.synthesizer(planes, render_cameras, render_size) - - return { - 'planes': planes, - **render_results, - } - - def get_texture_prediction(self, planes, tex_pos, hard_mask=None): - ''' - Predict Texture given triplanes - :param planes: the triplane feature map - :param tex_pos: Position we want to query the texture field - :param hard_mask: 2D silhoueete of the rendered image - ''' - tex_pos = torch.cat(tex_pos, dim=0) - if not hard_mask is None: - tex_pos = tex_pos * hard_mask.float() - batch_size = tex_pos.shape[0] - tex_pos = tex_pos.reshape(batch_size, -1, 3) - ################### - # We use mask to get the texture location (to save the memory) - if hard_mask is not None: - n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) - sample_tex_pose_list = [] - max_point = n_point_list.max() - expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 - for i in range(tex_pos.shape[0]): - tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) - if tex_pos_one_shape.shape[1] < max_point: - tex_pos_one_shape = torch.cat( - [tex_pos_one_shape, torch.zeros( - 1, max_point - tex_pos_one_shape.shape[1], 3, - device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1) - sample_tex_pose_list.append(tex_pos_one_shape) - tex_pos = torch.cat(sample_tex_pose_list, dim=0) - - tex_feat = self.synthesizer.forward_points(planes, tex_pos)['rgb'] - - if hard_mask is not None: - final_tex_feat = torch.zeros( - planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device) - expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 - for i in range(planes.shape[0]): - final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1) - tex_feat = final_tex_feat - - return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]) - - def extract_mesh( - self, - planes: torch.Tensor, - mesh_resolution: int = 256, - mesh_threshold: int = 10.0, - use_texture_map: bool = False, - texture_resolution: int = 1024, - **kwargs, - ): - ''' - Extract a 3D mesh from triplane nerf. Only support batch_size 1. - :param planes: triplane features - :param mesh_resolution: marching cubes resolution - :param mesh_threshold: iso-surface threshold - :param use_texture_map: use texture map or vertex color - :param texture_resolution: the resolution of texture map - ''' - assert planes.shape[0] == 1 - device = planes.device - - grid_out = self.synthesizer.forward_grid( - planes=planes, - grid_size=mesh_resolution, - ) - - vertices, faces = mcubes.marching_cubes( - grid_out['sigma'].squeeze(0).squeeze(-1).cpu().numpy(), - mesh_threshold, - ) - vertices = vertices / (mesh_resolution - 1) * 2 - 1 - - if not use_texture_map: - # query vertex colors - vertices_tensor = torch.tensor(vertices, dtype=torch.float32, device=device).unsqueeze(0) - vertices_colors = self.synthesizer.forward_points( - planes, vertices_tensor)['rgb'].squeeze(0).cpu().numpy() - vertices_colors = (vertices_colors * 255).astype(np.uint8) - - return vertices, faces, vertices_colors - - # use x-atlas to get uv mapping for the mesh - vertices = torch.tensor(vertices, dtype=torch.float32, device=device) - faces = torch.tensor(faces.astype(int), dtype=torch.long, device=device) - - ctx = dr.RasterizeCudaContext(device=device) - uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( - ctx, vertices, faces, resolution=texture_resolution) - tex_hard_mask = tex_hard_mask.float() - - # query the texture field to get the RGB color for texture map - tex_feat = self.get_texture_prediction( - planes, [gb_pos], tex_hard_mask) - background_feature = torch.zeros_like(tex_feat) - img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) - texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) - - return vertices, faces, uvs, mesh_tex_idx, texture_map \ No newline at end of file diff --git a/instant-mesh/src/models/lrm_mesh.py b/instant-mesh/src/models/lrm_mesh.py deleted file mode 100644 index b0f278e6bf73d3320c05c24809de862220f53a00..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/lrm_mesh.py +++ /dev/null @@ -1,385 +0,0 @@ -# Copyright (c) 2023, Tencent Inc -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import numpy as np -import torch -import torch.nn as nn -import nvdiffrast.torch as dr -from einops import rearrange, repeat - -from .encoder.dino_wrapper import DinoWrapper -from .decoder.transformer import TriplaneTransformer -from .renderer.synthesizer_mesh import TriplaneSynthesizer -from .geometry.camera.perspective_camera import PerspectiveCamera -from .geometry.render.neural_render import NeuralRender -from .geometry.rep_3d.flexicubes_geometry import FlexiCubesGeometry -from ..utils.mesh_util import xatlas_uvmap - - -class InstantMesh(nn.Module): - """ - Full model of the large reconstruction model. - """ - def __init__( - self, - encoder_freeze: bool = False, - encoder_model_name: str = 'facebook/dino-vitb16', - encoder_feat_dim: int = 768, - transformer_dim: int = 1024, - transformer_layers: int = 16, - transformer_heads: int = 16, - triplane_low_res: int = 32, - triplane_high_res: int = 64, - triplane_dim: int = 80, - rendering_samples_per_ray: int = 128, - grid_res: int = 128, - grid_scale: float = 2.0, - ): - super().__init__() - - # attributes - self.grid_res = grid_res - self.grid_scale = grid_scale - self.deformation_multiplier = 4.0 - - # modules - self.encoder = DinoWrapper( - model_name=encoder_model_name, - freeze=encoder_freeze, - ) - - self.transformer = TriplaneTransformer( - inner_dim=transformer_dim, - num_layers=transformer_layers, - num_heads=transformer_heads, - image_feat_dim=encoder_feat_dim, - triplane_low_res=triplane_low_res, - triplane_high_res=triplane_high_res, - triplane_dim=triplane_dim, - ) - - self.synthesizer = TriplaneSynthesizer( - triplane_dim=triplane_dim, - samples_per_ray=rendering_samples_per_ray, - ) - - def init_flexicubes_geometry(self, device, fovy=50.0, use_renderer=True): - camera = PerspectiveCamera(fovy=fovy, device=device) - if use_renderer: - renderer = NeuralRender(device, camera_model=camera) - else: - renderer = None - self.geometry = FlexiCubesGeometry( - grid_res=self.grid_res, - scale=self.grid_scale, - renderer=renderer, - render_type='neural_render', - device=device, - ) - - def forward_planes(self, images, cameras): - # images: [B, V, C_img, H_img, W_img] - # cameras: [B, V, 16] - B = images.shape[0] - - # encode images - image_feats = self.encoder(images, cameras) - image_feats = rearrange(image_feats, '(b v) l d -> b (v l) d', b=B) - - # decode triplanes - planes = self.transformer(image_feats) - - return planes - - def get_sdf_deformation_prediction(self, planes): - ''' - Predict SDF and deformation for tetrahedron vertices - :param planes: triplane feature map for the geometry - ''' - init_position = self.geometry.verts.unsqueeze(0).expand(planes.shape[0], -1, -1) - - # Step 1: predict the SDF and deformation - sdf, deformation, weight = torch.utils.checkpoint.checkpoint( - self.synthesizer.get_geometry_prediction, - planes, - init_position, - self.geometry.indices, - use_reentrant=False, - ) - - # Step 2: Normalize the deformation to avoid the flipped triangles. - deformation = 1.0 / (self.grid_res * self.deformation_multiplier) * torch.tanh(deformation) - sdf_reg_loss = torch.zeros(sdf.shape[0], device=sdf.device, dtype=torch.float32) - - #### - # Step 3: Fix some sdf if we observe empty shape (full positive or full negative) - sdf_bxnxnxn = sdf.reshape((sdf.shape[0], self.grid_res + 1, self.grid_res + 1, self.grid_res + 1)) - sdf_less_boundary = sdf_bxnxnxn[:, 1:-1, 1:-1, 1:-1].reshape(sdf.shape[0], -1) - pos_shape = torch.sum((sdf_less_boundary > 0).int(), dim=-1) - neg_shape = torch.sum((sdf_less_boundary < 0).int(), dim=-1) - zero_surface = torch.bitwise_or(pos_shape == 0, neg_shape == 0) - if torch.sum(zero_surface).item() > 0: - update_sdf = torch.zeros_like(sdf[0:1]) - max_sdf = sdf.max() - min_sdf = sdf.min() - update_sdf[:, self.geometry.center_indices] += (1.0 - min_sdf) # greater than zero - update_sdf[:, self.geometry.boundary_indices] += (-1 - max_sdf) # smaller than zero - new_sdf = torch.zeros_like(sdf) - for i_batch in range(zero_surface.shape[0]): - if zero_surface[i_batch]: - new_sdf[i_batch:i_batch + 1] += update_sdf - update_mask = (new_sdf == 0).float() - # Regulraization here is used to push the sdf to be a different sign (make it not fully positive or fully negative) - sdf_reg_loss = torch.abs(sdf).mean(dim=-1).mean(dim=-1) - sdf_reg_loss = sdf_reg_loss * zero_surface.float() - sdf = sdf * update_mask + new_sdf * (1 - update_mask) - - # Step 4: Here we remove the gradient for the bad sdf (full positive or full negative) - final_sdf = [] - final_def = [] - for i_batch in range(zero_surface.shape[0]): - if zero_surface[i_batch]: - final_sdf.append(sdf[i_batch: i_batch + 1].detach()) - final_def.append(deformation[i_batch: i_batch + 1].detach()) - else: - final_sdf.append(sdf[i_batch: i_batch + 1]) - final_def.append(deformation[i_batch: i_batch + 1]) - sdf = torch.cat(final_sdf, dim=0) - deformation = torch.cat(final_def, dim=0) - return sdf, deformation, sdf_reg_loss, weight - - def get_geometry_prediction(self, planes=None): - ''' - Function to generate mesh with give triplanes - :param planes: triplane features - ''' - # Step 1: first get the sdf and deformation value for each vertices in the tetrahedon grid. - sdf, deformation, sdf_reg_loss, weight = self.get_sdf_deformation_prediction(planes) - v_deformed = self.geometry.verts.unsqueeze(dim=0).expand(sdf.shape[0], -1, -1) + deformation - tets = self.geometry.indices - n_batch = planes.shape[0] - v_list = [] - f_list = [] - flexicubes_surface_reg_list = [] - - # Step 2: Using marching tet to obtain the mesh - for i_batch in range(n_batch): - verts, faces, flexicubes_surface_reg = self.geometry.get_mesh( - v_deformed[i_batch], - sdf[i_batch].squeeze(dim=-1), - with_uv=False, - indices=tets, - weight_n=weight[i_batch].squeeze(dim=-1), - is_training=self.training, - ) - flexicubes_surface_reg_list.append(flexicubes_surface_reg) - v_list.append(verts) - f_list.append(faces) - - flexicubes_surface_reg = torch.cat(flexicubes_surface_reg_list).mean() - flexicubes_weight_reg = (weight ** 2).mean() - - return v_list, f_list, sdf, deformation, v_deformed, (sdf_reg_loss, flexicubes_surface_reg, flexicubes_weight_reg) - - def get_texture_prediction(self, planes, tex_pos, hard_mask=None): - ''' - Predict Texture given triplanes - :param planes: the triplane feature map - :param tex_pos: Position we want to query the texture field - :param hard_mask: 2D silhoueete of the rendered image - ''' - tex_pos = torch.cat(tex_pos, dim=0) - if not hard_mask is None: - tex_pos = tex_pos * hard_mask.float() - batch_size = tex_pos.shape[0] - tex_pos = tex_pos.reshape(batch_size, -1, 3) - ################### - # We use mask to get the texture location (to save the memory) - if hard_mask is not None: - n_point_list = torch.sum(hard_mask.long().reshape(hard_mask.shape[0], -1), dim=-1) - sample_tex_pose_list = [] - max_point = n_point_list.max() - expanded_hard_mask = hard_mask.reshape(batch_size, -1, 1).expand(-1, -1, 3) > 0.5 - for i in range(tex_pos.shape[0]): - tex_pos_one_shape = tex_pos[i][expanded_hard_mask[i]].reshape(1, -1, 3) - if tex_pos_one_shape.shape[1] < max_point: - tex_pos_one_shape = torch.cat( - [tex_pos_one_shape, torch.zeros( - 1, max_point - tex_pos_one_shape.shape[1], 3, - device=tex_pos_one_shape.device, dtype=torch.float32)], dim=1) - sample_tex_pose_list.append(tex_pos_one_shape) - tex_pos = torch.cat(sample_tex_pose_list, dim=0) - - tex_feat = torch.utils.checkpoint.checkpoint( - self.synthesizer.get_texture_prediction, - planes, - tex_pos, - use_reentrant=False, - ) - - if hard_mask is not None: - final_tex_feat = torch.zeros( - planes.shape[0], hard_mask.shape[1] * hard_mask.shape[2], tex_feat.shape[-1], device=tex_feat.device) - expanded_hard_mask = hard_mask.reshape(hard_mask.shape[0], -1, 1).expand(-1, -1, final_tex_feat.shape[-1]) > 0.5 - for i in range(planes.shape[0]): - final_tex_feat[i][expanded_hard_mask[i]] = tex_feat[i][:n_point_list[i]].reshape(-1) - tex_feat = final_tex_feat - - return tex_feat.reshape(planes.shape[0], hard_mask.shape[1], hard_mask.shape[2], tex_feat.shape[-1]) - - def render_mesh(self, mesh_v, mesh_f, cam_mv, render_size=256): - ''' - Function to render a generated mesh with nvdiffrast - :param mesh_v: List of vertices for the mesh - :param mesh_f: List of faces for the mesh - :param cam_mv: 4x4 rotation matrix - :return: - ''' - return_value_list = [] - for i_mesh in range(len(mesh_v)): - return_value = self.geometry.render_mesh( - mesh_v[i_mesh], - mesh_f[i_mesh].int(), - cam_mv[i_mesh], - resolution=render_size, - hierarchical_mask=False - ) - return_value_list.append(return_value) - - return_keys = return_value_list[0].keys() - return_value = dict() - for k in return_keys: - value = [v[k] for v in return_value_list] - return_value[k] = value - - mask = torch.cat(return_value['mask'], dim=0) - hard_mask = torch.cat(return_value['hard_mask'], dim=0) - tex_pos = return_value['tex_pos'] - depth = torch.cat(return_value['depth'], dim=0) - normal = torch.cat(return_value['normal'], dim=0) - return mask, hard_mask, tex_pos, depth, normal - - def forward_geometry(self, planes, render_cameras, render_size=256): - ''' - Main function of our Generator. It first generate 3D mesh, then render it into 2D image - with given `render_cameras`. - :param planes: triplane features - :param render_cameras: cameras to render generated 3D shape - ''' - B, NV = render_cameras.shape[:2] - - # Generate 3D mesh first - mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes) - - # Render the mesh into 2D image (get 3d position of each image plane) - cam_mv = render_cameras - run_n_view = cam_mv.shape[1] - antilias_mask, hard_mask, tex_pos, depth, normal = self.render_mesh(mesh_v, mesh_f, cam_mv, render_size=render_size) - - tex_hard_mask = hard_mask - tex_pos = [torch.cat([pos[i_view:i_view + 1] for i_view in range(run_n_view)], dim=2) for pos in tex_pos] - tex_hard_mask = torch.cat( - [torch.cat( - [tex_hard_mask[i * run_n_view + i_view: i * run_n_view + i_view + 1] - for i_view in range(run_n_view)], dim=2) - for i in range(planes.shape[0])], dim=0) - - # Querying the texture field to predict the texture feature for each pixel on the image - tex_feat = self.get_texture_prediction(planes, tex_pos, tex_hard_mask) - background_feature = torch.ones_like(tex_feat) # white background - - # Merge them together - img_feat = tex_feat * tex_hard_mask + background_feature * (1 - tex_hard_mask) - - # We should split it back to the original image shape - img_feat = torch.cat( - [torch.cat( - [img_feat[i:i + 1, :, render_size * i_view: render_size * (i_view + 1)] - for i_view in range(run_n_view)], dim=0) for i in range(len(tex_pos))], dim=0) - - img = img_feat.clamp(0, 1).permute(0, 3, 1, 2).unflatten(0, (B, NV)) - antilias_mask = antilias_mask.permute(0, 3, 1, 2).unflatten(0, (B, NV)) - depth = -depth.permute(0, 3, 1, 2).unflatten(0, (B, NV)) # transform negative depth to positive - normal = normal.permute(0, 3, 1, 2).unflatten(0, (B, NV)) - - out = { - 'img': img, - 'mask': antilias_mask, - 'depth': depth, - 'normal': normal, - 'sdf': sdf, - 'mesh_v': mesh_v, - 'mesh_f': mesh_f, - 'sdf_reg_loss': sdf_reg_loss, - } - return out - - def forward(self, images, cameras, render_cameras, render_size: int): - # images: [B, V, C_img, H_img, W_img] - # cameras: [B, V, 16] - # render_cameras: [B, M, D_cam_render] - # render_size: int - B, M = render_cameras.shape[:2] - - planes = self.forward_planes(images, cameras) - out = self.forward_geometry(planes, render_cameras, render_size=render_size) - - return { - 'planes': planes, - **out - } - - def extract_mesh( - self, - planes: torch.Tensor, - use_texture_map: bool = False, - texture_resolution: int = 1024, - **kwargs, - ): - ''' - Extract a 3D mesh from FlexiCubes. Only support batch_size 1. - :param planes: triplane features - :param use_texture_map: use texture map or vertex color - :param texture_resolution: the resolution of texure map - ''' - assert planes.shape[0] == 1 - device = planes.device - - # predict geometry first - mesh_v, mesh_f, sdf, deformation, v_deformed, sdf_reg_loss = self.get_geometry_prediction(planes) - vertices, faces = mesh_v[0], mesh_f[0] - - if not use_texture_map: - # query vertex colors - vertices_tensor = vertices.unsqueeze(0) - vertices_colors = self.synthesizer.get_texture_prediction( - planes, vertices_tensor).clamp(0, 1).squeeze(0).cpu().numpy() - vertices_colors = (vertices_colors * 255).astype(np.uint8) - - return vertices.cpu().numpy(), faces.cpu().numpy(), vertices_colors - - # use x-atlas to get uv mapping for the mesh - ctx = dr.RasterizeCudaContext(device=device) - uvs, mesh_tex_idx, gb_pos, tex_hard_mask = xatlas_uvmap( - self.geometry.renderer.ctx, vertices, faces, resolution=texture_resolution) - tex_hard_mask = tex_hard_mask.float() - - # query the texture field to get the RGB color for texture map - tex_feat = self.get_texture_prediction( - planes, [gb_pos], tex_hard_mask) - background_feature = torch.zeros_like(tex_feat) - img_feat = torch.lerp(background_feature, tex_feat, tex_hard_mask) - texture_map = img_feat.permute(0, 3, 1, 2).squeeze(0) - - return vertices, faces, uvs, mesh_tex_idx, texture_map \ No newline at end of file diff --git a/instant-mesh/src/models/renderer/__init__.py b/instant-mesh/src/models/renderer/__init__.py deleted file mode 100644 index 2c772e4fa331c678cfff50884be94d7d31835b34..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/renderer/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. diff --git a/instant-mesh/src/models/renderer/synthesizer.py b/instant-mesh/src/models/renderer/synthesizer.py deleted file mode 100644 index 8db9fbdb1703b566117d227c8e4eef04157ccc93..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/renderer/synthesizer.py +++ /dev/null @@ -1,203 +0,0 @@ -# ORIGINAL LICENSE -# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Modified by Jiale Xu -# The modifications are subject to the same license as the original. - - -import itertools -import torch -import torch.nn as nn - -from .utils.renderer import ImportanceRenderer -from .utils.ray_sampler import RaySampler - - -class OSGDecoder(nn.Module): - """ - Triplane decoder that gives RGB and sigma values from sampled features. - Using ReLU here instead of Softplus in the original implementation. - - Reference: - EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 - """ - def __init__(self, n_features: int, - hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): - super().__init__() - self.net = nn.Sequential( - nn.Linear(3 * n_features, hidden_dim), - activation(), - *itertools.chain(*[[ - nn.Linear(hidden_dim, hidden_dim), - activation(), - ] for _ in range(num_layers - 2)]), - nn.Linear(hidden_dim, 1 + 3), - ) - # init all bias to zero - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.zeros_(m.bias) - - def forward(self, sampled_features, ray_directions): - # Aggregate features by mean - # sampled_features = sampled_features.mean(1) - # Aggregate features by concatenation - _N, n_planes, _M, _C = sampled_features.shape - sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) - x = sampled_features - - N, M, C = x.shape - x = x.contiguous().view(N*M, C) - - x = self.net(x) - x = x.view(N, M, -1) - rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF - sigma = x[..., 0:1] - - return {'rgb': rgb, 'sigma': sigma} - - -class TriplaneSynthesizer(nn.Module): - """ - Synthesizer that renders a triplane volume with planes and a camera. - - Reference: - EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 - """ - - DEFAULT_RENDERING_KWARGS = { - 'ray_start': 'auto', - 'ray_end': 'auto', - 'box_warp': 2., - 'white_back': True, - 'disparity_space_sampling': False, - 'clamp_mode': 'softplus', - 'sampler_bbox_min': -1., - 'sampler_bbox_max': 1., - } - - def __init__(self, triplane_dim: int, samples_per_ray: int): - super().__init__() - - # attributes - self.triplane_dim = triplane_dim - self.rendering_kwargs = { - **self.DEFAULT_RENDERING_KWARGS, - 'depth_resolution': samples_per_ray // 2, - 'depth_resolution_importance': samples_per_ray // 2, - } - - # renderings - self.renderer = ImportanceRenderer() - self.ray_sampler = RaySampler() - - # modules - self.decoder = OSGDecoder(n_features=triplane_dim) - - def forward(self, planes, cameras, render_size=128, crop_params=None): - # planes: (N, 3, D', H', W') - # cameras: (N, M, D_cam) - # render_size: int - assert planes.shape[0] == cameras.shape[0], "Batch size mismatch for planes and cameras" - N, M = cameras.shape[:2] - - cam2world_matrix = cameras[..., :16].view(N, M, 4, 4) - intrinsics = cameras[..., 16:25].view(N, M, 3, 3) - - # Create a batch of rays for volume rendering - ray_origins, ray_directions = self.ray_sampler( - cam2world_matrix=cam2world_matrix.reshape(-1, 4, 4), - intrinsics=intrinsics.reshape(-1, 3, 3), - render_size=render_size, - ) - assert N*M == ray_origins.shape[0], "Batch size mismatch for ray_origins" - assert ray_origins.dim() == 3, "ray_origins should be 3-dimensional" - - # Crop rays if crop_params is available - if crop_params is not None: - ray_origins = ray_origins.reshape(N*M, render_size, render_size, 3) - ray_directions = ray_directions.reshape(N*M, render_size, render_size, 3) - i, j, h, w = crop_params - ray_origins = ray_origins[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) - ray_directions = ray_directions[:, i:i+h, j:j+w, :].reshape(N*M, -1, 3) - - # Perform volume rendering - rgb_samples, depth_samples, weights_samples = self.renderer( - planes.repeat_interleave(M, dim=0), self.decoder, ray_origins, ray_directions, self.rendering_kwargs, - ) - - # Reshape into 'raw' neural-rendered image - if crop_params is not None: - Himg, Wimg = crop_params[2:] - else: - Himg = Wimg = render_size - rgb_images = rgb_samples.permute(0, 2, 1).reshape(N, M, rgb_samples.shape[-1], Himg, Wimg).contiguous() - depth_images = depth_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) - weight_images = weights_samples.permute(0, 2, 1).reshape(N, M, 1, Himg, Wimg) - - out = { - 'images_rgb': rgb_images, - 'images_depth': depth_images, - 'images_weight': weight_images, - } - return out - - def forward_grid(self, planes, grid_size: int, aabb: torch.Tensor = None): - # planes: (N, 3, D', H', W') - # grid_size: int - # aabb: (N, 2, 3) - if aabb is None: - aabb = torch.tensor([ - [self.rendering_kwargs['sampler_bbox_min']] * 3, - [self.rendering_kwargs['sampler_bbox_max']] * 3, - ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) - assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" - N = planes.shape[0] - - # create grid points for triplane query - grid_points = [] - for i in range(N): - grid_points.append(torch.stack(torch.meshgrid( - torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), - torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), - torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), - indexing='ij', - ), dim=-1).reshape(-1, 3)) - cube_grid = torch.stack(grid_points, dim=0).to(planes.device) - - features = self.forward_points(planes, cube_grid) - - # reshape into grid - features = { - k: v.reshape(N, grid_size, grid_size, grid_size, -1) - for k, v in features.items() - } - return features - - def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**20): - # planes: (N, 3, D', H', W') - # points: (N, P, 3) - N, P = points.shape[:2] - - # query triplane in chunks - outs = [] - for i in range(0, points.shape[1], chunk_size): - chunk_points = points[:, i:i+chunk_size] - - # query triplane - chunk_out = self.renderer.run_model_activated( - planes=planes, - decoder=self.decoder, - sample_coordinates=chunk_points, - sample_directions=torch.zeros_like(chunk_points), - options=self.rendering_kwargs, - ) - outs.append(chunk_out) - - # concatenate the outputs - point_features = { - k: torch.cat([out[k] for out in outs], dim=1) - for k in outs[0].keys() - } - return point_features diff --git a/instant-mesh/src/models/renderer/synthesizer_mesh.py b/instant-mesh/src/models/renderer/synthesizer_mesh.py deleted file mode 100644 index dc31838315b33781560b3623c030443eeae24147..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/renderer/synthesizer_mesh.py +++ /dev/null @@ -1,141 +0,0 @@ -# ORIGINAL LICENSE -# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# Modified by Jiale Xu -# The modifications are subject to the same license as the original. - -import itertools -import torch -import torch.nn as nn - -from .utils.renderer import generate_planes, project_onto_planes, sample_from_planes - - -class OSGDecoder(nn.Module): - """ - Triplane decoder that gives RGB and sigma values from sampled features. - Using ReLU here instead of Softplus in the original implementation. - - Reference: - EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 - """ - def __init__(self, n_features: int, - hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): - super().__init__() - - self.net_sdf = nn.Sequential( - nn.Linear(3 * n_features, hidden_dim), - activation(), - *itertools.chain(*[[ - nn.Linear(hidden_dim, hidden_dim), - activation(), - ] for _ in range(num_layers - 2)]), - nn.Linear(hidden_dim, 1), - ) - self.net_rgb = nn.Sequential( - nn.Linear(3 * n_features, hidden_dim), - activation(), - *itertools.chain(*[[ - nn.Linear(hidden_dim, hidden_dim), - activation(), - ] for _ in range(num_layers - 2)]), - nn.Linear(hidden_dim, 3), - ) - self.net_deformation = nn.Sequential( - nn.Linear(3 * n_features, hidden_dim), - activation(), - *itertools.chain(*[[ - nn.Linear(hidden_dim, hidden_dim), - activation(), - ] for _ in range(num_layers - 2)]), - nn.Linear(hidden_dim, 3), - ) - self.net_weight = nn.Sequential( - nn.Linear(8 * 3 * n_features, hidden_dim), - activation(), - *itertools.chain(*[[ - nn.Linear(hidden_dim, hidden_dim), - activation(), - ] for _ in range(num_layers - 2)]), - nn.Linear(hidden_dim, 21), - ) - - # init all bias to zero - for m in self.modules(): - if isinstance(m, nn.Linear): - nn.init.zeros_(m.bias) - - def get_geometry_prediction(self, sampled_features, flexicubes_indices): - _N, n_planes, _M, _C = sampled_features.shape - sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) - - sdf = self.net_sdf(sampled_features) - deformation = self.net_deformation(sampled_features) - - grid_features = torch.index_select(input=sampled_features, index=flexicubes_indices.reshape(-1), dim=1) - grid_features = grid_features.reshape( - sampled_features.shape[0], flexicubes_indices.shape[0], flexicubes_indices.shape[1] * sampled_features.shape[-1]) - weight = self.net_weight(grid_features) * 0.1 - - return sdf, deformation, weight - - def get_texture_prediction(self, sampled_features): - _N, n_planes, _M, _C = sampled_features.shape - sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) - - rgb = self.net_rgb(sampled_features) - rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF - - return rgb - - -class TriplaneSynthesizer(nn.Module): - """ - Synthesizer that renders a triplane volume with planes and a camera. - - Reference: - EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L19 - """ - - DEFAULT_RENDERING_KWARGS = { - 'ray_start': 'auto', - 'ray_end': 'auto', - 'box_warp': 2., - 'white_back': True, - 'disparity_space_sampling': False, - 'clamp_mode': 'softplus', - 'sampler_bbox_min': -1., - 'sampler_bbox_max': 1., - } - - def __init__(self, triplane_dim: int, samples_per_ray: int): - super().__init__() - - # attributes - self.triplane_dim = triplane_dim - self.rendering_kwargs = { - **self.DEFAULT_RENDERING_KWARGS, - 'depth_resolution': samples_per_ray // 2, - 'depth_resolution_importance': samples_per_ray // 2, - } - - # modules - self.plane_axes = generate_planes() - self.decoder = OSGDecoder(n_features=triplane_dim) - - def get_geometry_prediction(self, planes, sample_coordinates, flexicubes_indices): - plane_axes = self.plane_axes.to(planes.device) - sampled_features = sample_from_planes( - plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) - - sdf, deformation, weight = self.decoder.get_geometry_prediction(sampled_features, flexicubes_indices) - return sdf, deformation, weight - - def get_texture_prediction(self, planes, sample_coordinates): - plane_axes = self.plane_axes.to(planes.device) - sampled_features = sample_from_planes( - plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=self.rendering_kwargs['box_warp']) - - rgb = self.decoder.get_texture_prediction(sampled_features) - return rgb diff --git a/instant-mesh/src/models/renderer/utils/__init__.py b/instant-mesh/src/models/renderer/utils/__init__.py deleted file mode 100644 index 2c772e4fa331c678cfff50884be94d7d31835b34..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/renderer/utils/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. diff --git a/instant-mesh/src/models/renderer/utils/math_utils.py b/instant-mesh/src/models/renderer/utils/math_utils.py deleted file mode 100644 index 4cf9d2b811e0acbc7923bc9126e010b52cb1a8af..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/renderer/utils/math_utils.py +++ /dev/null @@ -1,118 +0,0 @@ -# MIT License - -# Copyright (c) 2022 Petr Kellnhofer - -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: - -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. - -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -import torch - -def transform_vectors(matrix: torch.Tensor, vectors4: torch.Tensor) -> torch.Tensor: - """ - Left-multiplies MxM @ NxM. Returns NxM. - """ - res = torch.matmul(vectors4, matrix.T) - return res - - -def normalize_vecs(vectors: torch.Tensor) -> torch.Tensor: - """ - Normalize vector lengths. - """ - return vectors / (torch.norm(vectors, dim=-1, keepdim=True)) - -def torch_dot(x: torch.Tensor, y: torch.Tensor): - """ - Dot product of two tensors. - """ - return (x * y).sum(-1) - - -def get_ray_limits_box(rays_o: torch.Tensor, rays_d: torch.Tensor, box_side_length): - """ - Author: Petr Kellnhofer - Intersects rays with the [-1, 1] NDC volume. - Returns min and max distance of entry. - Returns -1 for no intersection. - https://www.scratchapixel.com/lessons/3d-basic-rendering/minimal-ray-tracer-rendering-simple-shapes/ray-box-intersection - """ - o_shape = rays_o.shape - rays_o = rays_o.detach().reshape(-1, 3) - rays_d = rays_d.detach().reshape(-1, 3) - - - bb_min = [-1*(box_side_length/2), -1*(box_side_length/2), -1*(box_side_length/2)] - bb_max = [1*(box_side_length/2), 1*(box_side_length/2), 1*(box_side_length/2)] - bounds = torch.tensor([bb_min, bb_max], dtype=rays_o.dtype, device=rays_o.device) - is_valid = torch.ones(rays_o.shape[:-1], dtype=bool, device=rays_o.device) - - # Precompute inverse for stability. - invdir = 1 / rays_d - sign = (invdir < 0).long() - - # Intersect with YZ plane. - tmin = (bounds.index_select(0, sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] - tmax = (bounds.index_select(0, 1 - sign[..., 0])[..., 0] - rays_o[..., 0]) * invdir[..., 0] - - # Intersect with XZ plane. - tymin = (bounds.index_select(0, sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] - tymax = (bounds.index_select(0, 1 - sign[..., 1])[..., 1] - rays_o[..., 1]) * invdir[..., 1] - - # Resolve parallel rays. - is_valid[torch.logical_or(tmin > tymax, tymin > tmax)] = False - - # Use the shortest intersection. - tmin = torch.max(tmin, tymin) - tmax = torch.min(tmax, tymax) - - # Intersect with XY plane. - tzmin = (bounds.index_select(0, sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] - tzmax = (bounds.index_select(0, 1 - sign[..., 2])[..., 2] - rays_o[..., 2]) * invdir[..., 2] - - # Resolve parallel rays. - is_valid[torch.logical_or(tmin > tzmax, tzmin > tmax)] = False - - # Use the shortest intersection. - tmin = torch.max(tmin, tzmin) - tmax = torch.min(tmax, tzmax) - - # Mark invalid. - tmin[torch.logical_not(is_valid)] = -1 - tmax[torch.logical_not(is_valid)] = -2 - - return tmin.reshape(*o_shape[:-1], 1), tmax.reshape(*o_shape[:-1], 1) - - -def linspace(start: torch.Tensor, stop: torch.Tensor, num: int): - """ - Creates a tensor of shape [num, *start.shape] whose values are evenly spaced from start to end, inclusive. - Replicates but the multi-dimensional bahaviour of numpy.linspace in PyTorch. - """ - # create a tensor of 'num' steps from 0 to 1 - steps = torch.arange(num, dtype=torch.float32, device=start.device) / (num - 1) - - # reshape the 'steps' tensor to [-1, *([1]*start.ndim)] to allow for broadcastings - # - using 'steps.reshape([-1, *([1]*start.ndim)])' would be nice here but torchscript - # "cannot statically infer the expected size of a list in this contex", hence the code below - for i in range(start.ndim): - steps = steps.unsqueeze(-1) - - # the output starts at 'start' and increments until 'stop' in each dimension - out = start[None] + steps * (stop - start)[None] - - return out diff --git a/instant-mesh/src/models/renderer/utils/ray_marcher.py b/instant-mesh/src/models/renderer/utils/ray_marcher.py deleted file mode 100644 index ea1db43478de703509cdd04c684f92f8e283c5ad..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/renderer/utils/ray_marcher.py +++ /dev/null @@ -1,72 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. -# -# Modified by Jiale Xu -# The modifications are subject to the same license as the original. - - -""" -The ray marcher takes the raw output of the implicit representation and uses the volume rendering equation to produce composited colors and depths. -Based off of the implementation in MipNeRF (this one doesn't do any cone tracing though!) -""" - -import torch -import torch.nn as nn -import torch.nn.functional as F - - -class MipRayMarcher2(nn.Module): - def __init__(self, activation_factory): - super().__init__() - self.activation_factory = activation_factory - - def run_forward(self, colors, densities, depths, rendering_options, normals=None): - dtype = colors.dtype - deltas = depths[:, :, 1:] - depths[:, :, :-1] - colors_mid = (colors[:, :, :-1] + colors[:, :, 1:]) / 2 - densities_mid = (densities[:, :, :-1] + densities[:, :, 1:]) / 2 - depths_mid = (depths[:, :, :-1] + depths[:, :, 1:]) / 2 - - # using factory mode for better usability - densities_mid = self.activation_factory(rendering_options)(densities_mid).to(dtype) - - density_delta = densities_mid * deltas - - alpha = 1 - torch.exp(-density_delta).to(dtype) - - alpha_shifted = torch.cat([torch.ones_like(alpha[:, :, :1]), 1-alpha + 1e-10], -2) - weights = alpha * torch.cumprod(alpha_shifted, -2)[:, :, :-1] - weights = weights.to(dtype) - - composite_rgb = torch.sum(weights * colors_mid, -2) - weight_total = weights.sum(2) - # composite_depth = torch.sum(weights * depths_mid, -2) / weight_total - composite_depth = torch.sum(weights * depths_mid, -2) - - # clip the composite to min/max range of depths - composite_depth = torch.nan_to_num(composite_depth, float('inf')).to(dtype) - composite_depth = torch.clamp(composite_depth, torch.min(depths), torch.max(depths)) - - if rendering_options.get('white_back', False): - composite_rgb = composite_rgb + 1 - weight_total - - # rendered value scale is 0-1, comment out original mipnerf scaling - # composite_rgb = composite_rgb * 2 - 1 # Scale to (-1, 1) - - return composite_rgb, composite_depth, weights - - - def forward(self, colors, densities, depths, rendering_options, normals=None): - if normals is not None: - composite_rgb, composite_depth, composite_normals, weights = self.run_forward(colors, densities, depths, rendering_options, normals) - return composite_rgb, composite_depth, composite_normals, weights - - composite_rgb, composite_depth, weights = self.run_forward(colors, densities, depths, rendering_options) - return composite_rgb, composite_depth, weights diff --git a/instant-mesh/src/models/renderer/utils/ray_sampler.py b/instant-mesh/src/models/renderer/utils/ray_sampler.py deleted file mode 100644 index ae5151dda467e826ce346986bd486d4465c906f2..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/renderer/utils/ray_sampler.py +++ /dev/null @@ -1,141 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. -# -# Modified by Jiale Xu -# The modifications are subject to the same license as the original. - - -""" -The ray sampler is a module that takes in camera matrices and resolution and batches of rays. -Expects cam2world matrices that use the OpenCV camera coordinate system conventions. -""" - -import torch - -class RaySampler(torch.nn.Module): - def __init__(self): - super().__init__() - self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None - - - def forward(self, cam2world_matrix, intrinsics, render_size): - """ - Create batches of rays and return origins and directions. - - cam2world_matrix: (N, 4, 4) - intrinsics: (N, 3, 3) - render_size: int - - ray_origins: (N, M, 3) - ray_dirs: (N, M, 2) - """ - - dtype = cam2world_matrix.dtype - device = cam2world_matrix.device - N, M = cam2world_matrix.shape[0], render_size**2 - cam_locs_world = cam2world_matrix[:, :3, 3] - fx = intrinsics[:, 0, 0] - fy = intrinsics[:, 1, 1] - cx = intrinsics[:, 0, 2] - cy = intrinsics[:, 1, 2] - sk = intrinsics[:, 0, 1] - - uv = torch.stack(torch.meshgrid( - torch.arange(render_size, dtype=dtype, device=device), - torch.arange(render_size, dtype=dtype, device=device), - indexing='ij', - )) - uv = uv.flip(0).reshape(2, -1).transpose(1, 0) - uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) - - x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) - y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) - z_cam = torch.ones((N, M), dtype=dtype, device=device) - - x_lift = (x_cam - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y_cam/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z_cam - y_lift = (y_cam - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z_cam - - cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1).to(dtype) - - _opencv2blender = torch.tensor([ - [1, 0, 0, 0], - [0, -1, 0, 0], - [0, 0, -1, 0], - [0, 0, 0, 1], - ], dtype=dtype, device=device).unsqueeze(0).repeat(N, 1, 1) - - cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) - - world_rel_points = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] - - ray_dirs = world_rel_points - cam_locs_world[:, None, :] - ray_dirs = torch.nn.functional.normalize(ray_dirs, dim=2).to(dtype) - - ray_origins = cam_locs_world.unsqueeze(1).repeat(1, ray_dirs.shape[1], 1) - - return ray_origins, ray_dirs - - -class OrthoRaySampler(torch.nn.Module): - def __init__(self): - super().__init__() - self.ray_origins_h, self.ray_directions, self.depths, self.image_coords, self.rendering_options = None, None, None, None, None - - - def forward(self, cam2world_matrix, ortho_scale, render_size): - """ - Create batches of rays and return origins and directions. - - cam2world_matrix: (N, 4, 4) - ortho_scale: float - render_size: int - - ray_origins: (N, M, 3) - ray_dirs: (N, M, 3) - """ - - N, M = cam2world_matrix.shape[0], render_size**2 - - uv = torch.stack(torch.meshgrid( - torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), - torch.arange(render_size, dtype=torch.float32, device=cam2world_matrix.device), - indexing='ij', - )) - uv = uv.flip(0).reshape(2, -1).transpose(1, 0) - uv = uv.unsqueeze(0).repeat(cam2world_matrix.shape[0], 1, 1) - - x_cam = uv[:, :, 0].view(N, -1) * (1./render_size) + (0.5/render_size) - y_cam = uv[:, :, 1].view(N, -1) * (1./render_size) + (0.5/render_size) - z_cam = torch.zeros((N, M), device=cam2world_matrix.device) - - x_lift = (x_cam - 0.5) * ortho_scale - y_lift = (y_cam - 0.5) * ortho_scale - - cam_rel_points = torch.stack((x_lift, y_lift, z_cam, torch.ones_like(z_cam)), dim=-1) - - _opencv2blender = torch.tensor([ - [1, 0, 0, 0], - [0, -1, 0, 0], - [0, 0, -1, 0], - [0, 0, 0, 1], - ], dtype=torch.float32, device=cam2world_matrix.device).unsqueeze(0).repeat(N, 1, 1) - - cam2world_matrix = torch.bmm(cam2world_matrix, _opencv2blender) - - ray_origins = torch.bmm(cam2world_matrix, cam_rel_points.permute(0, 2, 1)).permute(0, 2, 1)[:, :, :3] - - ray_dirs_cam = torch.stack([ - torch.zeros((N, M), device=cam2world_matrix.device), - torch.zeros((N, M), device=cam2world_matrix.device), - torch.ones((N, M), device=cam2world_matrix.device), - ], dim=-1) - ray_dirs = torch.bmm(cam2world_matrix[:, :3, :3], ray_dirs_cam.permute(0, 2, 1)).permute(0, 2, 1) - - return ray_origins, ray_dirs diff --git a/instant-mesh/src/models/renderer/utils/renderer.py b/instant-mesh/src/models/renderer/utils/renderer.py deleted file mode 100644 index 95c4c728efbd0283b8ddd7dc6a1b28d1510efa97..0000000000000000000000000000000000000000 --- a/instant-mesh/src/models/renderer/utils/renderer.py +++ /dev/null @@ -1,323 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# SPDX-License-Identifier: LicenseRef-NvidiaProprietary -# -# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual -# property and proprietary rights in and to this material, related -# documentation and any modifications thereto. Any use, reproduction, -# disclosure or distribution of this material and related documentation -# without an express license agreement from NVIDIA CORPORATION or -# its affiliates is strictly prohibited. -# -# Modified by Jiale Xu -# The modifications are subject to the same license as the original. - - -""" -The renderer is a module that takes in rays, decides where to sample along each -ray, and computes pixel colors using the volume rendering equation. -""" - -import torch -import torch.nn as nn -import torch.nn.functional as F - -from .ray_marcher import MipRayMarcher2 -from . import math_utils - - -def generate_planes(): - """ - Defines planes by the three vectors that form the "axes" of the - plane. Should work with arbitrary number of planes and planes of - arbitrary orientation. - - Bugfix reference: https://github.com/NVlabs/eg3d/issues/67 - """ - return torch.tensor([[[1, 0, 0], - [0, 1, 0], - [0, 0, 1]], - [[1, 0, 0], - [0, 0, 1], - [0, 1, 0]], - [[0, 0, 1], - [0, 1, 0], - [1, 0, 0]]], dtype=torch.float32) - -def project_onto_planes(planes, coordinates): - """ - Does a projection of a 3D point onto a batch of 2D planes, - returning 2D plane coordinates. - - Takes plane axes of shape n_planes, 3, 3 - # Takes coordinates of shape N, M, 3 - # returns projections of shape N*n_planes, M, 2 - """ - N, M, C = coordinates.shape - n_planes, _, _ = planes.shape - coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3) - inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3) - projections = torch.bmm(coordinates, inv_planes) - return projections[..., :2] - -def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None): - assert padding_mode == 'zeros' - N, n_planes, C, H, W = plane_features.shape - _, M, _ = coordinates.shape - plane_features = plane_features.view(N*n_planes, C, H, W) - dtype = plane_features.dtype - - coordinates = (2/box_warp) * coordinates # add specific box bounds - - projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1) - output_features = torch.nn.functional.grid_sample( - plane_features, - projected_coordinates.to(dtype), - mode=mode, - padding_mode=padding_mode, - align_corners=False, - ).permute(0, 3, 2, 1).reshape(N, n_planes, M, C) - return output_features - -def sample_from_3dgrid(grid, coordinates): - """ - Expects coordinates in shape (batch_size, num_points_per_batch, 3) - Expects grid in shape (1, channels, H, W, D) - (Also works if grid has batch size) - Returns sampled features of shape (batch_size, num_points_per_batch, feature_channels) - """ - batch_size, n_coords, n_dims = coordinates.shape - sampled_features = torch.nn.functional.grid_sample( - grid.expand(batch_size, -1, -1, -1, -1), - coordinates.reshape(batch_size, 1, 1, -1, n_dims), - mode='bilinear', - padding_mode='zeros', - align_corners=False, - ) - N, C, H, W, D = sampled_features.shape - sampled_features = sampled_features.permute(0, 4, 3, 2, 1).reshape(N, H*W*D, C) - return sampled_features - -class ImportanceRenderer(torch.nn.Module): - """ - Modified original version to filter out-of-box samples as TensoRF does. - - Reference: - TensoRF: https://github.com/apchenstu/TensoRF/blob/main/models/tensorBase.py#L277 - """ - def __init__(self): - super().__init__() - self.activation_factory = self._build_activation_factory() - self.ray_marcher = MipRayMarcher2(self.activation_factory) - self.plane_axes = generate_planes() - - def _build_activation_factory(self): - def activation_factory(options: dict): - if options['clamp_mode'] == 'softplus': - return lambda x: F.softplus(x - 1) # activation bias of -1 makes things initialize better - else: - assert False, "Renderer only supports `clamp_mode`=`softplus`!" - return activation_factory - - def _forward_pass(self, depths: torch.Tensor, ray_directions: torch.Tensor, ray_origins: torch.Tensor, - planes: torch.Tensor, decoder: nn.Module, rendering_options: dict): - """ - Additional filtering is applied to filter out-of-box samples. - Modifications made by Zexin He. - """ - - # context related variables - batch_size, num_rays, samples_per_ray, _ = depths.shape - device = depths.device - - # define sample points with depths - sample_directions = ray_directions.unsqueeze(-2).expand(-1, -1, samples_per_ray, -1).reshape(batch_size, -1, 3) - sample_coordinates = (ray_origins.unsqueeze(-2) + depths * ray_directions.unsqueeze(-2)).reshape(batch_size, -1, 3) - - # filter out-of-box samples - mask_inbox = \ - (rendering_options['sampler_bbox_min'] <= sample_coordinates) & \ - (sample_coordinates <= rendering_options['sampler_bbox_max']) - mask_inbox = mask_inbox.all(-1) - - # forward model according to all samples - _out = self.run_model(planes, decoder, sample_coordinates, sample_directions, rendering_options) - - # set out-of-box samples to zeros(rgb) & -inf(sigma) - SAFE_GUARD = 3 - DATA_TYPE = _out['sigma'].dtype - colors_pass = torch.zeros(batch_size, num_rays * samples_per_ray, 3, device=device, dtype=DATA_TYPE) - densities_pass = torch.nan_to_num(torch.full((batch_size, num_rays * samples_per_ray, 1), -float('inf'), device=device, dtype=DATA_TYPE)) / SAFE_GUARD - colors_pass[mask_inbox], densities_pass[mask_inbox] = _out['rgb'][mask_inbox], _out['sigma'][mask_inbox] - - # reshape back - colors_pass = colors_pass.reshape(batch_size, num_rays, samples_per_ray, colors_pass.shape[-1]) - densities_pass = densities_pass.reshape(batch_size, num_rays, samples_per_ray, densities_pass.shape[-1]) - - return colors_pass, densities_pass - - def forward(self, planes, decoder, ray_origins, ray_directions, rendering_options): - # self.plane_axes = self.plane_axes.to(ray_origins.device) - - if rendering_options['ray_start'] == rendering_options['ray_end'] == 'auto': - ray_start, ray_end = math_utils.get_ray_limits_box(ray_origins, ray_directions, box_side_length=rendering_options['box_warp']) - is_ray_valid = ray_end > ray_start - if torch.any(is_ray_valid).item(): - ray_start[~is_ray_valid] = ray_start[is_ray_valid].min() - ray_end[~is_ray_valid] = ray_start[is_ray_valid].max() - depths_coarse = self.sample_stratified(ray_origins, ray_start, ray_end, rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) - else: - # Create stratified depth samples - depths_coarse = self.sample_stratified(ray_origins, rendering_options['ray_start'], rendering_options['ray_end'], rendering_options['depth_resolution'], rendering_options['disparity_space_sampling']) - - # Coarse Pass - colors_coarse, densities_coarse = self._forward_pass( - depths=depths_coarse, ray_directions=ray_directions, ray_origins=ray_origins, - planes=planes, decoder=decoder, rendering_options=rendering_options) - - # Fine Pass - N_importance = rendering_options['depth_resolution_importance'] - if N_importance > 0: - _, _, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) - - depths_fine = self.sample_importance(depths_coarse, weights, N_importance) - - colors_fine, densities_fine = self._forward_pass( - depths=depths_fine, ray_directions=ray_directions, ray_origins=ray_origins, - planes=planes, decoder=decoder, rendering_options=rendering_options) - - all_depths, all_colors, all_densities = self.unify_samples(depths_coarse, colors_coarse, densities_coarse, - depths_fine, colors_fine, densities_fine) - - rgb_final, depth_final, weights = self.ray_marcher(all_colors, all_densities, all_depths, rendering_options) - else: - rgb_final, depth_final, weights = self.ray_marcher(colors_coarse, densities_coarse, depths_coarse, rendering_options) - - return rgb_final, depth_final, weights.sum(2) - - def run_model(self, planes, decoder, sample_coordinates, sample_directions, options): - plane_axes = self.plane_axes.to(planes.device) - sampled_features = sample_from_planes(plane_axes, planes, sample_coordinates, padding_mode='zeros', box_warp=options['box_warp']) - - out = decoder(sampled_features, sample_directions) - if options.get('density_noise', 0) > 0: - out['sigma'] += torch.randn_like(out['sigma']) * options['density_noise'] - return out - - def run_model_activated(self, planes, decoder, sample_coordinates, sample_directions, options): - out = self.run_model(planes, decoder, sample_coordinates, sample_directions, options) - out['sigma'] = self.activation_factory(options)(out['sigma']) - return out - - def sort_samples(self, all_depths, all_colors, all_densities): - _, indices = torch.sort(all_depths, dim=-2) - all_depths = torch.gather(all_depths, -2, indices) - all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) - all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) - return all_depths, all_colors, all_densities - - def unify_samples(self, depths1, colors1, densities1, depths2, colors2, densities2, normals1=None, normals2=None): - all_depths = torch.cat([depths1, depths2], dim = -2) - all_colors = torch.cat([colors1, colors2], dim = -2) - all_densities = torch.cat([densities1, densities2], dim = -2) - - if normals1 is not None and normals2 is not None: - all_normals = torch.cat([normals1, normals2], dim = -2) - else: - all_normals = None - - _, indices = torch.sort(all_depths, dim=-2) - all_depths = torch.gather(all_depths, -2, indices) - all_colors = torch.gather(all_colors, -2, indices.expand(-1, -1, -1, all_colors.shape[-1])) - all_densities = torch.gather(all_densities, -2, indices.expand(-1, -1, -1, 1)) - - if all_normals is not None: - all_normals = torch.gather(all_normals, -2, indices.expand(-1, -1, -1, all_normals.shape[-1])) - return all_depths, all_colors, all_normals, all_densities - - return all_depths, all_colors, all_densities - - def sample_stratified(self, ray_origins, ray_start, ray_end, depth_resolution, disparity_space_sampling=False): - """ - Return depths of approximately uniformly spaced samples along rays. - """ - N, M, _ = ray_origins.shape - if disparity_space_sampling: - depths_coarse = torch.linspace(0, - 1, - depth_resolution, - device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) - depth_delta = 1/(depth_resolution - 1) - depths_coarse += torch.rand_like(depths_coarse) * depth_delta - depths_coarse = 1./(1./ray_start * (1. - depths_coarse) + 1./ray_end * depths_coarse) - else: - if type(ray_start) == torch.Tensor: - depths_coarse = math_utils.linspace(ray_start, ray_end, depth_resolution).permute(1,2,0,3) - depth_delta = (ray_end - ray_start) / (depth_resolution - 1) - depths_coarse += torch.rand_like(depths_coarse) * depth_delta[..., None] - else: - depths_coarse = torch.linspace(ray_start, ray_end, depth_resolution, device=ray_origins.device).reshape(1, 1, depth_resolution, 1).repeat(N, M, 1, 1) - depth_delta = (ray_end - ray_start)/(depth_resolution - 1) - depths_coarse += torch.rand_like(depths_coarse) * depth_delta - - return depths_coarse - - def sample_importance(self, z_vals, weights, N_importance): - """ - Return depths of importance sampled points along rays. See NeRF importance sampling for more. - """ - with torch.no_grad(): - batch_size, num_rays, samples_per_ray, _ = z_vals.shape - - z_vals = z_vals.reshape(batch_size * num_rays, samples_per_ray) - weights = weights.reshape(batch_size * num_rays, -1) # -1 to account for loss of 1 sample in MipRayMarcher - - # smooth weights - weights = torch.nn.functional.max_pool1d(weights.unsqueeze(1), 2, 1, padding=1) - weights = torch.nn.functional.avg_pool1d(weights, 2, 1).squeeze() - weights = weights + 0.01 - - z_vals_mid = 0.5 * (z_vals[: ,:-1] + z_vals[: ,1:]) - importance_z_vals = self.sample_pdf(z_vals_mid, weights[:, 1:-1], - N_importance).detach().reshape(batch_size, num_rays, N_importance, 1) - return importance_z_vals - - def sample_pdf(self, bins, weights, N_importance, det=False, eps=1e-5): - """ - Sample @N_importance samples from @bins with distribution defined by @weights. - Inputs: - bins: (N_rays, N_samples_+1) where N_samples_ is "the number of coarse samples per ray - 2" - weights: (N_rays, N_samples_) - N_importance: the number of samples to draw from the distribution - det: deterministic or not - eps: a small number to prevent division by zero - Outputs: - samples: the sampled samples - """ - N_rays, N_samples_ = weights.shape - weights = weights + eps # prevent division by zero (don't do inplace op!) - pdf = weights / torch.sum(weights, -1, keepdim=True) # (N_rays, N_samples_) - cdf = torch.cumsum(pdf, -1) # (N_rays, N_samples), cumulative distribution function - cdf = torch.cat([torch.zeros_like(cdf[: ,:1]), cdf], -1) # (N_rays, N_samples_+1) - # padded to 0~1 inclusive - - if det: - u = torch.linspace(0, 1, N_importance, device=bins.device) - u = u.expand(N_rays, N_importance) - else: - u = torch.rand(N_rays, N_importance, device=bins.device) - u = u.contiguous() - - inds = torch.searchsorted(cdf, u, right=True) - below = torch.clamp_min(inds-1, 0) - above = torch.clamp_max(inds, N_samples_) - - inds_sampled = torch.stack([below, above], -1).view(N_rays, 2*N_importance) - cdf_g = torch.gather(cdf, 1, inds_sampled).view(N_rays, N_importance, 2) - bins_g = torch.gather(bins, 1, inds_sampled).view(N_rays, N_importance, 2) - - denom = cdf_g[...,1]-cdf_g[...,0] - denom[denom 0 and radius > 0 - - elevation = np.deg2rad(elevation) - - camera_positions = [] - for i in range(M): - azimuth = 2 * np.pi * i / M - x = radius * np.cos(elevation) * np.cos(azimuth) - y = radius * np.cos(elevation) * np.sin(azimuth) - z = radius * np.sin(elevation) - camera_positions.append([x, y, z]) - camera_positions = np.array(camera_positions) - camera_positions = torch.from_numpy(camera_positions).float() - extrinsics = center_looking_at_camera_pose(camera_positions) - return extrinsics - - -def FOV_to_intrinsics(fov, device='cpu'): - """ - Creates a 3x3 camera intrinsics matrix from the camera field of view, specified in degrees. - Note the intrinsics are returned as normalized by image size, rather than in pixel units. - Assumes principal point is at image center. - """ - focal_length = 0.5 / np.tan(np.deg2rad(fov) * 0.5) - intrinsics = torch.tensor([[focal_length, 0, 0.5], [0, focal_length, 0.5], [0, 0, 1]], device=device) - return intrinsics - - -def get_zero123plus_input_cameras(batch_size=1, radius=4.0, fov=30.0): - """ - Get the input camera parameters. - """ - azimuths = np.array([30, 90, 150, 210, 270, 330]).astype(float) - elevations = np.array([20, -10, 20, -10, 20, -10]).astype(float) - - c2ws = spherical_camera_pose(azimuths, elevations, radius) - c2ws = c2ws.float().flatten(-2) - - Ks = FOV_to_intrinsics(fov).unsqueeze(0).repeat(6, 1, 1).float().flatten(-2) - - extrinsics = c2ws[:, :12] - intrinsics = torch.stack([Ks[:, 0], Ks[:, 4], Ks[:, 2], Ks[:, 5]], dim=-1) - cameras = torch.cat([extrinsics, intrinsics], dim=-1) - - return cameras.unsqueeze(0).repeat(batch_size, 1, 1) diff --git a/instant-mesh/src/utils/infer_util.py b/instant-mesh/src/utils/infer_util.py deleted file mode 100644 index 89cd078214afcc0e3dadafea5fbbb9ac005ea476..0000000000000000000000000000000000000000 --- a/instant-mesh/src/utils/infer_util.py +++ /dev/null @@ -1,84 +0,0 @@ -import os -import imageio -import rembg -import torch -import numpy as np -import PIL.Image -from PIL import Image -from typing import Any - - -def remove_background(image: PIL.Image.Image, - rembg_session: Any = None, - force: bool = False, - **rembg_kwargs, -) -> PIL.Image.Image: - do_remove = True - if image.mode == "RGBA" and image.getextrema()[3][0] < 255: - do_remove = False - do_remove = do_remove or force - if do_remove: - image = rembg.remove(image, session=rembg_session, **rembg_kwargs) - return image - - -def resize_foreground( - image: PIL.Image.Image, - ratio: float, -) -> PIL.Image.Image: - image = np.array(image) - assert image.shape[-1] == 4 - alpha = np.where(image[..., 3] > 0) - y1, y2, x1, x2 = ( - alpha[0].min(), - alpha[0].max(), - alpha[1].min(), - alpha[1].max(), - ) - # crop the foreground - fg = image[y1:y2, x1:x2] - # pad to square - size = max(fg.shape[0], fg.shape[1]) - ph0, pw0 = (size - fg.shape[0]) // 2, (size - fg.shape[1]) // 2 - ph1, pw1 = size - fg.shape[0] - ph0, size - fg.shape[1] - pw0 - new_image = np.pad( - fg, - ((ph0, ph1), (pw0, pw1), (0, 0)), - mode="constant", - constant_values=((0, 0), (0, 0), (0, 0)), - ) - - # compute padding according to the ratio - new_size = int(new_image.shape[0] / ratio) - # pad to size, double side - ph0, pw0 = (new_size - size) // 2, (new_size - size) // 2 - ph1, pw1 = new_size - size - ph0, new_size - size - pw0 - new_image = np.pad( - new_image, - ((ph0, ph1), (pw0, pw1), (0, 0)), - mode="constant", - constant_values=((0, 0), (0, 0), (0, 0)), - ) - new_image = PIL.Image.fromarray(new_image) - return new_image - - -def images_to_video( - images: torch.Tensor, - output_path: str, - fps: int = 30, -) -> None: - # images: (N, C, H, W) - video_dir = os.path.dirname(output_path) - video_name = os.path.basename(output_path) - os.makedirs(video_dir, exist_ok=True) - - frames = [] - for i in range(len(images)): - frame = (images[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8) - assert frame.shape[0] == images.shape[2] and frame.shape[1] == images.shape[3], \ - f"Frame shape mismatch: {frame.shape} vs {images.shape}" - assert frame.min() >= 0 and frame.max() <= 255, \ - f"Frame value out of range: {frame.min()} ~ {frame.max()}" - frames.append(frame) - imageio.mimwrite(output_path, np.stack(frames), fps=fps, quality=10) \ No newline at end of file diff --git a/instant-mesh/src/utils/mesh_util.py b/instant-mesh/src/utils/mesh_util.py deleted file mode 100644 index 0ec4663eeaa5c54209e08771969ec4f2a739c0b4..0000000000000000000000000000000000000000 --- a/instant-mesh/src/utils/mesh_util.py +++ /dev/null @@ -1,181 +0,0 @@ -# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property -# and proprietary rights in and to this software, related documentation -# and any modifications thereto. Any use, reproduction, disclosure or -# distribution of this software and related documentation without an express -# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited. - -import torch -import xatlas -import trimesh -import cv2 -import numpy as np -import nvdiffrast.torch as dr -from PIL import Image - - -def save_obj(pointnp_px3, facenp_fx3, colornp_px3, fpath): - - pointnp_px3 = pointnp_px3 @ np.array([[1, 0, 0], [0, 1, 0], [0, 0, -1]]) - facenp_fx3 = facenp_fx3[:, [2, 1, 0]] - - mesh = trimesh.Trimesh( - vertices=pointnp_px3, - faces=facenp_fx3, - vertex_colors=colornp_px3, - ) - mesh.export(fpath, 'obj') - - -def save_glb(pointnp_px3, facenp_fx3, colornp_px3, fpath): - - pointnp_px3 = pointnp_px3 @ np.array([[-1, 0, 0], [0, 1, 0], [0, 0, -1]]) - - mesh = trimesh.Trimesh( - vertices=pointnp_px3, - faces=facenp_fx3, - vertex_colors=colornp_px3, - ) - mesh.export(fpath, 'glb') - - -def save_obj_with_mtl(pointnp_px3, tcoords_px2, facenp_fx3, facetex_fx3, texmap_hxwx3, fname): - import os - fol, na = os.path.split(fname) - na, _ = os.path.splitext(na) - - matname = '%s/%s.mtl' % (fol, na) - fid = open(matname, 'w') - fid.write('newmtl material_0\n') - fid.write('Kd 1 1 1\n') - fid.write('Ka 0 0 0\n') - fid.write('Ks 0.4 0.4 0.4\n') - fid.write('Ns 10\n') - fid.write('illum 2\n') - fid.write('map_Kd %s.png\n' % na) - fid.close() - #### - - fid = open(fname, 'w') - fid.write('mtllib %s.mtl\n' % na) - - for pidx, p in enumerate(pointnp_px3): - pp = p - fid.write('v %f %f %f\n' % (pp[0], pp[1], pp[2])) - - for pidx, p in enumerate(tcoords_px2): - pp = p - fid.write('vt %f %f\n' % (pp[0], pp[1])) - - fid.write('usemtl material_0\n') - for i, f in enumerate(facenp_fx3): - f1 = f + 1 - f2 = facetex_fx3[i] + 1 - fid.write('f %d/%d %d/%d %d/%d\n' % (f1[0], f2[0], f1[1], f2[1], f1[2], f2[2])) - fid.close() - - # save texture map - lo, hi = 0, 1 - img = np.asarray(texmap_hxwx3, dtype=np.float32) - img = (img - lo) * (255 / (hi - lo)) - img = img.clip(0, 255) - mask = np.sum(img.astype(np.float32), axis=-1, keepdims=True) - mask = (mask <= 3.0).astype(np.float32) - kernel = np.ones((3, 3), 'uint8') - dilate_img = cv2.dilate(img, kernel, iterations=1) - img = img * (1 - mask) + dilate_img * mask - img = img.clip(0, 255).astype(np.uint8) - Image.fromarray(np.ascontiguousarray(img[::-1, :, :]), 'RGB').save(f'{fol}/{na}.png') - - -def loadobj(meshfile): - v = [] - f = [] - meshfp = open(meshfile, 'r') - for line in meshfp.readlines(): - data = line.strip().split(' ') - data = [da for da in data if len(da) > 0] - if len(data) != 4: - continue - if data[0] == 'v': - v.append([float(d) for d in data[1:]]) - if data[0] == 'f': - data = [da.split('/')[0] for da in data] - f.append([int(d) for d in data[1:]]) - meshfp.close() - - # torch need int64 - facenp_fx3 = np.array(f, dtype=np.int64) - 1 - pointnp_px3 = np.array(v, dtype=np.float32) - return pointnp_px3, facenp_fx3 - - -def loadobjtex(meshfile): - v = [] - vt = [] - f = [] - ft = [] - meshfp = open(meshfile, 'r') - for line in meshfp.readlines(): - data = line.strip().split(' ') - data = [da for da in data if len(da) > 0] - if not ((len(data) == 3) or (len(data) == 4) or (len(data) == 5)): - continue - if data[0] == 'v': - assert len(data) == 4 - - v.append([float(d) for d in data[1:]]) - if data[0] == 'vt': - if len(data) == 3 or len(data) == 4: - vt.append([float(d) for d in data[1:3]]) - if data[0] == 'f': - data = [da.split('/') for da in data] - if len(data) == 4: - f.append([int(d[0]) for d in data[1:]]) - ft.append([int(d[1]) for d in data[1:]]) - elif len(data) == 5: - idx1 = [1, 2, 3] - data1 = [data[i] for i in idx1] - f.append([int(d[0]) for d in data1]) - ft.append([int(d[1]) for d in data1]) - idx2 = [1, 3, 4] - data2 = [data[i] for i in idx2] - f.append([int(d[0]) for d in data2]) - ft.append([int(d[1]) for d in data2]) - meshfp.close() - - # torch need int64 - facenp_fx3 = np.array(f, dtype=np.int64) - 1 - ftnp_fx3 = np.array(ft, dtype=np.int64) - 1 - pointnp_px3 = np.array(v, dtype=np.float32) - uvs = np.array(vt, dtype=np.float32) - return pointnp_px3, facenp_fx3, uvs, ftnp_fx3 - - -# ============================================================================================== -def interpolate(attr, rast, attr_idx, rast_db=None): - return dr.interpolate(attr.contiguous(), rast, attr_idx, rast_db=rast_db, diff_attrs=None if rast_db is None else 'all') - - -def xatlas_uvmap(ctx, mesh_v, mesh_pos_idx, resolution): - vmapping, indices, uvs = xatlas.parametrize(mesh_v.detach().cpu().numpy(), mesh_pos_idx.detach().cpu().numpy()) - - # Convert to tensors - indices_int64 = indices.astype(np.uint64, casting='same_kind').view(np.int64) - - uvs = torch.tensor(uvs, dtype=torch.float32, device=mesh_v.device) - mesh_tex_idx = torch.tensor(indices_int64, dtype=torch.int64, device=mesh_v.device) - # mesh_v_tex. ture - uv_clip = uvs[None, ...] * 2.0 - 1.0 - - # pad to four component coordinate - uv_clip4 = torch.cat((uv_clip, torch.zeros_like(uv_clip[..., 0:1]), torch.ones_like(uv_clip[..., 0:1])), dim=-1) - - # rasterize - rast, _ = dr.rasterize(ctx, uv_clip4, mesh_tex_idx.int(), (resolution, resolution)) - - # Interpolate world space position - gb_pos, _ = interpolate(mesh_v[None, ...], rast, mesh_pos_idx.int()) - mask = rast[..., 3:4] > 0 - return uvs, mesh_tex_idx, gb_pos, mask diff --git a/instant-mesh/src/utils/train_util.py b/instant-mesh/src/utils/train_util.py deleted file mode 100644 index 2e65421bffa8cc42c1517e86f2dfd8183caf52ab..0000000000000000000000000000000000000000 --- a/instant-mesh/src/utils/train_util.py +++ /dev/null @@ -1,26 +0,0 @@ -import importlib - - -def count_params(model, verbose=False): - total_params = sum(p.numel() for p in model.parameters()) - if verbose: - print(f"{model.__class__.__name__} has {total_params*1.e-6:.2f} M params.") - return total_params - - -def instantiate_from_config(config): - if not "target" in config: - if config == '__is_first_stage__': - return None - elif config == "__is_unconditional__": - return None - raise KeyError("Expected key `target` to instantiate.") - return get_obj_from_str(config["target"])(**config.get("params", dict())) - - -def get_obj_from_str(string, reload=False): - module, cls = string.rsplit(".", 1) - if reload: - module_imp = importlib.import_module(module) - importlib.reload(module_imp) - return getattr(importlib.import_module(module, package=None), cls) diff --git a/instant-mesh/utils.py b/instant-mesh/utils.py deleted file mode 100644 index 99882dba2bbd2b56dc6caf53dd30d423f749a05a..0000000000000000000000000000000000000000 --- a/instant-mesh/utils.py +++ /dev/null @@ -1,178 +0,0 @@ -import os -import imageio -import numpy as np -import torch -import rembg -from PIL import Image -from torchvision.transforms import v2 -from pytorch_lightning import seed_everything -from omegaconf import OmegaConf -from einops import rearrange, repeat -from tqdm import tqdm -from diffusers import DiffusionPipeline, EulerAncestralDiscreteScheduler - -from src.utils.train_util import instantiate_from_config -from src.utils.camera_util import ( - FOV_to_intrinsics, - get_zero123plus_input_cameras, - get_circular_camera_poses, -) -from src.utils.mesh_util import save_obj, save_glb -from src.utils.infer_util import remove_background, resize_foreground, images_to_video - -import tempfile -from functools import partial - -from huggingface_hub import hf_hub_download - -import gradio as gr -import shutil -import spaces - - -def get_render_cameras(batch_size=1, M=120, radius=2.5, elevation=10.0, is_flexicubes=False): - """ - Get the rendering camera parameters. - """ - c2ws = get_circular_camera_poses(M=M, radius=radius, elevation=elevation) - if is_flexicubes: - cameras = torch.linalg.inv(c2ws) - cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1, 1) - else: - extrinsics = c2ws.flatten(-2) - intrinsics = FOV_to_intrinsics(50.0).unsqueeze(0).repeat(M, 1, 1).float().flatten(-2) - cameras = torch.cat([extrinsics, intrinsics], dim=-1) - cameras = cameras.unsqueeze(0).repeat(batch_size, 1, 1) - return cameras - - -import shutil - -def find_cuda(): - # Check if CUDA_HOME or CUDA_PATH environment variables are set - cuda_home = os.environ.get('CUDA_HOME') or os.environ.get('CUDA_PATH') - - if cuda_home and os.path.exists(cuda_home): - return cuda_home - - # Search for the nvcc executable in the system's PATH - nvcc_path = shutil.which('nvcc') - - if nvcc_path: - # Remove the 'bin/nvcc' part to get the CUDA installation path - cuda_path = os.path.dirname(os.path.dirname(nvcc_path)) - return cuda_path - - return None - -def check_input_image(input_image): - if input_image is None: - raise gr.Error("No image uploaded!") - - -def preprocess(input_image, do_remove_background): - - rembg_session = rembg.new_session() if do_remove_background else None - - if do_remove_background: - input_image = remove_background(input_image, rembg_session) - input_image = resize_foreground(input_image, 0.85) - - return input_image - - -@spaces.GPU -def generate_mvs(input_image, sample_steps, sample_seed): - - seed_everything(sample_seed) - - # sampling - z123_image = pipeline( - input_image, - num_inference_steps=sample_steps - ).images[0] - - show_image = np.asarray(z123_image, dtype=np.uint8) - show_image = torch.from_numpy(show_image) # (960, 640, 3) - show_image = rearrange(show_image, '(n h) (m w) c -> (n m) h w c', n=3, m=2) - show_image = rearrange(show_image, '(n m) h w c -> (n h) (m w) c', n=2, m=3) - show_image = Image.fromarray(show_image.numpy()) - - return z123_image, show_image - - -@spaces.GPU -def make3d(images): - - global model - if IS_FLEXICUBES: - model.init_flexicubes_geometry(device, use_renderer=False) - model = model.eval() - - images = np.asarray(images, dtype=np.float32) / 255.0 - images = torch.from_numpy(images).permute(2, 0, 1).contiguous().float() # (3, 960, 640) - images = rearrange(images, 'c (n h) (m w) -> (n m) c h w', n=3, m=2) # (6, 3, 320, 320) - - input_cameras = get_zero123plus_input_cameras(batch_size=1, radius=4.0).to(device) - render_cameras = get_render_cameras(batch_size=1, radius=2.5, is_flexicubes=IS_FLEXICUBES).to(device) - - images = images.unsqueeze(0).to(device) - images = v2.functional.resize(images, (320, 320), interpolation=3, antialias=True).clamp(0, 1) - - mesh_fpath = tempfile.NamedTemporaryFile(suffix=f".obj", delete=False).name - print(mesh_fpath) - mesh_basename = os.path.basename(mesh_fpath).split('.')[0] - mesh_dirname = os.path.dirname(mesh_fpath) - video_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.mp4") - mesh_glb_fpath = os.path.join(mesh_dirname, f"{mesh_basename}.glb") - - with torch.no_grad(): - # get triplane - planes = model.forward_planes(images, input_cameras) - - # # get video - # chunk_size = 20 if IS_FLEXICUBES else 1 - # render_size = 384 - - # frames = [] - # for i in tqdm(range(0, render_cameras.shape[1], chunk_size)): - # if IS_FLEXICUBES: - # frame = model.forward_geometry( - # planes, - # render_cameras[:, i:i+chunk_size], - # render_size=render_size, - # )['img'] - # else: - # frame = model.synthesizer( - # planes, - # cameras=render_cameras[:, i:i+chunk_size], - # render_size=render_size, - # )['images_rgb'] - # frames.append(frame) - # frames = torch.cat(frames, dim=1) - - # images_to_video( - # frames[0], - # video_fpath, - # fps=30, - # ) - - # print(f"Video saved to {video_fpath}") - - # get mesh - mesh_out = model.extract_mesh( - planes, - use_texture_map=False, - **infer_config, - ) - - vertices, faces, vertex_colors = mesh_out - vertices = vertices[:, [1, 2, 0]] - - save_glb(vertices, faces, vertex_colors, mesh_glb_fpath) - save_obj(vertices, faces, vertex_colors, mesh_fpath) - - print(f"Mesh saved to {mesh_fpath}") - - return mesh_fpath, mesh_glb_fpath - diff --git a/instant-mesh/zero123plus/pipeline.py b/instant-mesh/zero123plus/pipeline.py deleted file mode 100644 index 0088218346b36f07662d051670e51c658df59f1f..0000000000000000000000000000000000000000 --- a/instant-mesh/zero123plus/pipeline.py +++ /dev/null @@ -1,406 +0,0 @@ -from typing import Any, Dict, Optional -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.schedulers import KarrasDiffusionSchedulers - -import numpy -import torch -import torch.nn as nn -import torch.utils.checkpoint -import torch.distributed -import transformers -from collections import OrderedDict -from PIL import Image -from torchvision import transforms -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer - -import diffusers -from diffusers import ( - AutoencoderKL, - DDPMScheduler, - DiffusionPipeline, - EulerAncestralDiscreteScheduler, - UNet2DConditionModel, - ImagePipelineOutput -) -from diffusers.image_processor import VaeImageProcessor -from diffusers.models.attention_processor import Attention, AttnProcessor, XFormersAttnProcessor, AttnProcessor2_0 -from diffusers.utils.import_utils import is_xformers_available - - -def to_rgb_image(maybe_rgba: Image.Image): - if maybe_rgba.mode == 'RGB': - return maybe_rgba - elif maybe_rgba.mode == 'RGBA': - rgba = maybe_rgba - img = numpy.random.randint(255, 256, size=[rgba.size[1], rgba.size[0], 3], dtype=numpy.uint8) - img = Image.fromarray(img, 'RGB') - img.paste(rgba, mask=rgba.getchannel('A')) - return img - else: - raise ValueError("Unsupported image type.", maybe_rgba.mode) - - -class ReferenceOnlyAttnProc(torch.nn.Module): - def __init__( - self, - chained_proc, - enabled=False, - name=None - ) -> None: - super().__init__() - self.enabled = enabled - self.chained_proc = chained_proc - self.name = name - - def __call__( - self, attn: Attention, hidden_states, encoder_hidden_states=None, attention_mask=None, - mode="w", ref_dict: dict = None, is_cfg_guidance = False - ) -> Any: - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - if self.enabled and is_cfg_guidance: - res0 = self.chained_proc(attn, hidden_states[:1], encoder_hidden_states[:1], attention_mask) - hidden_states = hidden_states[1:] - encoder_hidden_states = encoder_hidden_states[1:] - if self.enabled: - if mode == 'w': - ref_dict[self.name] = encoder_hidden_states - elif mode == 'r': - encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict.pop(self.name)], dim=1) - elif mode == 'm': - encoder_hidden_states = torch.cat([encoder_hidden_states, ref_dict[self.name]], dim=1) - else: - assert False, mode - res = self.chained_proc(attn, hidden_states, encoder_hidden_states, attention_mask) - if self.enabled and is_cfg_guidance: - res = torch.cat([res0, res]) - return res - - -class RefOnlyNoisedUNet(torch.nn.Module): - def __init__(self, unet: UNet2DConditionModel, train_sched: DDPMScheduler, val_sched: EulerAncestralDiscreteScheduler) -> None: - super().__init__() - self.unet = unet - self.train_sched = train_sched - self.val_sched = val_sched - - unet_lora_attn_procs = dict() - for name, _ in unet.attn_processors.items(): - if torch.__version__ >= '2.0': - default_attn_proc = AttnProcessor2_0() - elif is_xformers_available(): - default_attn_proc = XFormersAttnProcessor() - else: - default_attn_proc = AttnProcessor() - unet_lora_attn_procs[name] = ReferenceOnlyAttnProc( - default_attn_proc, enabled=name.endswith("attn1.processor"), name=name - ) - unet.set_attn_processor(unet_lora_attn_procs) - - def __getattr__(self, name: str): - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.unet, name) - - def forward_cond(self, noisy_cond_lat, timestep, encoder_hidden_states, class_labels, ref_dict, is_cfg_guidance, **kwargs): - if is_cfg_guidance: - encoder_hidden_states = encoder_hidden_states[1:] - class_labels = class_labels[1:] - self.unet( - noisy_cond_lat, timestep, - encoder_hidden_states=encoder_hidden_states, - class_labels=class_labels, - cross_attention_kwargs=dict(mode="w", ref_dict=ref_dict), - **kwargs - ) - - def forward( - self, sample, timestep, encoder_hidden_states, class_labels=None, - *args, cross_attention_kwargs, - down_block_res_samples=None, mid_block_res_sample=None, - **kwargs - ): - cond_lat = cross_attention_kwargs['cond_lat'] - is_cfg_guidance = cross_attention_kwargs.get('is_cfg_guidance', False) - noise = torch.randn_like(cond_lat) - if self.training: - noisy_cond_lat = self.train_sched.add_noise(cond_lat, noise, timestep) - noisy_cond_lat = self.train_sched.scale_model_input(noisy_cond_lat, timestep) - else: - noisy_cond_lat = self.val_sched.add_noise(cond_lat, noise, timestep.reshape(-1)) - noisy_cond_lat = self.val_sched.scale_model_input(noisy_cond_lat, timestep.reshape(-1)) - ref_dict = {} - self.forward_cond( - noisy_cond_lat, timestep, - encoder_hidden_states, class_labels, - ref_dict, is_cfg_guidance, **kwargs - ) - weight_dtype = self.unet.dtype - return self.unet( - sample, timestep, - encoder_hidden_states, *args, - class_labels=class_labels, - cross_attention_kwargs=dict(mode="r", ref_dict=ref_dict, is_cfg_guidance=is_cfg_guidance), - down_block_additional_residuals=[ - sample.to(dtype=weight_dtype) for sample in down_block_res_samples - ] if down_block_res_samples is not None else None, - mid_block_additional_residual=( - mid_block_res_sample.to(dtype=weight_dtype) - if mid_block_res_sample is not None else None - ), - **kwargs - ) - - -def scale_latents(latents): - latents = (latents - 0.22) * 0.75 - return latents - - -def unscale_latents(latents): - latents = latents / 0.75 + 0.22 - return latents - - -def scale_image(image): - image = image * 0.5 / 0.8 - return image - - -def unscale_image(image): - image = image / 0.5 * 0.8 - return image - - -class DepthControlUNet(torch.nn.Module): - def __init__(self, unet: RefOnlyNoisedUNet, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0) -> None: - super().__init__() - self.unet = unet - if controlnet is None: - self.controlnet = diffusers.ControlNetModel.from_unet(unet.unet) - else: - self.controlnet = controlnet - DefaultAttnProc = AttnProcessor2_0 - if is_xformers_available(): - DefaultAttnProc = XFormersAttnProcessor - self.controlnet.set_attn_processor(DefaultAttnProc()) - self.conditioning_scale = conditioning_scale - - def __getattr__(self, name: str): - try: - return super().__getattr__(name) - except AttributeError: - return getattr(self.unet, name) - - def forward(self, sample, timestep, encoder_hidden_states, class_labels=None, *args, cross_attention_kwargs: dict, **kwargs): - cross_attention_kwargs = dict(cross_attention_kwargs) - control_depth = cross_attention_kwargs.pop('control_depth') - down_block_res_samples, mid_block_res_sample = self.controlnet( - sample, - timestep, - encoder_hidden_states=encoder_hidden_states, - controlnet_cond=control_depth, - conditioning_scale=self.conditioning_scale, - return_dict=False, - ) - return self.unet( - sample, - timestep, - encoder_hidden_states=encoder_hidden_states, - down_block_res_samples=down_block_res_samples, - mid_block_res_sample=mid_block_res_sample, - cross_attention_kwargs=cross_attention_kwargs - ) - - -class ModuleListDict(torch.nn.Module): - def __init__(self, procs: dict) -> None: - super().__init__() - self.keys = sorted(procs.keys()) - self.values = torch.nn.ModuleList(procs[k] for k in self.keys) - - def __getitem__(self, key): - return self.values[self.keys.index(key)] - - -class SuperNet(torch.nn.Module): - def __init__(self, state_dict: Dict[str, torch.Tensor]): - super().__init__() - state_dict = OrderedDict((k, state_dict[k]) for k in sorted(state_dict.keys())) - self.layers = torch.nn.ModuleList(state_dict.values()) - self.mapping = dict(enumerate(state_dict.keys())) - self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())} - - # .processor for unet, .self_attn for text encoder - self.split_keys = [".processor", ".self_attn"] - - # we add a hook to state_dict() and load_state_dict() so that the - # naming fits with `unet.attn_processors` - def map_to(module, state_dict, *args, **kwargs): - new_state_dict = {} - for key, value in state_dict.items(): - num = int(key.split(".")[1]) # 0 is always "layers" - new_key = key.replace(f"layers.{num}", module.mapping[num]) - new_state_dict[new_key] = value - - return new_state_dict - - def remap_key(key, state_dict): - for k in self.split_keys: - if k in key: - return key.split(k)[0] + k - return key.split('.')[0] - - def map_from(module, state_dict, *args, **kwargs): - all_keys = list(state_dict.keys()) - for key in all_keys: - replace_key = remap_key(key, state_dict) - new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}") - state_dict[new_key] = state_dict[key] - del state_dict[key] - - self._register_state_dict_hook(map_to) - self._register_load_state_dict_pre_hook(map_from, with_module=True) - - -class Zero123PlusPipeline(diffusers.StableDiffusionPipeline): - tokenizer: transformers.CLIPTokenizer - text_encoder: transformers.CLIPTextModel - vision_encoder: transformers.CLIPVisionModelWithProjection - - feature_extractor_clip: transformers.CLIPImageProcessor - unet: UNet2DConditionModel - scheduler: diffusers.schedulers.KarrasDiffusionSchedulers - - vae: AutoencoderKL - ramping: nn.Linear - - feature_extractor_vae: transformers.CLIPImageProcessor - - depth_transforms_multi = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize([0.5], [0.5]) - ]) - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - vision_encoder: transformers.CLIPVisionModelWithProjection, - feature_extractor_clip: CLIPImageProcessor, - feature_extractor_vae: CLIPImageProcessor, - ramping_coefficients: Optional[list] = None, - safety_checker=None, - ): - DiffusionPipeline.__init__(self) - - self.register_modules( - vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, - unet=unet, scheduler=scheduler, safety_checker=None, - vision_encoder=vision_encoder, - feature_extractor_clip=feature_extractor_clip, - feature_extractor_vae=feature_extractor_vae - ) - self.register_to_config(ramping_coefficients=ramping_coefficients) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - - def prepare(self): - train_sched = DDPMScheduler.from_config(self.scheduler.config) - if isinstance(self.unet, UNet2DConditionModel): - self.unet = RefOnlyNoisedUNet(self.unet, train_sched, self.scheduler).eval() - - def add_controlnet(self, controlnet: Optional[diffusers.ControlNetModel] = None, conditioning_scale=1.0): - self.prepare() - self.unet = DepthControlUNet(self.unet, controlnet, conditioning_scale) - return SuperNet(OrderedDict([('controlnet', self.unet.controlnet)])) - - def encode_condition_image(self, image: torch.Tensor): - image = self.vae.encode(image).latent_dist.sample() - return image - - @torch.no_grad() - def __call__( - self, - image: Image.Image = None, - prompt = "", - *args, - num_images_per_prompt: Optional[int] = 1, - guidance_scale=4.0, - depth_image: Image.Image = None, - output_type: Optional[str] = "pil", - width=640, - height=960, - num_inference_steps=28, - return_dict=True, - **kwargs - ): - self.prepare() - if image is None: - raise ValueError("Inputting embeddings not supported for this pipeline. Please pass an image.") - assert not isinstance(image, torch.Tensor) - image = to_rgb_image(image) - image_1 = self.feature_extractor_vae(images=image, return_tensors="pt").pixel_values - image_2 = self.feature_extractor_clip(images=image, return_tensors="pt").pixel_values - if depth_image is not None and hasattr(self.unet, "controlnet"): - depth_image = to_rgb_image(depth_image) - depth_image = self.depth_transforms_multi(depth_image).to( - device=self.unet.controlnet.device, dtype=self.unet.controlnet.dtype - ) - image = image_1.to(device=self.vae.device, dtype=self.vae.dtype) - image_2 = image_2.to(device=self.vae.device, dtype=self.vae.dtype) - cond_lat = self.encode_condition_image(image) - if guidance_scale > 1: - negative_lat = self.encode_condition_image(torch.zeros_like(image)) - cond_lat = torch.cat([negative_lat, cond_lat]) - encoded = self.vision_encoder(image_2, output_hidden_states=False) - global_embeds = encoded.image_embeds - global_embeds = global_embeds.unsqueeze(-2) - - if hasattr(self, "encode_prompt"): - encoder_hidden_states = self.encode_prompt( - prompt, - self.device, - num_images_per_prompt, - False - )[0] - else: - encoder_hidden_states = self._encode_prompt( - prompt, - self.device, - num_images_per_prompt, - False - ) - ramp = global_embeds.new_tensor(self.config.ramping_coefficients).unsqueeze(-1) - encoder_hidden_states = encoder_hidden_states + global_embeds * ramp - cak = dict(cond_lat=cond_lat) - if hasattr(self.unet, "controlnet"): - cak['control_depth'] = depth_image - latents: torch.Tensor = super().__call__( - None, - *args, - cross_attention_kwargs=cak, - guidance_scale=guidance_scale, - num_images_per_prompt=num_images_per_prompt, - prompt_embeds=encoder_hidden_states, - num_inference_steps=num_inference_steps, - output_type='latent', - width=width, - height=height, - **kwargs - ).images - latents = unscale_latents(latents) - if not output_type == "latent": - image = unscale_image(self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]) - else: - image = latents - - image = self.image_processor.postprocess(image, output_type=output_type) - if not return_dict: - return (image,) - - return ImagePipelineOutput(images=image)