Spaces:
Build error
Build error
from skimage import measure | |
import numpy as np | |
import torch | |
from .sdf import create_grid, eval_grid_octree, eval_grid | |
from skimage import measure | |
def reconstruction(net, cuda, calib_tensor, | |
resolution, b_min, b_max, | |
use_octree=False, num_samples=10000, transform=None): | |
''' | |
Reconstruct meshes from sdf predicted by the network. | |
:param net: a BasePixImpNet object. call image filter beforehead. | |
:param cuda: cuda device | |
:param calib_tensor: calibration tensor | |
:param resolution: resolution of the grid cell | |
:param b_min: bounding box corner [x_min, y_min, z_min] | |
:param b_max: bounding box corner [x_max, y_max, z_max] | |
:param use_octree: whether to use octree acceleration | |
:param num_samples: how many points to query each gpu iteration | |
:return: marching cubes results. | |
''' | |
# First we create a grid by resolution | |
# and transforming matrix for grid coordinates to real world xyz | |
coords, mat = create_grid(resolution, resolution, resolution, | |
b_min, b_max, transform=transform) | |
# Then we define the lambda function for cell evaluation | |
def eval_func(points): | |
points = np.expand_dims(points, axis=0) | |
points = np.repeat(points, net.num_views, axis=0) | |
samples = torch.from_numpy(points).to(device=cuda).float() | |
net.query(samples, calib_tensor) | |
pred = net.get_preds()[0][0] | |
return pred.detach().cpu().numpy() | |
# Then we evaluate the grid | |
if use_octree: | |
sdf = eval_grid_octree(coords, eval_func, num_samples=num_samples) | |
else: | |
sdf = eval_grid(coords, eval_func, num_samples=num_samples) | |
# Finally we do marching cubes | |
try: | |
verts, faces, normals, values = measure.marching_cubes_lewiner(sdf, 0.5) | |
# transform verts into world coordinate system | |
verts = np.matmul(mat[:3, :3], verts.T) + mat[:3, 3:4] | |
verts = verts.T | |
return verts, faces, normals, values | |
except: | |
print('error cannot marching cubes') | |
return -1 | |
def save_obj_mesh(mesh_path, verts, faces): | |
file = open(mesh_path, 'w') | |
for v in verts: | |
file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2])) | |
for f in faces: | |
f_plus = f + 1 | |
file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1])) | |
file.close() | |
def save_obj_mesh_with_color(mesh_path, verts, faces, colors): | |
file = open(mesh_path, 'w') | |
for idx, v in enumerate(verts): | |
c = colors[idx] | |
file.write('v %.4f %.4f %.4f %.4f %.4f %.4f\n' % (v[0], v[1], v[2], c[0], c[1], c[2])) | |
for f in faces: | |
f_plus = f + 1 | |
file.write('f %d %d %d\n' % (f_plus[0], f_plus[2], f_plus[1])) | |
file.close() | |
def save_obj_mesh_with_uv(mesh_path, verts, faces, uvs): | |
file = open(mesh_path, 'w') | |
for idx, v in enumerate(verts): | |
vt = uvs[idx] | |
file.write('v %.4f %.4f %.4f\n' % (v[0], v[1], v[2])) | |
file.write('vt %.4f %.4f\n' % (vt[0], vt[1])) | |
for f in faces: | |
f_plus = f + 1 | |
file.write('f %d/%d %d/%d %d/%d\n' % (f_plus[0], f_plus[0], | |
f_plus[2], f_plus[2], | |
f_plus[1], f_plus[1])) | |
file.close() | |