Spaces:
Sleeping
Sleeping
# THE CODE WAS TAKEN AND ADAPTED FROM https://pengsongyou.github.io/sap | |
# @inproceedings{Peng2021SAP, | |
# author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Geiger, Andreas}, | |
# title = {Shape As Points: A Differentiable Poisson Solver}, | |
# booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, | |
# year = {2021} | |
# } | |
import torch | |
import io, os, logging, urllib | |
import yaml | |
import trimesh | |
import imageio | |
import numbers | |
import math | |
import numpy as np | |
from collections import OrderedDict | |
from plyfile import PlyData | |
from torch import nn | |
from torch.nn import functional as F | |
from torch.utils import model_zoo | |
from skimage import measure, img_as_float32 | |
from igl import adjacency_matrix, connected_components | |
################################################## | |
# Below are functions for DPSR | |
def fftfreqs(res, dtype=torch.float32, exact=True): | |
""" | |
Helper function to return frequency tensors | |
:param res: n_dims int tuple of number of frequency modes | |
:return: | |
""" | |
n_dims = len(res) | |
freqs = [] | |
for dim in range(n_dims - 1): | |
r_ = res[dim] | |
freq = np.fft.fftfreq(r_, d=1/r_) | |
freqs.append(torch.tensor(freq, dtype=dtype)) | |
r_ = res[-1] | |
if exact: | |
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_), dtype=dtype)) | |
else: | |
freqs.append(torch.tensor(np.fft.rfftfreq(r_, d=1/r_)[:-1], dtype=dtype)) | |
omega = torch.meshgrid(freqs) | |
omega = list(omega) | |
omega = torch.stack(omega, dim=-1) | |
return omega | |
def img(x, deg=1): # imaginary of tensor (assume last dim: real/imag) | |
""" | |
multiply tensor x by i ** deg | |
""" | |
deg %= 4 | |
if deg == 0: | |
res = x | |
elif deg == 1: | |
res = x[..., [1, 0]] | |
res[..., 0] = -res[..., 0] | |
elif deg == 2: | |
res = -x | |
elif deg == 3: | |
res = x[..., [1, 0]] | |
res[..., 1] = -res[..., 1] | |
return res | |
def spec_gaussian_filter(res, sig): | |
omega = fftfreqs(res, dtype=torch.float64) # [dim0, dim1, dim2, d] | |
dis = torch.sqrt(torch.sum(omega ** 2, dim=-1)) | |
filter_ = torch.exp(-0.5*((sig*2*dis/res[0])**2)).unsqueeze(-1).unsqueeze(-1) | |
filter_.requires_grad = False | |
return filter_ | |
def grid_interp(grid, pts, batched=True): | |
""" | |
:param grid: tensor of shape (batch, *size, in_features) | |
:param pts: tensor of shape (batch, num_points, dim) within range (0, 1) | |
:return values at query points | |
""" | |
if not batched: | |
grid = grid.unsqueeze(0) | |
pts = pts.unsqueeze(0) | |
dim = pts.shape[-1] | |
bs = grid.shape[0] | |
size = torch.tensor(grid.shape[1:-1]).to(grid.device).type(pts.dtype) | |
cubesize = 1.0 / size | |
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim) | |
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around | |
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim) | |
tmp = torch.tensor([0,1],dtype=torch.long) | |
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim) | |
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim) | |
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points) | |
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) | |
ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) | |
# latent code on neighbor nodes | |
if dim == 2: | |
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1]] # (batch, num_points, 2**dim, in_features) | |
else: | |
lat = grid.clone()[ind_b, ind_n[..., 0], ind_n[..., 1], ind_n[..., 2]] # (batch, num_points, 2**dim, in_features) | |
# weights of neighboring nodes | |
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim) | |
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim) | |
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim) | |
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) | |
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) | |
pos_ = pos_.type(pts.dtype) | |
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim) | |
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim) | |
query_values = torch.sum(lat * weights.unsqueeze(-1), dim=-2) # (batch, num_points, in_features) | |
if not batched: | |
query_values = query_values.squeeze(0) | |
return query_values | |
def scatter_to_grid(inds, vals, size): | |
""" | |
Scatter update values into empty tensor of size size. | |
:param inds: (#values, dims) | |
:param vals: (#values) | |
:param size: tuple for size. len(size)=dims | |
""" | |
dims = inds.shape[1] | |
assert(inds.shape[0] == vals.shape[0]) | |
assert(len(size) == dims) | |
dev = vals.device | |
# result = torch.zeros(*size).view(-1).to(dev).type(vals.dtype) # flatten | |
# # flatten inds | |
result = torch.zeros(*size, device=dev).view(-1).type(vals.dtype) # flatten | |
# flatten inds | |
fac = [np.prod(size[i+1:]) for i in range(len(size)-1)] + [1] | |
fac = torch.tensor(fac, device=dev).type(inds.dtype) | |
inds_fold = torch.sum(inds*fac, dim=-1) # [#values,] | |
result.scatter_add_(0, inds_fold, vals) | |
result = result.view(*size) | |
return result | |
def point_rasterize(pts, vals, size): | |
""" | |
:param pts: point coords, tensor of shape (batch, num_points, dim) within range (0, 1) | |
:param vals: point values, tensor of shape (batch, num_points, features) | |
:param size: len(size)=dim tuple for grid size | |
:return rasterized values (batch, features, res0, res1, res2) | |
""" | |
dim = pts.shape[-1] | |
assert(pts.shape[:2] == vals.shape[:2]) | |
assert(pts.shape[2] == dim) | |
size_list = list(size) | |
size = torch.tensor(size).to(pts.device).float() | |
cubesize = 1.0 / size | |
bs = pts.shape[0] | |
nf = vals.shape[-1] | |
npts = pts.shape[1] | |
dev = pts.device | |
ind0 = torch.floor(pts / cubesize).long() # (batch, num_points, dim) | |
ind1 = torch.fmod(torch.ceil(pts / cubesize), size).long() # periodic wrap-around | |
ind01 = torch.stack((ind0, ind1), dim=0) # (2, batch, num_points, dim) | |
tmp = torch.tensor([0,1],dtype=torch.long) | |
com_ = torch.stack(torch.meshgrid(tuple([tmp] * dim)), dim=-1).view(-1, dim) | |
dim_ = torch.arange(dim).repeat(com_.shape[0], 1) # (2**dim, dim) | |
ind_ = ind01[com_, ..., dim_] # (2**dim, dim, batch, num_points) | |
ind_n = ind_.permute(2, 3, 0, 1) # (batch, num_points, 2**dim, dim) | |
# ind_b = torch.arange(bs).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) | |
ind_b = torch.arange(bs, device=dev).expand(ind_n.shape[1], ind_n.shape[2], bs).permute(2, 0, 1) # (batch, num_points, 2**dim) | |
# weights of neighboring nodes | |
xyz0 = ind0.type(cubesize.dtype) * cubesize # (batch, num_points, dim) | |
xyz1 = (ind0.type(cubesize.dtype) + 1) * cubesize # (batch, num_points, dim) | |
xyz01 = torch.stack((xyz0, xyz1), dim=0) # (2, batch, num_points, dim) | |
pos = xyz01[com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) | |
pos_ = xyz01[1-com_, ..., dim_].permute(2,3,0,1) # (batch, num_points, 2**dim, dim) | |
pos_ = pos_.type(pts.dtype) | |
dxyz_ = torch.abs(pts.unsqueeze(-2) - pos_) / cubesize # (batch, num_points, 2**dim, dim) | |
weights = torch.prod(dxyz_, dim=-1, keepdim=False) # (batch, num_points, 2**dim) | |
ind_b = ind_b.unsqueeze(-1).unsqueeze(-1) # (batch, num_points, 2**dim, 1, 1) | |
ind_n = ind_n.unsqueeze(-2) # (batch, num_points, 2**dim, 1, dim) | |
ind_f = torch.arange(nf, device=dev).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1) | |
# ind_f = torch.arange(nf).view(1, 1, 1, nf, 1) # (1, 1, 1, nf, 1) | |
ind_b = ind_b.expand(bs, npts, 2**dim, nf, 1) | |
ind_n = ind_n.expand(bs, npts, 2**dim, nf, dim).to(dev) | |
ind_f = ind_f.expand(bs, npts, 2**dim, nf, 1) | |
inds = torch.cat([ind_b, ind_f, ind_n], dim=-1) # (batch, num_points, 2**dim, nf, 1+1+dim) | |
# weighted values | |
vals = weights.unsqueeze(-1) * vals.unsqueeze(-2) # (batch, num_points, 2**dim, nf) | |
inds = inds.view(-1, dim+2).permute(1, 0).long() # (1+dim+1, bs*npts*2**dim*nf) | |
vals = vals.reshape(-1) # (bs*npts*2**dim*nf) | |
tensor_size = [bs, nf] + size_list | |
raster = scatter_to_grid(inds.permute(1, 0), vals, [bs, nf] + size_list) | |
return raster # [batch, nf, res, res, res] | |
################################################## | |
# Below are the utilization functions in general | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.n = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.n = n | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |
def valcavg(self): | |
return self.val.sum().item() / (self.n != 0).sum().item() | |
def avgcavg(self): | |
return self.avg.sum().item() / (self.count != 0).sum().item() | |
def load_model_manual(state_dict, model): | |
new_state_dict = OrderedDict() | |
is_model_parallel = isinstance(model, torch.nn.DataParallel) | |
for k, v in state_dict.items(): | |
if k.startswith('module.') != is_model_parallel: | |
if k.startswith('module.'): | |
# remove module | |
k = k[7:] | |
else: | |
# add module | |
k = 'module.' + k | |
new_state_dict[k]=v | |
model.load_state_dict(new_state_dict) | |
def mc_from_psr(psr_grid, pytorchify=False, real_scale=False, zero_level=0): | |
''' | |
Run marching cubes from PSR grid | |
''' | |
batch_size = psr_grid.shape[0] | |
s = psr_grid.shape[-1] # size of psr_grid | |
psr_grid_numpy = psr_grid.squeeze().detach().cpu().numpy() | |
if batch_size>1: | |
verts, faces, normals = [], [], [] | |
for i in range(batch_size): | |
verts_cur, faces_cur, normals_cur, values = measure.marching_cubes(psr_grid_numpy[i], level=0) | |
verts.append(verts_cur) | |
faces.append(faces_cur) | |
normals.append(normals_cur) | |
verts = np.stack(verts, axis = 0) | |
faces = np.stack(faces, axis = 0) | |
normals = np.stack(normals, axis = 0) | |
else: | |
try: | |
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy, level=zero_level) | |
except: | |
verts, faces, normals, values = measure.marching_cubes(psr_grid_numpy) | |
if real_scale: | |
verts = verts / (s-1) # scale to range [0, 1] | |
else: | |
verts = verts / s # scale to range [0, 1) | |
if pytorchify: | |
device = psr_grid.device | |
verts = torch.Tensor(np.ascontiguousarray(verts)).to(device) | |
faces = torch.Tensor(np.ascontiguousarray(faces)).to(device) | |
normals = torch.Tensor(np.ascontiguousarray(-normals)).to(device) | |
return verts, faces, normals | |
def calc_inters_points(verts, faces, pose, img_size, mask_gt=None): | |
verts = verts.squeeze() | |
faces = faces.squeeze() | |
pix_to_face, w, mask = mesh_rasterization(verts, faces, pose, img_size) | |
if mask_gt is not None: | |
#! only evaluate within the intersection | |
mask = mask & mask_gt | |
# find 3D points intesected on the mesh | |
if True: | |
w_masked = w[mask] | |
f_p = faces[pix_to_face[mask]].long() # cooresponding faces for each pixel | |
# corresponding vertices for p_closest | |
v_a, v_b, v_c = verts[f_p[..., 0]], verts[f_p[..., 1]], verts[f_p[..., 2]] | |
# calculate the intersection point of each pixel and the mesh | |
p_inters = w_masked[..., 0, None] * v_a + \ | |
w_masked[..., 1, None] * v_b + \ | |
w_masked[..., 2, None] * v_c | |
else: | |
# backproject ndc to world coordinates using z-buffer | |
W, H = img_size[1], img_size[0] | |
xy = uv.to(mask.device)[mask] | |
x_ndc = 1 - (2*xy[:, 0]) / (W - 1) | |
y_ndc = 1 - (2*xy[:, 1]) / (H - 1) | |
z = zbuf.squeeze().reshape(H * W)[mask] | |
xy_depth = torch.stack((x_ndc, y_ndc, z), dim=1) | |
p_inters = pose.unproject_points(xy_depth, world_coordinates=True) | |
# if there are outlier points, we should remove it | |
if (p_inters.max()>1) | (p_inters.min()<-1): | |
mask_bound = (p_inters>=-1) & (p_inters<=1) | |
mask_bound = (mask_bound.sum(dim=-1)==3) | |
mask[mask==True] = mask_bound | |
p_inters = p_inters[mask_bound] | |
print('!!!!!find outlier!') | |
return p_inters, mask, f_p, w_masked | |
def mesh_rasterization(verts, faces, pose, img_size): | |
''' | |
Use PyTorch3D to rasterize the mesh given a camera | |
''' | |
transformed_v = pose.transform_points(verts.detach()) # world -> ndc coordinate system | |
if isinstance(pose, PerspectiveCameras): | |
transformed_v[..., 2] = 1/transformed_v[..., 2] | |
# find p_closest on mesh of each pixel via rasterization | |
transformed_mesh = Meshes(verts=[transformed_v], faces=[faces]) | |
pix_to_face, zbuf, bary_coords, dists = rasterize_meshes( | |
transformed_mesh, | |
image_size=img_size, | |
blur_radius=0, | |
faces_per_pixel=1, | |
perspective_correct=False | |
) | |
pix_to_face = pix_to_face.reshape(1, -1) # B x reso x reso -> B x (reso x reso) | |
mask = pix_to_face.clone() != -1 | |
mask = mask.squeeze() | |
pix_to_face = pix_to_face.squeeze() | |
w = bary_coords.reshape(-1, 3) | |
return pix_to_face, w, mask | |
def verts_on_largest_mesh(verts, faces): | |
''' | |
verts: Numpy array or Torch.Tensor (N, 3) | |
faces: Numpy array (N, 3) | |
''' | |
if torch.is_tensor(faces): | |
verts = verts.squeeze().detach().cpu().numpy() | |
faces = faces.squeeze().int().detach().cpu().numpy() | |
A = adjacency_matrix(faces) | |
num, conn_idx, conn_size = connected_components(A) | |
if num == 0: | |
v_large, f_large = verts, faces | |
else: | |
max_idx = conn_size.argmax() # find the index of the largest component | |
v_large = verts[conn_idx==max_idx] # keep points on the largest component | |
if True: | |
mesh_largest = trimesh.Trimesh(verts, faces) | |
connected_comp = mesh_largest.split(only_watertight=False) | |
mesh_largest = connected_comp[max_idx] | |
v_large, f_large = mesh_largest.vertices, mesh_largest.faces | |
v_large = v_large.astype(np.float32) | |
return v_large, f_large | |
def update_recursive(dict1, dict2): | |
''' Update two config dictionaries recursively. | |
Args: | |
dict1 (dict): first dictionary to be updated | |
dict2 (dict): second dictionary which entries should be used | |
''' | |
for k, v in dict2.items(): | |
if k not in dict1: | |
dict1[k] = dict() | |
if isinstance(v, dict): | |
update_recursive(dict1[k], v) | |
else: | |
dict1[k] = v | |
def scale2onet(p, scale=1.2): | |
''' | |
Scale the point cloud from SAP to ONet range | |
''' | |
return (p - 0.5) * scale | |
def update_optimizer(inputs, cfg, epoch, model=None, schedule=None): | |
if model is not None: | |
if schedule is not None: | |
optimizer = torch.optim.Adam([ | |
{"params": model.parameters(), | |
"lr": schedule[0].get_learning_rate(epoch)}, | |
{"params": inputs, | |
"lr": schedule[1].get_learning_rate(epoch)}]) | |
elif 'lr' in cfg['train']: | |
optimizer = torch.optim.Adam([ | |
{"params": model.parameters(), | |
"lr": float(cfg['train']['lr'])}, | |
{"params": inputs, | |
"lr": float(cfg['train']['lr_pcl'])}]) | |
else: | |
raise Exception('no known learning rate') | |
else: | |
if schedule is not None: | |
optimizer = torch.optim.Adam([inputs], lr=schedule[0].get_learning_rate(epoch)) | |
else: | |
optimizer = torch.optim.Adam([inputs], lr=float(cfg['train']['lr_pcl'])) | |
return optimizer | |
def is_url(url): | |
scheme = urllib.parse.urlparse(url).scheme | |
return scheme in ('http', 'https') | |
def load_url(url): | |
'''Load a module dictionary from url. | |
Args: | |
url (str): url to saved model | |
''' | |
print(url) | |
print('=> Loading checkpoint from url...') | |
state_dict = model_zoo.load_url(url, progress=True) | |
return state_dict | |
class GaussianSmoothing(nn.Module): | |
""" | |
Apply gaussian smoothing on a | |
1d, 2d or 3d tensor. Filtering is performed seperately for each channel | |
in the input using a depthwise convolution. | |
Arguments: | |
channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. | |
kernel_size (int, sequence): Size of the gaussian kernel. | |
sigma (float, sequence): Standard deviation of the gaussian kernel. | |
dim (int, optional): The number of dimensions of the data. | |
Default value is 2 (spatial). | |
""" | |
def __init__(self, channels, kernel_size, sigma, dim=3): | |
super(GaussianSmoothing, self).__init__() | |
if isinstance(kernel_size, numbers.Number): | |
kernel_size = [kernel_size] * dim | |
if isinstance(sigma, numbers.Number): | |
sigma = [sigma] * dim | |
# The gaussian kernel is the product of the | |
# gaussian function of each dimension. | |
kernel = 1 | |
meshgrids = torch.meshgrid( | |
[ | |
torch.arange(size, dtype=torch.float32) | |
for size in kernel_size | |
] | |
) | |
for size, std, mgrid in zip(kernel_size, sigma, meshgrids): | |
mean = (size - 1) / 2 | |
kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ | |
torch.exp(-((mgrid - mean) / std) ** 2 / 2) | |
# Make sure sum of values in gaussian kernel equals 1. | |
kernel = kernel / torch.sum(kernel) | |
# Reshape to depthwise convolutional weight | |
kernel = kernel.view(1, 1, *kernel.size()) | |
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) | |
self.register_buffer('weight', kernel) | |
self.groups = channels | |
if dim == 1: | |
self.conv = F.conv1d | |
elif dim == 2: | |
self.conv = F.conv2d | |
elif dim == 3: | |
self.conv = F.conv3d | |
else: | |
raise RuntimeError( | |
'Only 1, 2 and 3 dimensions are supported. Received {}.'.format(dim) | |
) | |
def forward(self, input): | |
""" | |
Apply gaussian filter to input. | |
Arguments: | |
input (torch.Tensor): Input to apply gaussian filter on. | |
Returns: | |
filtered (torch.Tensor): Filtered output. | |
""" | |
return self.conv(input, weight=self.weight, groups=self.groups) | |
# Originally from https://github.com/amosgropp/IGR/blob/0db06b1273/code/utils/general.py | |
def get_learning_rate_schedules(schedule_specs): | |
schedules = [] | |
for key in schedule_specs.keys(): | |
schedules.append(StepLearningRateSchedule( | |
schedule_specs[key]['initial'], | |
schedule_specs[key]["interval"], | |
schedule_specs[key]["factor"], | |
schedule_specs[key]["final"])) | |
return schedules | |
class LearningRateSchedule: | |
def get_learning_rate(self, epoch): | |
pass | |
class StepLearningRateSchedule(LearningRateSchedule): | |
def __init__(self, initial, interval, factor, final=1e-6): | |
self.initial = float(initial) | |
self.interval = interval | |
self.factor = factor | |
self.final = float(final) | |
def get_learning_rate(self, epoch): | |
lr = np.maximum(self.initial * (self.factor ** (epoch // self.interval)), 5.0e-6) | |
if lr > self.final: | |
return lr | |
else: | |
return self.final | |
def adjust_learning_rate(lr_schedules, optimizer, epoch): | |
for i, param_group in enumerate(optimizer.param_groups): | |
param_group["lr"] = lr_schedules[i].get_learning_rate(epoch) |