Spaces:
Running
on
L4
Running
on
L4
File size: 6,088 Bytes
81ecb2b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
import trimesh
import torch.nn as nn
import torch.nn.functional as F
import logging
logger = logging.getLogger(__name__)
class PrimSDF(nn.Module):
def __init__(self, mesh_obj=None, f_sdf=None, geo_fn=None, asset_list=None, num_prims=1024, dim_feat=6, prim_shape=8, init_scale=0.05, sdf2alpha_var=0.005, auto_scale_init=True, init_sampling="uniform"):
super().__init__()
self.num_prims = num_prims
# 6 channels features - [SDF, R, G, B, roughness, metallic]
self.dim_feat = dim_feat
self.prim_shape = prim_shape
self.sdf_sampled_point = None
self.auto_scale_init = auto_scale_init
self.init_sampling = init_sampling
self.sdf2alpha_var = sdf2alpha_var
# assume the mesh is normalized to [-1, 1] cube
self.mesh_obj = mesh_obj
self.f_sdf = f_sdf
# N x (D x S^3 + 3(Global Translation) + 1(Global Scale))
self.srt_param = nn.parameter.Parameter(torch.zeros(self.num_prims, 1 + 3))
self.feat_param = nn.parameter.Parameter(torch.zeros(self.num_prims, self.dim_feat * (self.prim_shape ** 3)))
self.geo_start_index = 0
self.geo_end_index = self.geo_start_index + self.prim_shape ** 3 # non-inclusive
self.tex_start_index = self.geo_end_index
self.tex_end_index = self.tex_start_index + self.prim_shape ** 3 * 3 # non-inclusive
self.mat_start_index = self.tex_end_index
self.mat_end_index = self.mat_start_index + self.prim_shape ** 3 * 2
# sampled_point -> local grid
# local_grid - [prim_shape^3, 3]
xx = torch.linspace(-1, 1, self.prim_shape)
# two ways to sample xyz-axis aligned local grids: 1st is ij indexing
meshx, meshy, meshz = torch.meshgrid(xx, xx, xx, indexing='ij')
local_grid = torch.stack((meshz, meshy, meshx), dim=-1).reshape(-1, 3)
self.local_grid = local_grid
# second is xy indexing, equivalent to the first one
# meshx, meshy, meshz = torch.meshgrid(xx, xx, xx, indexing='xy')
# local_grid = torch.stack((meshz, meshx, meshy), dim=-1).reshape(-1, 3)
if self.f_sdf is not None and geo_fn is not None and asset_list is not None:
self._init_param(init_scale=init_scale, geo_fn=geo_fn, asset_list=asset_list, sampling=self.init_sampling)
@torch.no_grad()
def _init_param(self, init_scale, geo_fn, asset_list, sampling="uniform"):
pass
def forward(self, x):
# x - [bs, 3]
bs = x.shape[0]
weights = self.prim_weight(x)
output = self.grid_sample_feat(x, weights)
preds = {}
preds['sdf'] = output[:, 0:1]
# RGB
preds['tex'] = torch.clip(output[:, 1:4], min=0.0, max=1.0)
# roughness, metallic
preds['mat'] = torch.clip(output[:, 4:6], min=0.0, max=1.0)
return preds
def grid_sample_feat(self, x, weights):
# implementation of I_V -> trilinear grid sample of V_i
# x - [bs, 3]
# weights - [bs, n_prims]
bs = x.shape[0]
sampled_point = (x[:, None, :] - self.pos[None, ...]) / self.scale[None, ...]
mask = weights > 0
ind_bs, ind_nprim = torch.where(weights > 0)
masked_sampled_point = sampled_point[ind_bs, ind_nprim, :].reshape(ind_nprim.shape[0], 1, 1, 1, 3)
feat4sample = self.feat[ind_nprim, :].reshape(ind_nprim.shape[0], self.dim_feat, self.prim_shape, self.prim_shape, self.prim_shape)
sampled_feat = F.grid_sample(feat4sample, masked_sampled_point, mode='bilinear', padding_mode='zeros', align_corners=True).reshape(ind_nprim.shape[0], self.dim_feat)
weighted_sampled_feat = sampled_feat * weights[mask][:, None]
weighted_feat = torch.zeros(bs, self.dim_feat).to(x)
weighted_feat.index_add_(0, ind_bs, weighted_sampled_feat)
# at inference time, fill in approximated SDF value for region not covered by prims
if not self.training:
# get mask for points not covered by prims
bs_mask = weights.sum(1) <= 0
# get nearest prim index
dist = torch.norm(x[bs_mask, None, :] - self.pos[None, ...], p=2, dim=-1)
_, min_dist_ind = dist.min(1)
nearest_prim_pos = self.pos[min_dist_ind, :]
nearest_prim_scale = self.scale[min_dist_ind, :]
# in each nearest prim, get nearest voxel points
candidate_nearest_pts = nearest_prim_pos[:, None, :] + nearest_prim_scale[..., None] * self.local_grid.to(x)[None, :]
pts_dist = torch.norm(x[bs_mask, None, :] - candidate_nearest_pts, p=2, dim=-1)
min_dist, min_dist_pts_ind = pts_dist.min(1)
# get the SDF value as a nearest valid SDF value
min_pts_sdf = self.feat_geo[min_dist_ind, min_dist_pts_ind]
# approximate SDF value with the same sign distance + L2 distance
approx_sdf = min_pts_sdf + min_dist * torch.sign(min_pts_sdf)
weighted_feat[bs_mask, 0:1] = approx_sdf[:, None]
return weighted_feat
def prim_weight(self, x):
# x - [bs, 3]
weights = F.relu(1 - torch.norm((x[:, None, :] - self.pos[None, ...]) / self.scale[None, ...], p = float('inf'), dim=-1))
# weight - [bs, N]
normalized_weights = weights / (torch.sum(weights, dim=-1, keepdim=True) + 1e-6)
return normalized_weights
def sdf2alpha(self, sdf):
return torch.exp(-(sdf / self.sdf2alpha_var) ** 2)
@property
def pos(self):
return self.srt_param[:, 1:4]
@property
def scale(self):
return self.srt_param[:, 0:1]
@property
def feat(self):
return self.feat_param
@property
def feat_geo(self):
return self.feat_param[:, self.geo_start_index:self.geo_end_index]
@property
def feat_tex(self):
return self.feat_param[:, self.tex_start_index:self.tex_end_index]
@property
def feat_mat(self):
return self.feat_param[:, self.mat_start_index:self.mat_end_index] |