salad-demo / salad /spaghetti /utils /mcubes_meshing.py
DveloperY0115's picture
init repo
801501a
import skimage.measure
import time
from ..custom_types import *
from .train_utils import Logger
from .. import constants
def mcubes_skimage(pytorch_3d_occ_tensor: T, voxel_grid_origin: List[float], voxel_size: float) -> T_Mesh:
numpy_3d_occ_tensor = pytorch_3d_occ_tensor.numpy()
try:
marching_cubes = skimage.measure.marching_cubes if 'marching_cubes' in dir(skimage.measure) else skimage.measure.marching_cubes_lewiner
verts, faces, normals, values = marching_cubes(numpy_3d_occ_tensor, level=0.0, spacing=[voxel_size] * 3)
except BaseException:
print("mc failed")
return None
mesh_points = np.zeros_like(verts)
mesh_points[:, 0] = voxel_grid_origin[0] + verts[:, 0]
mesh_points[:, 1] = voxel_grid_origin[1] + verts[:, 1]
mesh_points[:, 2] = voxel_grid_origin[2] + verts[:, 2]
return torch.from_numpy(mesh_points.copy()).float(), torch.from_numpy(faces.copy()).long()
class MarchingCubesMeshing:
def fill_samples(self, decoder, samples, device: Optional[D] = None) -> T:
num_samples = samples.shape[1]
num_iters = num_samples // self.max_batch + int(num_samples % self.max_batch != 0)
sample_coords = samples[:3]
if self.verbose:
logger = Logger()
logger.start(num_iters, tag='meshing')
for i in range(num_iters):
sample_subset = sample_coords[:, i * self.max_batch: min((i + 1) * self.max_batch, num_samples)]
if device is not None:
sample_subset = sample_subset.to(device)
sample_subset = sample_subset.T
samples[3, i * self.max_batch: min((i + 1) * self.max_batch, num_samples)] = (
decoder(sample_subset * self.scale).squeeze().detach()
)
if self.verbose:
logger.reset_iter()
if self.verbose:
logger.stop()
return samples
def fill_recursive(self, decoder, samples: T, stride: int, base_res: int, depth: int) -> T:
if base_res <= self.min_res:
samples_ = self.fill_samples(decoder, samples)
return samples_
kernel_size = 7 + 4 * depth
padding = tuple([kernel_size // 2] * 6)
samples_ = samples.view(1, 4, base_res, base_res, base_res)
samples_ = nnf.avg_pool3d(samples_, stride, stride)
samples_ = samples_.view(4, -1)
res = base_res // stride
samples_lower = self.fill_recursive(decoder, samples_, stride, res, depth - 1)
mask = samples_lower[-1, :].lt(.3)
mask = mask.view(1, 1, res, res, res).float()
mask = nnf.pad(mask, padding, mode='replicate')
mask = nnf.max_pool3d(mask, kernel_size, 1)
mask = nnf.interpolate(mask, scale_factor=stride)
mask = mask.flatten().bool()
samples[:, mask] = self.fill_samples(decoder, samples[:, mask])
return samples
def tune_resolution(self, res: int):
counter = 1
while res > self.min_res:
res = res // 2
counter *= 2
return res * counter
@staticmethod
def get_res_samples(res):
voxel_origin = torch.tensor([-1., -1., -1.])
voxel_size = 2.0 / (res - 1)
overall_index = torch.arange(0, res ** 3, 1, dtype=torch.int64)
samples = torch.ones(4, res ** 3).detach()
samples.requires_grad = False
# transform first 3 columns
# to be the x, y, z index
div_1 = torch.div(overall_index, res, rounding_mode='floor')
samples[2, :] = (overall_index % res).float()
samples[1, :] = (div_1 % res).float()
samples[0, :] = (torch.div(div_1, res, rounding_mode='floor') % res).float()
# transform first 3 columns
# to be the x, y, z coordinate
samples[:3] = samples[:3] * voxel_size + voxel_origin[:, None]
# samples[0, :] = (samples[:, 0] * voxel_size) + voxel_origin[2]
# samples[1, :] = (samples[:, 1] * voxel_size) + voxel_origin[1]
# samples[2, :] = (samples[:, 2] * voxel_size) + voxel_origin[0]
return samples
def register_resolution(self, res: int):
res = self.tune_resolution(res)
if res not in self.sample_cache:
samples = self.get_res_samples(res)
samples = samples.to(self.device)
self.sample_cache[res] = samples
else:
samples = self.sample_cache[res]
samples[3, :] = 1
return samples, res
def get_grid(self, decoder, res):
stride = 2
samples, res = self.register_resolution(res)
depth = int(np.ceil(np.log2(res) - np.log2(self.min_res)))
samples = self.fill_recursive(decoder, samples, stride, res, depth)
occ_values = samples[3]
occ_values = occ_values.reshape(res, res, res)
return occ_values
def occ_meshing(self, decoder, res: int = 256, get_time: bool = False, verbose=False):
start = time.time()
voxel_origin = [-1., -1., -1.]
voxel_size = 2.0 / (res - 1)
occ_values = self.get_grid(decoder, res)
if verbose:
end = time.time()
print("sampling took: %f" % (end - start))
if get_time:
return end - start
mesh_a = mcubes_skimage(occ_values.data.cpu(), voxel_origin, voxel_size)
# mesh_a = mcubes_torch(occ_values, voxel_origin, voxel_size)
if verbose:
end_b = time.time()
print("mcube took: %f" % (end_b - end))
print("meshing took: %f" % (end_b - start))
return mesh_a
def __init__(self, device: D, max_batch: int = 64 ** 3, min_res: int = 64, scale: float = 1, verbose: bool = False):
self.device = device
self.max_batch = 32 ** 3 if constants.IS_WINDOWS else max_batch
self.min_res = min_res
self.scale = scale
self.verbose = verbose
self.sample_cache = {}
def create_mesh_old(decoder, res=256, max_batch=64 ** 3, scale=1, device=CPU, verbose=False, get_time: bool = False):
meshing = MarchingCubesMeshing(device, max_batch=max_batch, scale=scale, verbose=verbose)
start = time.time()
# NOTE: the voxel_origin is actually the (bottom, left, down) corner, not the middle
voxel_origin = [-1, -1, -1]
voxel_size = 2.0 / (res - 1)
overall_index = torch.arange(0, res ** 3, 1, out=torch.LongTensor())
samples = torch.zeros(res ** 3, 4)
# transform first 3 columns
# to be the x, y, z index
samples[:, 2] = overall_index % res
samples[:, 1] = (overall_index.long() // res) % res
samples[:, 0] = ((overall_index.long() // res) // res) % res
# transform first 3 columns
# to be the x, y, z coordinate
samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]
samples = meshing.fill_samples(decoder, samples, device=device)
sdf_values = samples[:, 3]
# return sdf_values, samples[:, :3]
sdf_values = sdf_values.reshape(res, res, res)
end = time.time()
print("sampling took: %f" % (end - start))
if get_time:
return end - start
return mcubes_skimage(
sdf_values.data.cpu(),
voxel_origin,
voxel_size,
)