Spaces:
Running
on
T4
Running
on
T4
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torchvision | |
import torchvision.models as models | |
from typing import Union, List, Tuple | |
import os | |
import video3d.utils.misc as misc | |
import torch.nn.functional as F | |
class Lift_Encoder(nn.Module): | |
def __init__( | |
self, | |
cin, | |
feat_dim, | |
grid_scale=7., | |
grid_size=32, | |
optim_latent=False, | |
img_size=256, | |
with_z_feature=False, | |
cam_pos_z_offset=10. | |
): | |
super().__init__() | |
''' | |
unproject the input feature map to tri-plane, each plane is (-1, -1)*grid_scale to (1, 1)*scale | |
''' | |
self.cin = cin | |
self.nf = feat_dim | |
self.grid_scale = grid_scale | |
self.grid_size = grid_size | |
self.img_size = img_size | |
self.with_z_feature = with_z_feature | |
self.cam_pos_z_offset = cam_pos_z_offset | |
self.feature_projector = nn.Linear(cin, feat_dim, bias=False) | |
self.plane_latent = None | |
if optim_latent: | |
self.optim_latent = nn.Parameter(torch.rand(3, feat_dim, grid_size, grid_size)) | |
else: | |
self.optim_latent = None | |
if with_z_feature: | |
self.conv_bottleneck = nn.Conv2d(feat_dim+1, feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate") | |
else: | |
self.conv_bottleneck = nn.Conv2d(feat_dim, feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate") | |
#TODO: implement an upsampler for input feature map here? | |
self.conv_1 = nn.Conv2d(feat_dim, 4*feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate") | |
self.conv_2 = nn.Conv2d(feat_dim, 4*feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate") | |
self.up = nn.PixelShuffle(2) | |
self.conv_enc = nn.Conv2d(feat_dim, feat_dim // 2, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate") | |
self.conv_dec = nn.Conv2d(feat_dim // 2, feat_dim, kernel_size=3, stride=1, padding=1, dilation=1, padding_mode="replicate") | |
self.feature_fusion = nn.Linear(3*feat_dim, feat_dim, bias=False) | |
def get_coords(self, grid_size): | |
with torch.no_grad(): | |
lines = torch.arange(0, grid_size) | |
grids_x, grids_y = torch.meshgrid([lines, lines], indexing="ij") | |
grids = torch.stack([grids_x, grids_y], dim=-1) | |
grids = (grids - self.grid_size // 2) / (self.grid_size // 2) | |
grids = grids * self.grid_scale | |
plane_z0 = torch.cat([grids, torch.zeros(list(grids.shape[:-1]) + [1])], dim=-1) # [S, S, 3] | |
plane_y0 = plane_z0.clone()[..., [0, 2, 1]] | |
plane_x0 = plane_z0.clone()[..., [2, 0, 1]] | |
planes = torch.stack([plane_x0, plane_y0, plane_z0], dim=0) | |
return planes # [3, S, S, 3] | |
def get_uv_z(self, pts, mvp): | |
cam4 = torch.matmul(torch.nn.functional.pad(pts, pad=(0,1), mode='constant', value=1.0), torch.transpose(mvp, 1, 2)) | |
cam3 = cam4[..., :3] / cam4[..., 3:4] | |
cam_uv = cam3[..., :2] | |
# cam_uv = cam_uv.detach() | |
cam_depth = cam3 + torch.FloatTensor([0, 0, self.cam_pos_z_offset]).to(pts.device).view(1, 1, 3) | |
cam_depth = cam_depth / self.grid_scale * 2 | |
cam_depth = cam_depth[..., 2:3] | |
return cam_uv, cam_depth | |
def unproject(self, feature_map, mvp): | |
''' | |
feature_map: [B, C, h, w] | |
mvp: [B, 4, 4] | |
''' | |
self.plane_latent = None | |
bs, C, h, w = feature_map.shape | |
device = feature_map.device | |
feature_map = self.feature_projector(feature_map.permute(0, 2, 3, 1).reshape(-1, C)).reshape(bs, h, w, self.nf).permute(0, 3, 1, 2) | |
feature_map = self.up(self.conv_1(feature_map)) | |
feature_map = self.up(self.conv_2(feature_map)) | |
plane_coords = self.get_coords(self.grid_size) | |
plane_coords = plane_coords.unsqueeze(0).repeat(bs, 1, 1, 1, 1) | |
plane_coords = plane_coords.to(device) | |
plane_pts = plane_coords.reshape(bs, -1, 3) # [B, N_POINTS, 3] | |
plane_uv, plane_z = self.get_uv_z(plane_pts, mvp) | |
plane_uv = plane_uv.detach() | |
plane_z = plane_z.detach() | |
nP = plane_pts.shape[1] | |
plane_feature = F.grid_sample(feature_map, plane_uv.reshape(bs, 1, nP, 2), mode="bilinear", padding_mode="zeros").squeeze(dim=-2).permute(0, 2, 1) | |
if self.with_z_feature: | |
plane_feature = torch.cat([plane_feature, plane_z], dim=-1) | |
plane_feature = plane_feature.reshape(plane_feature.shape[0], 3, self.grid_size, self.grid_size, plane_feature.shape[-1]) | |
return plane_feature | |
def conv_plane(self, plane_feature): | |
bs, _, nh, nw, nC = plane_feature.shape | |
plane_feature = plane_feature.reshape(-1, nh, nw, nC).permute(0, 3, 1, 2) # [bs*3, nC, nh, nw] | |
plane_feature = self.conv_bottleneck(plane_feature) | |
x = self.conv_dec(self.conv_enc(plane_feature)) | |
out = x + plane_feature | |
out = out.reshape(bs, 3, out.shape[-3], out.shape[-2], out.shape[-1]) | |
if self.optim_latent is not None: | |
optim_latent = self.optim_latent.unsqueeze(0).repeat(bs, 1, 1, 1, 1) | |
out = out + optim_latent | |
return out | |
def sample_plane(self, pts, feat): | |
''' | |
pts: [B, K, 3] | |
feat: [B, 3, C, h, w] | |
''' | |
pts_x, pts_y, pts_z = pts.unbind(dim=-1) | |
pts_x0 = torch.stack([pts_y, pts_z], dim=-1) | |
pts_y0 = torch.stack([pts_x, pts_z], dim=-1) | |
pts_z0 = torch.stack([pts_x, pts_y], dim=-1) | |
feat_x0 = F.grid_sample(feat[:, 0, :, :], pts_x0.unsqueeze(1), mode="bilinear", padding_mode="border").squeeze(-2).permute(0, 2, 1) | |
feat_y0 = F.grid_sample(feat[:, 0, :, :], pts_y0.unsqueeze(1), mode="bilinear", padding_mode="border").squeeze(-2).permute(0, 2, 1) | |
feat_z0 = F.grid_sample(feat[:, 0, :, :], pts_z0.unsqueeze(1), mode="bilinear", padding_mode="border").squeeze(-2).permute(0, 2, 1) | |
pts_feat = torch.cat([feat_x0, feat_y0, feat_z0], dim=-1) | |
return pts_feat | |
def forward(self, feature_map, mvp, pts, inference="unproject"): | |
''' | |
inference = "unproject" or "sample" | |
''' | |
assert inference in ["unproject", "sample"] | |
if inference == "unproject": | |
plane_feature = self.unproject(feature_map, mvp) | |
plane_feature = self.conv_plane(plane_feature) | |
self.plane_latent = plane_feature.clone().detach() # this is just for test case | |
if inference == "unproject": | |
feat_to_sample = plane_feature | |
else: | |
new_bs = pts.shape[0] | |
feat_to_sample = self.plane_latent[:new_bs] | |
pts_feature = self.sample_plane(pts, feat_to_sample) | |
pts_feature = self.feature_fusion(pts_feature) # [B, K, C] | |
return pts_feature |