File size: 5,020 Bytes
2ca34db f74bb58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
# THE CODE WAS TAKEN AND ADAPTED FROM https://pengsongyou.github.io/sap
# @inproceedings{Peng2021SAP,
# author = {Peng, Songyou and Jiang, Chiyu "Max" and Liao, Yiyi and Niemeyer, Michael and Pollefeys, Marc and Geiger, Andreas},
# title = {Shape As Points: A Differentiable Poisson Solver},
# booktitle = {Advances in Neural Information Processing Systems (NeurIPS)},
# year = {2021}
# }
import torch
import numpy as np
import time
from .utils import point_rasterize, grid_interp, mc_from_psr, \
calc_inters_points
from .dpsr import DPSR
import torch.nn as nn
class PSR2Mesh(torch.autograd.Function):
@staticmethod
def forward(ctx, psr_grid):
"""
In the forward pass we receive a Tensor containing the input and return
a Tensor containing the output. ctx is a context object that can be used
to stash information for backward computation. You can cache arbitrary
objects for use in the backward pass using the ctx.save_for_backward method.
"""
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
verts = verts.unsqueeze(0)
faces = faces.unsqueeze(0)
normals = normals.unsqueeze(0)
res = torch.tensor(psr_grid.detach().shape[2])
ctx.save_for_backward(verts, normals, res)
return verts, faces, normals
@staticmethod
def backward(ctx, dL_dVertex, dL_dFace, dL_dNormals):
"""
In the backward pass we receive a Tensor containing the gradient of the loss
with respect to the output, and we need to compute the gradient of the loss
with respect to the input.
"""
vert_pts, normals, res = ctx.saved_tensors
res = (res.item(), res.item(), res.item())
# matrix multiplication between dL/dV and dV/dPSR
# dV/dPSR = - normals
grad_vert = torch.matmul(dL_dVertex.permute(1, 0, 2), -normals.permute(1, 2, 0))
grad_grid = point_rasterize(vert_pts, grad_vert.permute(1, 0, 2), res) # b x 1 x res x res x res
return grad_grid
class PSR2SurfacePoints(torch.autograd.Function):
@staticmethod
def forward(ctx, psr_grid, poses, img_size, uv, psr_grad, mask_sample):
verts, faces, normals = mc_from_psr(psr_grid, pytorchify=True)
verts = verts * 2. - 1. # within the range of [-1, 1]
p_all, n_all, mask_all = [], [], []
for i in range(len(poses)):
pose = poses[i]
if mask_sample is not None:
p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size, mask_gt=mask_sample[i])
else:
p_inters, mask, _, _ = calc_inters_points(verts, faces, pose, img_size)
n_inters = grid_interp(psr_grad[None], (p_inters[None].detach() + 1) / 2).squeeze()
p_all.append(p_inters)
n_all.append(n_inters)
mask_all.append(mask)
p_inters_all = torch.cat(p_all, dim=0)
n_inters_all = torch.cat(n_all, dim=0)
mask_visible = torch.stack(mask_all, dim=0)
res = torch.tensor(psr_grid.detach().shape[2])
ctx.save_for_backward(p_inters_all, n_inters_all, res)
return p_inters_all, mask_visible
@staticmethod
def backward(ctx, dL_dp, dL_dmask):
pts, pts_n, res = ctx.saved_tensors
res = (res.item(), res.item(), res.item())
# grad from the p_inters via MLP renderer
grad_pts = torch.matmul(dL_dp[:, None], -pts_n[..., None])
grad_grid_pts = point_rasterize((pts[None]+1)/2, grad_pts.permute(1, 0, 2), res) # b x 1 x res x res x res
return grad_grid_pts, None, None, None, None, None
# Resnet Blocks from https://github.com/autonomousvision/shape_as_points/blob/12757682f1075d83738b52f96747463b77343caf/src/network/utils.py
class ResnetBlockFC(nn.Module):
''' Fully connected ResNet Block class.
Args:
size_in (int): input dimension
size_out (int): output dimension
size_h (int): hidden dimension
'''
def __init__(self, size_in, size_out=None, size_h=None, siren=False):
super().__init__()
# Attributes
if size_out is None:
size_out = size_in
if size_h is None:
size_h = min(size_in, size_out)
self.size_in = size_in
self.size_h = size_h
self.size_out = size_out
# Submodules
self.fc_0 = nn.Linear(size_in, size_h)
self.fc_1 = nn.Linear(size_h, size_out)
self.actvn = nn.ReLU()
if size_in == size_out:
self.shortcut = None
else:
self.shortcut = nn.Linear(size_in, size_out, bias=False)
# Initialization
nn.init.zeros_(self.fc_1.weight)
def forward(self, x):
net = self.fc_0(self.actvn(x))
dx = self.fc_1(self.actvn(net))
if self.shortcut is not None:
x_s = self.shortcut(x)
else:
x_s = x
return x_s + dx
|