Spaces:
Sleeping
Sleeping
File size: 7,338 Bytes
801501a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import skimage.measure
import time
from ..custom_types import *
from .train_utils import Logger
from .. import constants
def mcubes_skimage(pytorch_3d_occ_tensor: T, voxel_grid_origin: List[float], voxel_size: float) -> T_Mesh:
numpy_3d_occ_tensor = pytorch_3d_occ_tensor.numpy()
try:
marching_cubes = skimage.measure.marching_cubes if 'marching_cubes' in dir(skimage.measure) else skimage.measure.marching_cubes_lewiner
verts, faces, normals, values = marching_cubes(numpy_3d_occ_tensor, level=0.0, spacing=[voxel_size] * 3)
except BaseException:
print("mc failed")
return None
mesh_points = np.zeros_like(verts)
mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]
return torch.from_numpy(mesh_points.copy()).float(), torch.from_numpy(faces.copy()).long()
class MarchingCubesMeshing:
def fill_samples(self, decoder, samples, device: Optional[D] = None) -> T:
num_samples = samples.shape[1]
num_iters = num_samples // self.max_batch + int(num_samples % self.max_batch != 0)
sample_coords = samples[:3]
if self.verbose:
logger = Logger()
logger.start(num_iters, tag='meshing')
for i in range(num_iters):
sample_subset = sample_coords[:, i * self.max_batch: min((i + 1) * self.max_batch, num_samples)]
if device is not None:
sample_subset = sample_subset.to(device)
sample_subset = sample_subset.T
samples[3, i * self.max_batch: min((i + 1) * self.max_batch, num_samples)] = (
decoder(sample_subset * self.scale).squeeze().detach()
)
if self.verbose:
logger.reset_iter()
if self.verbose:
logger.stop()
return samples
def fill_recursive(self, decoder, samples: T, stride: int, base_res: int, depth: int) -> T:
if base_res <= self.min_res:
samples_ = self.fill_samples(decoder, samples)
return samples_
kernel_size = 7 + 4 * depth
padding = tuple([kernel_size // 2] * 6)
samples_ = samples.view(1, 4, base_res, base_res, base_res)
samples_ = nnf.avg_pool3d(samples_, stride, stride)
samples_ = samples_.view(4, -1)
res = base_res // stride
samples_lower = self.fill_recursive(decoder, samples_, stride, res, depth - 1)
mask = samples_lower[-1, :].lt(.3)
mask = mask.view(1, 1, res, res, res).float()
mask = nnf.pad(mask, padding, mode='replicate')
mask = nnf.max_pool3d(mask, kernel_size, 1)
mask = nnf.interpolate(mask, scale_factor=stride)
mask = mask.flatten().bool()
samples[:, mask] = self.fill_samples(decoder, samples[:, mask])
return samples
def tune_resolution(self, res: int):
counter = 1
while res > self.min_res:
res = res // 2
counter *= 2
return res * counter
@staticmethod
def get_res_samples(res):
voxel_origin = torch.tensor([-1., -1., -1.])
voxel_size = 2.0 / (res - 1)
overall_index = torch.arange(0, res ** 3, 1, dtype=torch.int64)
samples = torch.ones(4, res ** 3).detach()
samples.requires_grad = False
# transform first 3 columns
# to be the x, y, z index
div_1 = torch.div(overall_index, res, rounding_mode='floor')
samples[2, :] = (overall_index % res).float()
samples[1, :] = (div_1 % res).float()
samples[0, :] = (torch.div(div_1, res, rounding_mode='floor') % res).float()
# transform first 3 columns
# to be the x, y, z coordinate
samples[:3] = samples[:3] * voxel_size + voxel_origin[:, None]
# samples[0, :] = (samples[:, 0] * voxel_size) + voxel_origin[2]
# samples[1, :] = (samples[:, 1] * voxel_size) + voxel_origin[1]
# samples[2, :] = (samples[:, 2] * voxel_size) + voxel_origin[0]
return samples
def register_resolution(self, res: int):
res = self.tune_resolution(res)
if res not in self.sample_cache:
samples = self.get_res_samples(res)
samples = samples.to(self.device)
self.sample_cache[res] = samples
else:
samples = self.sample_cache[res]
samples[3, :] = 1
return samples, res
def get_grid(self, decoder, res):
stride = 2
samples, res = self.register_resolution(res)
depth = int(np.ceil(np.log2(res) - np.log2(self.min_res)))
samples = self.fill_recursive(decoder, samples, stride, res, depth)
occ_values = samples[3]
occ_values = occ_values.reshape(res, res, res)
return occ_values
def occ_meshing(self, decoder, res: int = 256, get_time: bool = False, verbose=False):
start = time.time()
voxel_origin = [-1., -1., -1.]
voxel_size = 2.0 / (res - 1)
occ_values = self.get_grid(decoder, res)
if verbose:
end = time.time()
print("sampling took: %f" % (end - start))
if get_time:
return end - start
mesh_a = mcubes_skimage(occ_values.data.cpu(), voxel_origin, voxel_size)
# mesh_a = mcubes_torch(occ_values, voxel_origin, voxel_size)
if verbose:
end_b = time.time()
print("mcube took: %f" % (end_b - end))
print("meshing took: %f" % (end_b - start))
return mesh_a
def __init__(self, device: D, max_batch: int = 64 ** 3, min_res: int = 64, scale: float = 1, verbose: bool = False):
self.device = device
self.max_batch = 32 ** 3 if constants.IS_WINDOWS else max_batch
self.min_res = min_res
self.scale = scale
self.verbose = verbose
self.sample_cache = {}
def create_mesh_old(decoder, res=256, max_batch=64 ** 3, scale=1, device=CPU, verbose=False, get_time: bool = False):
meshing = MarchingCubesMeshing(device, max_batch=max_batch, scale=scale, verbose=verbose)
start = time.time()
# NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle
voxel_origin = [-1, -1, -1]
voxel_size = 2.0 / (res - 1)
overall_index = torch.arange(0, res ** 3, 1, out=torch.LongTensor())
samples = torch.zeros(res ** 3, 4)
# transform first 3 columns
# to be the x, y, z index
samples[:, 2] = overall_index % res
samples[:, 1] = (overall_index.long() // res) % res
samples[:, 0] = ((overall_index.long() // res) // res) % res
# transform first 3 columns
# to be the x, y, z coordinate
samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]
samples = meshing.fill_samples(decoder, samples, device=device)
sdf_values = samples[:, 3]
# return sdf_values, samples[:, :3]
sdf_values = sdf_values.reshape(res, res, res)
end = time.time()
print("sampling took: %f" % (end - start))
if get_time:
return end - start
return mcubes_skimage(
sdf_values.data.cpu(),
voxel_origin,
voxel_size,
)
|