# THE CODE WAS TAKEN AND ADAPTED FROM https://pengsongyou.github.io/sap # @inproceedings{Peng2021SAP, # author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Geiger, Andreas}, # title = {Shape As Points: A Differentiable Poisson Solver}, # booktitle = {Advances in Neural Information Processing Systems (NeurIPS)}, # year = {2021} # } import torch import numpy as np import time from .utils import point_rasterize, grid_interp, mc_from_psr, \ calc_inters_points from .dpsr import DPSR import torch.nn as nn class PSR2Mesh(torch.autograd.Function): @staticmethod def forward(ctx, psr_grid): """ In the forward pass we receive a Tensor containing the input and return a Tensor containing the output. ctx is a context object that can be used to stash information for backward computation. You can cache arbitrary objects for use in the backward pass using the ctx.save_for_backward method. """ verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True) verts = verts.unsqueeze(0) faces = faces.unsqueeze(0) normals = normals.unsqueeze(0) res = torch.tensor(psr_grid.detach().shape[2]) ctx.save_for_backward(verts, normals, res) return verts, faces, normals @staticmethod def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals): """ In the backward pass we receive a Tensor containing the gradient of the loss with respect to the output, and we need to compute the gradient of the loss with respect to the input. """ vert_pts, normals, res = ctx.saved_tensors res = (res.item(), res.item(), res.item()) # matrix multiplication between dL/dV and dV/dPSR # dV/dPSR = - normals grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0)) grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res return grad_grid class PSR2SurfacePoints(torch.autograd.Function): @staticmethod def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample): verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True) verts = verts * 2. - 1. # within the range of [-1, 1] p_all, n_all, mask_all = [], [], [] for i in range(len(poses)): pose = poses[i] if mask_sample is not None: p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size, mask_gt=mask_sample[i]) else: p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size) n_inters = grid_interp(psr_grad[None], (p_inters[None].detach() + 1) / 2).squeeze() p_all.append(p_inters) n_all.append(n_inters) mask_all.append(mask) p_inters_all = torch.cat(p_all, dim=0) n_inters_all = torch.cat(n_all, dim=0) mask_visible = torch.stack(mask_all, dim=0) res = torch.tensor(psr_grid.detach().shape[2]) ctx.save_for_backward(p_inters_all, n_inters_all, res) return p_inters_all, mask_visible @staticmethod def backward(ctx, dL_dp, dL_dmask): pts, pts_n, res = ctx.saved_tensors res = (res.item(), res.item(), res.item()) # grad from the p_inters via MLP renderer grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None]) grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res return grad_grid_pts, None, None, None, None, None # Resnet Blocks from https://github.com/autonomousvision/shape_as_points/blob/12757682f1075d83738b52f96747463b77343caf/src/network/utils.py class ResnetBlockFC(nn.Module): ''' Fully connected ResNet Block class. Args: size_in (int): input dimension size_out (int): output dimension size_h (int): hidden dimension ''' def __init__(self, size_in, size_out=None, size_h=None, siren=False): super().__init__() # Attributes if size_out is None: size_out = size_in if size_h is None: size_h = min(size_in, size_out) self.size_in = size_in self.size_h = size_h self.size_out = size_out # Submodules self.fc_0 = nn.Linear(size_in, size_h) self.fc_1 = nn.Linear(size_h, size_out) self.actvn = nn.ReLU() if size_in == size_out: self.shortcut = None else: self.shortcut = nn.Linear(size_in, size_out, bias=False) # Initialization nn.init.zeros_(self.fc_1.weight) def forward(self, x): net = self.fc_0(self.actvn(x)) dx = self.fc_1(self.actvn(net)) if self.shortcut is not None: x_s = self.shortcut(x) else: x_s = x return x_s + dx