Spaces:
Sleeping
Sleeping
import torch | |
import trimesh | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import logging | |
logger = logging.getLogger(__name__) | |
class PrimSDF(nn.Module): | |
def __init__(self, mesh_obj=None, f_sdf=None, geo_fn=None, asset_list=None, num_prims=1024, dim_feat=6, prim_shape=8, init_scale=0.05, sdf2alpha_var=0.005, auto_scale_init=True, init_sampling="uniform"): | |
super().__init__() | |
self.num_prims = num_prims | |
# 6 channels features - [SDF, R, G, B, roughness, metallic] | |
self.dim_feat = dim_feat | |
self.prim_shape = prim_shape | |
self.sdf_sampled_point = None | |
self.auto_scale_init = auto_scale_init | |
self.init_sampling = init_sampling | |
self.sdf2alpha_var = sdf2alpha_var | |
# assume the mesh is normalized to [-1, 1] cube | |
self.mesh_obj = mesh_obj | |
self.f_sdf = f_sdf | |
# N x (D x S^3 + 3(Global Translation) + 1(Global Scale)) | |
self.srt_param = nn.parameter.Parameter(torch.zeros(self.num_prims, 1 + 3)) | |
self.feat_param = nn.parameter.Parameter(torch.zeros(self.num_prims, self.dim_feat * (self.prim_shape ** 3))) | |
self.geo_start_index = 0 | |
self.geo_end_index = self.geo_start_index + self.prim_shape ** 3 # non-inclusive | |
self.tex_start_index = self.geo_end_index | |
self.tex_end_index = self.tex_start_index + self.prim_shape ** 3 * 3 # non-inclusive | |
self.mat_start_index = self.tex_end_index | |
self.mat_end_index = self.mat_start_index + self.prim_shape ** 3 * 2 | |
# sampled_point -> local grid | |
# local_grid - [prim_shape^3, 3] | |
xx = torch.linspace(-1, 1, self.prim_shape) | |
# two ways to sample xyz-axis aligned local grids: 1st is ij indexing | |
meshx, meshy, meshz = torch.meshgrid(xx, xx, xx, indexing='ij') | |
local_grid = torch.stack((meshz, meshy, meshx), dim=-1).reshape(-1, 3) | |
self.local_grid = local_grid | |
# second is xy indexing, equivalent to the first one | |
# meshx, meshy, meshz = torch.meshgrid(xx, xx, xx, indexing='xy') | |
# local_grid = torch.stack((meshz, meshx, meshy), dim=-1).reshape(-1, 3) | |
if self.f_sdf is not None and geo_fn is not None and asset_list is not None: | |
self._init_param(init_scale=init_scale, geo_fn=geo_fn, asset_list=asset_list, sampling=self.init_sampling) | |
def _init_param(self, init_scale, geo_fn, asset_list, sampling="uniform"): | |
pass | |
def forward(self, x): | |
# x - [bs, 3] | |
bs = x.shape[0] | |
weights = self.prim_weight(x) | |
output = self.grid_sample_feat(x, weights) | |
preds = {} | |
preds['sdf'] = output[:, 0:1] | |
# RGB | |
preds['tex'] = torch.clip(output[:, 1:4], min=0.0, max=1.0) | |
# roughness, metallic | |
preds['mat'] = torch.clip(output[:, 4:6], min=0.0, max=1.0) | |
return preds | |
def grid_sample_feat(self, x, weights): | |
# implementation of I_V -> trilinear grid sample of V_i | |
# x - [bs, 3] | |
# weights - [bs, n_prims] | |
bs = x.shape[0] | |
sampled_point = (x[:, None, :] - self.pos[None, ...]) / self.scale[None, ...] | |
mask = weights > 0 | |
ind_bs, ind_nprim = torch.where(weights > 0) | |
masked_sampled_point = sampled_point[ind_bs, ind_nprim, :].reshape(ind_nprim.shape[0], 1, 1, 1, 3) | |
feat4sample = self.feat[ind_nprim, :].reshape(ind_nprim.shape[0], self.dim_feat, self.prim_shape, self.prim_shape, self.prim_shape) | |
sampled_feat = F.grid_sample(feat4sample, masked_sampled_point, mode='bilinear', padding_mode='zeros', align_corners=True).reshape(ind_nprim.shape[0], self.dim_feat) | |
weighted_sampled_feat = sampled_feat * weights[mask][:, None] | |
weighted_feat = torch.zeros(bs, self.dim_feat).to(x) | |
weighted_feat.index_add_(0, ind_bs, weighted_sampled_feat) | |
# at inference time, fill in approximated SDF value for region not covered by prims | |
if not self.training: | |
# get mask for points not covered by prims | |
bs_mask = weights.sum(1) <= 0 | |
# get nearest prim index | |
dist = torch.norm(x[bs_mask, None, :] - self.pos[None, ...], p=2, dim=-1) | |
_, min_dist_ind = dist.min(1) | |
nearest_prim_pos = self.pos[min_dist_ind, :] | |
nearest_prim_scale = self.scale[min_dist_ind, :] | |
# in each nearest prim, get nearest voxel points | |
candidate_nearest_pts = nearest_prim_pos[:, None, :] + nearest_prim_scale[..., None] * self.local_grid.to(x)[None, :] | |
pts_dist = torch.norm(x[bs_mask, None, :] - candidate_nearest_pts, p=2, dim=-1) | |
min_dist, min_dist_pts_ind = pts_dist.min(1) | |
# get the SDF value as a nearest valid SDF value | |
min_pts_sdf = self.feat_geo[min_dist_ind, min_dist_pts_ind] | |
# approximate SDF value with the same sign distance + L2 distance | |
approx_sdf = min_pts_sdf + min_dist * torch.sign(min_pts_sdf) | |
weighted_feat[bs_mask, 0:1] = approx_sdf[:, None] | |
return weighted_feat | |
def prim_weight(self, x): | |
# x - [bs, 3] | |
weights = F.relu(1 - torch.norm((x[:, None, :] - self.pos[None, ...]) / self.scale[None, ...], p = float('inf'), dim=-1)) | |
# weight - [bs, N] | |
normalized_weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-6) | |
return normalized_weights | |
def sdf2alpha(self, sdf): | |
return torch.exp(-(sdf / self.sdf2alpha_var) ** 2) | |
def pos(self): | |
return self.srt_param[:, 1:4] | |
def scale(self): | |
return self.srt_param[:, 0:1] | |
def feat(self): | |
return self.feat_param | |
def feat_geo(self): | |
return self.feat_param[:, self.geo_start_index:self.geo_end_index] | |
def feat_tex(self): | |
return self.feat_param[:, self.tex_start_index:self.tex_end_index] | |
def feat_mat(self): | |
return self.feat_param[:, self.mat_start_index:self.mat_end_index] |