Spaces:
Sleeping
Sleeping
File size: 6,998 Bytes
98a77e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.models as models
from typing import Union, List, Tuple
import os
import video3d.utils.misc as misc
import torch.nn.functional as F
class Lift_Encoder(nn.Module):
def __init__(
self,
cin,
feat_dim,
grid_scale=7.,
grid_size=32,
optim_latent=False,
img_size=256,
with_z_feature=False,
cam_pos_z_offset=10.
):
super().__init__()
'''
unproject the input feature map to tri-plane, each plane is (-1, -1)*grid_scale to (1, 1)*scale
'''
self.cin = cin
self.nf = feat_dim
self.grid_scale = grid_scale
self.grid_size = grid_size
self.img_size = img_size
self.with_z_feature = with_z_feature
self.cam_pos_z_offset = cam_pos_z_offset
self.feature_projector = nn.Linear(cin, feat_dim, bias=False)
self.plane_latent = None
if optim_latent:
self.optim_latent = nn.Parameter(torch.rand(3, feat_dim, grid_size, grid_size))
else:
self.optim_latent = None
if with_z_feature:
self.conv_bottleneck = nn.Conv2d(feat_dim+1, feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate")
else:
self.conv_bottleneck = nn.Conv2d(feat_dim, feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate")
#TODO: implement an upsampler for input feature map here?
self.conv_1 = nn.Conv2d(feat_dim, 4*feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate")
self.conv_2 = nn.Conv2d(feat_dim, 4*feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate")
self.up = nn.PixelShuffle(2)
self.conv_enc = nn.Conv2d(feat_dim, feat_dim // 2, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate")
self.conv_dec = nn.Conv2d(feat_dim // 2, feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate")
self.feature_fusion = nn.Linear(3*feat_dim, feat_dim, bias=False)
def get_coords(self, grid_size):
with torch.no_grad():
lines = torch.arange(0, grid_size)
grids_x, grids_y = torch.meshgrid([lines, lines], indexing="ij")
grids = torch.stack([grids_x, grids_y], dim=-1)
grids = (grids - self.grid_size // 2) / (self.grid_size // 2)
grids = grids * self.grid_scale
plane_z0 = torch.cat([grids, torch.zeros(list(grids.shape[:-1]) + [1])], dim=-1) # [S, S, 3]
plane_y0 = plane_z0.clone()[..., [0, 2, 1]]
plane_x0 = plane_z0.clone()[..., [2, 0, 1]]
planes = torch.stack([plane_x0, plane_y0, plane_z0], dim=0)
return planes # [3, S, S, 3]
def get_uv_z(self, pts, mvp):
cam4 = torch.matmul(torch.nn.functional.pad(pts, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2))
cam3 = cam4[..., :3] / cam4[..., 3:4]
cam_uv = cam3[..., :2]
# cam_uv = cam_uv.detach()
cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(pts.device).view(1, 1, 3)
cam_depth = cam_depth / self.grid_scale * 2
cam_depth = cam_depth[..., 2:3]
return cam_uv, cam_depth
def unproject(self, feature_map, mvp):
'''
feature_map: [B, C, h, w]
mvp: [B, 4, 4]
'''
self.plane_latent = None
bs, C, h, w = feature_map.shape
device = feature_map.device
feature_map = self.feature_projector(feature_map.permute(0, 2, 3, 1).reshape(-1, C)).reshape(bs, h, w, self.nf).permute(0, 3, 1, 2)
feature_map = self.up(self.conv_1(feature_map))
feature_map = self.up(self.conv_2(feature_map))
plane_coords = self.get_coords(self.grid_size)
plane_coords = plane_coords.unsqueeze(0).repeat(bs, 1, 1, 1, 1)
plane_coords = plane_coords.to(device)
plane_pts = plane_coords.reshape(bs, -1, 3) # [B, N_POINTS, 3]
plane_uv, plane_z = self.get_uv_z(plane_pts, mvp)
plane_uv = plane_uv.detach()
plane_z = plane_z.detach()
nP = plane_pts.shape[1]
plane_feature = F.grid_sample(feature_map, plane_uv.reshape(bs, 1, nP, 2), mode="bilinear", padding_mode="zeros").squeeze(dim=-2).permute(0, 2, 1)
if self.with_z_feature:
plane_feature = torch.cat([plane_feature, plane_z], dim=-1)
plane_feature = plane_feature.reshape(plane_feature.shape[0], 3, self.grid_size, self.grid_size, plane_feature.shape[-1])
return plane_feature
def conv_plane(self, plane_feature):
bs, _, nh, nw, nC = plane_feature.shape
plane_feature = plane_feature.reshape(-1, nh, nw, nC).permute(0, 3, 1, 2) # [bs*3, nC, nh, nw]
plane_feature = self.conv_bottleneck(plane_feature)
x = self.conv_dec(self.conv_enc(plane_feature))
out = x + plane_feature
out = out.reshape(bs, 3, out.shape[-3], out.shape[-2], out.shape[-1])
if self.optim_latent is not None:
optim_latent = self.optim_latent.unsqueeze(0).repeat(bs, 1, 1, 1, 1)
out = out + optim_latent
return out
def sample_plane(self, pts, feat):
'''
pts: [B, K, 3]
feat: [B, 3, C, h, w]
'''
pts_x, pts_y, pts_z = pts.unbind(dim=-1)
pts_x0 = torch.stack([pts_y, pts_z], dim=-1)
pts_y0 = torch.stack([pts_x, pts_z], dim=-1)
pts_z0 = torch.stack([pts_x, pts_y], dim=-1)
feat_x0 = F.grid_sample(feat[:, 0, :, :], pts_x0.unsqueeze(1), mode="bilinear", padding_mode="border").squeeze(-2).permute(0, 2, 1)
feat_y0 = F.grid_sample(feat[:, 0, :, :], pts_y0.unsqueeze(1), mode="bilinear", padding_mode="border").squeeze(-2).permute(0, 2, 1)
feat_z0 = F.grid_sample(feat[:, 0, :, :], pts_z0.unsqueeze(1), mode="bilinear", padding_mode="border").squeeze(-2).permute(0, 2, 1)
pts_feat = torch.cat([feat_x0, feat_y0, feat_z0], dim=-1)
return pts_feat
def forward(self, feature_map, mvp, pts, inference="unproject"):
'''
inference = "unproject" or "sample"
'''
assert inference in ["unproject", "sample"]
if inference == "unproject":
plane_feature = self.unproject(feature_map, mvp)
plane_feature = self.conv_plane(plane_feature)
self.plane_latent = plane_feature.clone().detach() # this is just for test case
if inference == "unproject":
feat_to_sample = plane_feature
else:
new_bs = pts.shape[0]
feat_to_sample = self.plane_latent[:new_bs]
pts_feature = self.sample_plane(pts, feat_to_sample)
pts_feature = self.feature_fusion(pts_feature) # [B, K, C]
return pts_feature |