Spaces:
Sleeping
Sleeping
# THE CODE WAS TAKEN AND ADAPTED FROM https://pengsongyou.github.io/sap | |
# @inproceedings{Peng2021SAP, | |
# author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Geiger, Andreas}, | |
# title = {Shape As Points: A Differentiable Poisson Solver}, | |
# booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, | |
# year = {2021} | |
# } | |
import torch | |
import numpy as np | |
import time | |
from .utils import point_rasterize, grid_interp, mc_from_psr, \ | |
calc_inters_points | |
from .dpsr import DPSR | |
import torch.nn as nn | |
class PSR2Mesh(torch.autograd.Function): | |
def forward(ctx, psr_grid): | |
""" | |
In the forward pass we receive a Tensor containing the input and return | |
a Tensor containing the output. ctx is a context object that can be used | |
to stash information for backward computation. You can cache arbitrary | |
objects for use in the backward pass using the ctx.save_for_backward method. | |
""" | |
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True) | |
verts = verts.unsqueeze(0) | |
faces = faces.unsqueeze(0) | |
normals = normals.unsqueeze(0) | |
res = torch.tensor(psr_grid.detach().shape[2]) | |
ctx.save_for_backward(verts, normals, res) | |
return verts, faces, normals | |
def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals): | |
""" | |
In the backward pass we receive a Tensor containing the gradient of the loss | |
with respect to the output, and we need to compute the gradient of the loss | |
with respect to the input. | |
""" | |
vert_pts, normals, res = ctx.saved_tensors | |
res = (res.item(), res.item(), res.item()) | |
# matrix multiplication between dL/dV and dV/dPSR | |
# dV/dPSR = - normals | |
grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0)) | |
grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res | |
return grad_grid | |
class PSR2SurfacePoints(torch.autograd.Function): | |
def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample): | |
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True) | |
verts = verts * 2. - 1. # within the range of [-1, 1] | |
p_all, n_all, mask_all = [], [], [] | |
for i in range(len(poses)): | |
pose = poses[i] | |
if mask_sample is not None: | |
p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size, mask_gt=mask_sample[i]) | |
else: | |
p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size) | |
n_inters = grid_interp(psr_grad[None], (p_inters[None].detach() + 1) / 2).squeeze() | |
p_all.append(p_inters) | |
n_all.append(n_inters) | |
mask_all.append(mask) | |
p_inters_all = torch.cat(p_all, dim=0) | |
n_inters_all = torch.cat(n_all, dim=0) | |
mask_visible = torch.stack(mask_all, dim=0) | |
res = torch.tensor(psr_grid.detach().shape[2]) | |
ctx.save_for_backward(p_inters_all, n_inters_all, res) | |
return p_inters_all, mask_visible | |
def backward(ctx, dL_dp, dL_dmask): | |
pts, pts_n, res = ctx.saved_tensors | |
res = (res.item(), res.item(), res.item()) | |
# grad from the p_inters via MLP renderer | |
grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None]) | |
grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res | |
return grad_grid_pts, None, None, None, None, None | |
# Resnet Blocks from https://github.com/autonomousvision/shape_as_points/blob/12757682f1075d83738b52f96747463b77343caf/src/network/utils.py | |
class ResnetBlockFC(nn.Module): | |
''' Fully connected ResNet Block class. | |
Args: | |
size_in (int): input dimension | |
size_out (int): output dimension | |
size_h (int): hidden dimension | |
''' | |
def __init__(self, size_in, size_out=None, size_h=None, siren=False): | |
super().__init__() | |
# Attributes | |
if size_out is None: | |
size_out = size_in | |
if size_h is None: | |
size_h = min(size_in, size_out) | |
self.size_in = size_in | |
self.size_h = size_h | |
self.size_out = size_out | |
# Submodules | |
self.fc_0 = nn.Linear(size_in, size_h) | |
self.fc_1 = nn.Linear(size_h, size_out) | |
self.actvn = nn.ReLU() | |
if size_in == size_out: | |
self.shortcut = None | |
else: | |
self.shortcut = nn.Linear(size_in, size_out, bias=False) | |
# Initialization | |
nn.init.zeros_(self.fc_1.weight) | |
def forward(self, x): | |
net = self.fc_0(self.actvn(x)) | |
dx = self.fc_1(self.actvn(net)) | |
if self.shortcut is not None: | |
x_s = self.shortcut(x) | |
else: | |
x_s = x | |
return x_s + dx | |