import torch |
import torch.nn as nn |
import torch.nn.functional as F |
from .deeplabv3 import DeepLabV3 |
from .simple_encoders.high_resolution_encoder import HighResoEncoder |
from .segformer import LowResolutionViT, TriplanePredictorViT |
import copy |
from utils.commons.hparams import hparams |
class Img2PlaneModel(nn.Module): |
def __init__(self, out_channels=96, hp=None): |
super().__init__() |
global hparams |
self.hparams = hp if hp is not None else copy.deepcopy(hparams) |
hparams = self.hparams |
self.input_mode = hparams.get("img2plane_input_mode", "rgb") |
if self.input_mode == 'rgb': |
in_channels = 3 |
elif self.input_mode == 'rgb_alpha': |
in_channels = 4 |
elif self.input_mode == 'rgb_camera': |
self.camera_to_channel = nn.Linear(25, 3) |
in_channels = 3 + 3 |
elif self.input_mode == 'rgb_alpha_camera': |
self.camera_to_channel = nn.Linear(25, 3) |
in_channels = 4 + 3 |
in_channels += 2 |
self.low_reso_encoder = DeepLabV3(in_channels=in_channels) |
self.high_reso_encoder = HighResoEncoder(in_dim=in_channels) |
self.low_reso_vit = LowResolutionViT() |
self.triplane_predictor_vit = TriplanePredictorViT(out_channels=out_channels, img2plane_backbone_scale=hparams['img2plane_backbone_scale']) |
def forward(self, x, cond=None, **synthesis_kwargs): |
""" |
x: original image, [B, 3, H=512, W=512] |
return: predicted triplane, [B, 32*3, H=256, W=256] |
optional: |
ref_alphas: 0/1 mask, if img2plane, all ones; if secc2plane, only ones for head, [B, 1, H, W] |
ref_camera: camera pose of the input img, [B, 25] |
""" |
bs, _, H, W = x.shape |
if self.input_mode in ['rgb_alpha', 'rgb_alpha_camera']: |
if cond is None or cond.get("ref_alphas") is None: |
ref_alphas = (x.mean(dim=1, keepdim=True) >= -0.999).float() |
else: |
ref_alphas = cond["ref_alphas"] |
x = torch.cat([x, ref_alphas], dim=1) |
if self.input_mode in ['rgb_camera', 'rgb_alpha_camera']: |
ref_cameras = cond["ref_cameras"] |
camera_feat = self.camera_to_channel(ref_cameras).reshape(bs, 3, 1, 1).repeat([1, 1, H, W]) |
x = torch.cat([x, camera_feat], dim=1) |
grid_x, grid_y = torch.meshgrid(torch.arange(H, device=x.device), torch.arange(W, device=x.device)) |
grid_x = grid_x / H |
grid_y = grid_y / H |
expanded_x = grid_x[None, None, :, :].repeat(bs, 1, 1, 1) |
expanded_y = grid_y[None, None, :, :].repeat(bs, 1, 1, 1) |
x = torch.cat([x, expanded_x, expanded_y], dim=1) |
feat_low = self.low_reso_encoder(x) |
feat_low_after_vit = self.low_reso_vit(feat_low) |
feat_high = self.high_reso_encoder(x) |
planes = self.triplane_predictor_vit(feat_low_after_vit, feat_high) |
planes = planes.view(len(planes), 3, -1, planes.shape[-2], planes.shape[-1]) |
planes_xy = planes[:,0] |
planes_xy = torch.flip(planes_xy, [2]) |
planes_xz = planes[:,1] |
planes_xz = torch.flip(planes_xz, [2]) |
planes_zy = planes[:,2] |
planes_zy = torch.flip(planes_zy, [2, 3]) |
planes = torch.stack([planes_xy, planes_xz, planes_zy], dim=1) |
return planes |