刘虹雨
update
8ed2f16
# Generator for GenHead, modified from EG3D: https://github.com/NVlabs/eg3d
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
from torch import nn
from torch_utils import persistence
from models.stylegan.networks_stylegan2 import Generator as StyleGAN2Backbone
from models.stylegan.networks_stylegan2 import ToRGBLayer, FullyConnectedLayer, SynthesisNetwork
from models.stylegan.superresolution import SuperresolutionPatchMLP
from training.deformer.deformation import DeformationModule, DeformationModuleOnlyHead
from training.deformer.deform_utils import cam_world_matrix_transform
from training.volumetric_rendering.renderer import ImportanceRenderer, DeformImportanceRenderer, PartDeformImportanceRenderer, DeformImportanceRendererNew
from training.volumetric_rendering.ray_sampler import RaySampler
import dnnlib
import torch.nn.functional as F
from torch_utils.ops import upfirdn2d
import copy
# Baseline generator without separate deformation for face, eyes, and mouth
@persistence.persistent_class
class TriPlaneGeneratorDeform(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output resolution.
img_channels, # Number of output color channels.
sr_num_fp16_res = 0,
mapping_kwargs = {}, # Arguments for MappingNetwork.
rendering_kwargs = {},
deformation_kwargs = {},
sr_kwargs = {},
has_background = True,
has_superresolution = False,
flame_condition = False,
flame_full = False,
dynamic_texture = False, # Deprecated
random_combine = True,
triplane_resolution = 256,
triplane_channels = 96,
masked_sampling = None,
has_patch_sr = False,
add_block = False,
**synthesis_kwargs, # Arguments for SynthesisNetwork.
):
super().__init__()
self.z_dim=z_dim
self.c_dim=c_dim
self.w_dim=w_dim
self.flame_condition = flame_condition
self.has_background = has_background
self.has_superresolution = has_superresolution
self.dynamic_texture = dynamic_texture
decoder_output_dim = 32 if has_superresolution else 3
self.img_resolution=img_resolution
self.img_channels=img_channels
self.renderer = DeformImportanceRenderer()
self.ray_sampler = RaySampler()
self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32*3, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
self.to_dynamic = None
if self.has_background:
self.background = StyleGAN2Backbone(z_dim, 0, w_dim, img_resolution=64, mapping_kwargs={'num_layers':8}, channel_base=16384, img_channels=decoder_output_dim)
if self.has_superresolution:
self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=32, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs)
else:
self.superresolution = None
self.decoder = OSGDecoder(32, {'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), 'decoder_output_dim': decoder_output_dim})
self.neural_rendering_resolution = 64
self.rendering_kwargs = rendering_kwargs
self.deformer = DeformationModule(flame_full=flame_full,dynamic_texture=dynamic_texture, **deformation_kwargs)
self._last_planes = None
def _deformer(self,shape_params,exp_params,pose_params,eye_pose_params,ws,c_deform,cache_backbone=False, use_cached_backbone=False, use_rotation_limits=False, eye_blink_params=None):
return lambda coordinates: self.deformer(coordinates, shape_params,exp_params,pose_params,eye_pose_params,ws,c_deform,cache_backbone=cache_backbone,use_cached_backbone=use_cached_backbone, use_rotation_limits=use_rotation_limits)
def mapping(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
if self.rendering_kwargs['c_gen_conditioning_zero']:
c = torch.zeros_like(c)
return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
def synthesis(self, ws, z_bg, c, _deformer, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, use_dynamic=False,use_rotation_limits=None, smpl_param=None, patch_scale=None, chunk=None, run_full=None, uv=None, diff_dynamic=False, forward_mode='train', eye_blink_params=None, ws_super=None, **synthesis_kwargs):
if forward_mode == 'train':
face_ws = ws
dynamic_ws = ws
# for inversion only
elif ws.shape[1] >= self.backbone.num_ws + self.background.num_ws:
face_ws, bg_ws, dynamic_ws = ws[:, :self.backbone.num_ws, :], ws[:, self.backbone.num_ws:self.backbone.num_ws+self.background.num_ws, :], ws[:, self.backbone.num_ws+self.background.num_ws:, :]
else:
face_ws, bg_ws, dynamic_ws = ws[:, :self.backbone.num_ws, :], ws[:, self.backbone.num_ws:-1, :], ws[:, self.backbone.num_ws-1:self.backbone.num_ws, :]
cam2world_matrix = c[:, :16].view(-1, 4, 4)
world2cam_matrix = cam_world_matrix_transform(cam2world_matrix)
cam_z = world2cam_matrix[:,2,3]
intrinsics = c[:, 16:25].view(-1, 3, 3)
if neural_rendering_resolution is None:
neural_rendering_resolution = self.neural_rendering_resolution
else:
self.neural_rendering_resolution = neural_rendering_resolution
# Create a batch of rays for volume rendering
ray_origins, ray_directions, _ = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution)
# Create triplanes by running StyleGAN backbone
N, M, _ = ray_origins.shape
if use_cached_backbone and self._last_planes is not None:
planes = self._last_planes
else:
planes, last_featuremap = self.backbone.synthesis(face_ws, update_emas=update_emas, **synthesis_kwargs)
if cache_backbone:
self._last_planes = planes
planes = planes.view(len(planes), 3, -1, planes.shape[-2], planes.shape[-1])
if self.dynamic_texture:
pass # deprecated
else:
dynamic_planes = None
# Perform volume rendering
feature_samples, depth_samples, all_depths, all_weights, T_bg, offset, dist_to_surface, vts_mask, vts_mask_region, coarse_sample_points, coarse_triplane_features = self.renderer(planes, self.decoder, _deformer, ray_origins, ray_directions, self.rendering_kwargs, dynamic=self.dynamic_texture, cam_z=cam_z) # channels last
weights_samples = all_weights.sum(2)
# Reshape into 'raw' neural-rendered image
H = W = self.neural_rendering_resolution
feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous()
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
# background
if self.has_background:
if forward_mode == 'train':
background = self.background(z_bg, c, **synthesis_kwargs)
else:
background, _ = self.background.synthesis(bg_ws, update_emas=update_emas, **synthesis_kwargs)
background = torch.sigmoid(background) # (-1,1) (N,3,H,W)
if background.shape[-1] != neural_rendering_resolution:
background = F.interpolate(background,size=(neural_rendering_resolution,neural_rendering_resolution),mode='bilinear')
T_bg = T_bg.permute(0, 2, 1).reshape(N, 1, H, W)
feature_image = feature_image + T_bg*background
else:
T_bg = T_bg.permute(0, 2, 1).reshape(N, 1, H, W)
background = 0.
feature_image = 2*feature_image - 1
rgb_image = feature_image[:, :3]
if self.superresolution is not None:
sr_image = self.superresolution(rgb_image, feature_image, face_ws, ws_super=ws_super, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
else:
sr_image = None
return {'image': rgb_image, 'image_feature':feature_image, 'image_sr':sr_image, 'image_depth': depth_image, 'background':2*background-1, 'interval':all_depths.squeeze(-1), 'all_weights':all_weights.squeeze(-1), 'T_bg': T_bg, \
'seg': (1 - T_bg)*2 - 1, 'offset':offset, 'dist_to_surface':dist_to_surface, 'vts_mask':vts_mask, 'vts_mask_region':vts_mask_region, 'dynamic_planes':dynamic_planes, 'coarse_sample_points':coarse_sample_points, 'coarse_triplane_features':coarse_triplane_features}
def sample(self, coordinates, directions, shape_params,exp_params,pose_params,eye_pose_params, z, c, use_deform=True, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
# Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes.
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
planes, last_featuremap = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
planes = planes.view(len(planes), 3, -1, planes.shape[-2], planes.shape[-1])
# target space to canonical space deformation
if use_deform:
_deformer = self._deformer(shape_params,exp_params,pose_params,eye_pose_params)
out_deform = _deformer(coordinates)
coordinates = out_deform['canonical']
offset = out_deform['offset']
dynamic_mask = out_deform['dynamic_mask']
else:
coordinates = coordinates
offset = torch.zeros_like(coordinates)
dynamic_mask = None
out = self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs, dynamic_mask=dynamic_mask)
out['canonical'] = coordinates
out['offset'] = offset
return out
def sample_mixed(self, coordinates, directions, shape_params,exp_params,pose_params,eye_pose_params, ws, use_deform=True, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
# Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
planes, last_featuremap = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs)
# planes = torch.tanh(planes)
planes = planes.view(len(planes), 3, -1, planes.shape[-2], planes.shape[-1])
# target space to canonical space deformation
if use_deform:
_deformer = self._deformer(shape_params,exp_params,pose_params,eye_pose_params)
out_deform = _deformer(coordinates)
coordinates = out_deform['canonical']
offset = out_deform['offset']
dynamic_mask = out_deform['dynamic_mask']
else:
coordinates = coordinates
offset = torch.zeros_like(coordinates)
dynamic_mask = None
out = self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs, dynamic_mask=dynamic_mask)
out['canonical'] = coordinates
out['offset'] = offset
return out
def forward(self, shape_params,exp_params,pose_params,eye_pose_params, z, z_bg, c, c_compose, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, patch_scale=None, **synthesis_kwargs):
# Render a batch of generated images.
_deformer = self._deformer(shape_params,exp_params,pose_params,eye_pose_params)
c_compose_condition = c_compose.clone()
if self.flame_condition:
c_compose_condition = torch.cat([c_compose_condition,shape_params,exp_params],dim=-1)
ws = self.mapping(z, c_compose_condition, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
# Render correspondence map as condition to the discriminator
uv = self.deformer.renderer(shape_params, exp_params, pose_params, eye_pose_params, c, half_size=int(self.img_resolution/2))[0]
img = self.synthesis(ws, z_bg, c, _deformer=_deformer, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs)
img['uv'] = uv
return img
# Generator used in GenHead, with part-wise deformation
@persistence.persistent_class
class PartTriPlaneGeneratorDeform(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output resolution.
img_channels, # Number of output color channels.
triplane_channels,
sr_num_fp16_res = 0,
mapping_kwargs = {}, # Arguments for MappingNetwork.
rendering_kwargs = {},
deformation_kwargs = {},
sr_kwargs = {},
has_background = True,
has_superresolution = False,
has_patch_sr = False,
flame_condition = True,
flame_full = False,
dynamic_texture = False, # Deprecated
random_combine = True,
add_block = False,
triplane_resolution = 256,
masked_sampling = False,
**synthesis_kwargs, # Arguments for SynthesisNetwork.
):
super().__init__()
self.z_dim=z_dim
self.c_dim=c_dim
self.w_dim=w_dim
self.flame_condition = flame_condition
self.dynamic_texture = dynamic_texture
self.has_background = has_background
self.has_superresolution = has_superresolution
self.has_patch_sr = has_patch_sr
decoder_output_dim = 32 if has_superresolution else 3
self.img_resolution=img_resolution
self.img_channels=img_channels
self.renderer = PartDeformImportanceRenderer() if triplane_channels>96 else DeformImportanceRenderer()
self.mouth_part = True #if triplane_channels>192 else False
self.mouth_dynamic = False
self.masked_sampling = masked_sampling
self.ray_sampler = RaySampler()
self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=triplane_resolution, img_channels=triplane_channels, mapping_kwargs=mapping_kwargs, add_block=add_block, **synthesis_kwargs)
if self.has_background:
self.background = StyleGAN2Backbone(z_dim, 0, w_dim, img_resolution=64, mapping_kwargs={'num_layers':8}, channel_base=16384, img_channels=decoder_output_dim)
if self.has_superresolution:
self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=32, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs)
else:
self.superresolution = None
self.decoder = OSGDecoder(32, {'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), 'decoder_output_dim': decoder_output_dim})
self.neural_rendering_resolution = 64
self.rendering_kwargs = rendering_kwargs
if self.has_patch_sr:
self.patch_sr = SuperresolutionPatchMLP(channels=32, img_resolution=None, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=True)
self.to_dynamic = None
self.to_dynamic_sr = None
self.deformer = DeformationModule(flame_full=flame_full,dynamic_texture=dynamic_texture,part=True,**deformation_kwargs)
self._last_planes = None
self._last_dynamic_planes = None
self.max_pool = nn.MaxPool2d(kernel_size=7, stride=1, padding=3)
def _warping(self,images,flows):
# images [(B, M, C, H, W)]
# flows (B, M, 2, H, W) # inverse warpping flow
warp_images = []
B, M, _, H_f, W_f = flows.shape
flows = flows.view(B*M,2,H_f, W_f)
for im in images:
B, M, C, H, W = im.shape
im = im.view(B*M, C, H, W)
y, x = torch.meshgrid(torch.linspace(-1, 1, H, dtype=torch.float32, device=im.device), torch.linspace(-1, 1, W, dtype=torch.float32, device=im.device), indexing='ij')
xy = torch.stack([x, y], dim=-1).unsqueeze(0).repeat(B,1,1,1) #(B,H,W,2)
if H_f != H:
_flows = F.interpolate(flows,size=(H,W), mode='bilinear', align_corners=True)
else:
_flows = flows
_flows = _flows.permute(0,2,3,1) #(B,H,W,2)
uv = _flows + xy
warp_image = F.grid_sample(im, uv, mode='bilinear', padding_mode='zeros', align_corners=True) #(B,C,H,W)
warp_image = warp_image.view(B, M, C, H, W)
warp_images.append(warp_image)
return warp_images
def _deformer(self,shape_params,exp_params,pose_params,eye_pose_params,eye_blink_params=None,exp_params_dynamics=None,cache_backbone=False, use_cached_backbone=False,use_rotation_limits=False):
return lambda coordinates, mouth: self.deformer(coordinates, shape_params,exp_params,pose_params,eye_pose_params,eye_blink_params=eye_blink_params,exp_params_dynamics=exp_params_dynamics,cache_backbone=cache_backbone,use_cached_backbone=use_cached_backbone,use_rotation_limits=use_rotation_limits,mouth=mouth)
def mapping(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
if self.rendering_kwargs['c_gen_conditioning_zero']:
c = torch.zeros_like(c)
return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
def synthesis(self, ws, z_bg, c, _deformer, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, use_dynamic=False, use_rotation_limits=None, smpl_param=None, eye_blink_params=None, patch_scale=1.0, run_full=True, uv=None, chunk=None, diff_dynamic=False, dense_eye=False, forward_mode='train', ws_super=None, **synthesis_kwargs):
if forward_mode == 'train':
face_ws = ws
dynamic_ws = ws
# for inversion only
elif ws.shape[1] >= self.backbone.num_ws + self.background.num_ws:
face_ws, bg_ws, dynamic_ws = ws[:, :self.backbone.num_ws, :], ws[:, self.backbone.num_ws:self.backbone.num_ws+self.background.num_ws, :], ws[:, self.backbone.num_ws+self.background.num_ws:, :]
else:
face_ws, bg_ws, dynamic_ws = ws[:, :self.backbone.num_ws, :], ws[:, self.backbone.num_ws:-1, :], ws[:, self.backbone.num_ws-1:self.backbone.num_ws, :]
cam2world_matrix = c[:, :16].view(-1, 4, 4)
world2cam_matrix = cam_world_matrix_transform(cam2world_matrix)
cam_z = world2cam_matrix[:,2,3]
intrinsics = c[:, 16:25].view(-1, 3, 3)
N = cam2world_matrix.shape[0]
if neural_rendering_resolution is None:
neural_rendering_resolution = self.neural_rendering_resolution
elif self.training:
self.neural_rendering_resolution = neural_rendering_resolution
H = W = neural_rendering_resolution
with torch.no_grad():
eye_mask = self.deformer.renderer(smpl_param[0], smpl_param[1], smpl_param[2], smpl_param[3], c, half_size = int(self.img_resolution/2), eye_blink_params=eye_blink_params, eye_mask=True, use_rotation_limits=use_rotation_limits)[1]
face_wo_eye_mask = self.deformer.renderer(smpl_param[0], smpl_param[1], smpl_param[2], smpl_param[3], c, half_size = int(self.img_resolution/2), eye_blink_params=eye_blink_params, face_woeye=True, use_rotation_limits=use_rotation_limits)[1]
eye_mask = eye_mask * (1-face_wo_eye_mask)
blur_sigma = 1
blur_size = blur_sigma * 3
f = torch.arange(-blur_size, blur_size + 1, device=eye_mask.device).div(blur_sigma).square().neg().exp2()
eye_mask_sr = upfirdn2d.filter2d(eye_mask, f / f.sum())
eye_mask = torch.nn.functional.interpolate(eye_mask, size=(self.neural_rendering_resolution), mode='bilinear', align_corners=False, antialias=True)
head_mask = self.deformer.renderer(smpl_param[0], smpl_param[1], smpl_param[2], smpl_param[3], c, half_size = int(self.img_resolution/2), eye_blink_params=eye_blink_params, only_face=False,cull_backfaces=False, use_rotation_limits=use_rotation_limits)[1]
if self.mouth_part:
head_wo_mouth_mask = self.deformer.renderer(smpl_param[0], smpl_param[1], smpl_param[2], smpl_param[3], c, half_size = int(self.img_resolution/2), eye_blink_params=eye_blink_params, only_face=False,cull_backfaces=True, noinmouth=True, use_rotation_limits=use_rotation_limits)[1]
mouth_mask = head_mask * (1-head_wo_mouth_mask)
mouth_mask_sr = -self.max_pool(-mouth_mask)
mouth_mask_sr = self.max_pool(mouth_mask_sr)
blur_sigma = 2
blur_size = blur_sigma * 3
f = torch.arange(-blur_size, blur_size + 1, device=mouth_mask_sr.device).div(blur_sigma).square().neg().exp2()
mouth_mask_sr = upfirdn2d.filter2d(mouth_mask_sr, f / f.sum())
mouth_mask = torch.nn.functional.interpolate(mouth_mask, size=(self.neural_rendering_resolution), mode='bilinear', align_corners=False, antialias=True)
mouth_mask_sr = (mouth_mask_sr + torch.nn.functional.interpolate(mouth_mask, size=(self.img_resolution), mode='bilinear', align_corners=False, antialias=True)).clamp(max=1)
# mouth_mask_sr[:,:,:128,:] *= 0 # for visualization only (deprecated)
else:
mouth_mask = None
mouth_mask_sr = None
head_mask = torch.nn.functional.interpolate(head_mask, size=(neural_rendering_resolution), mode='bilinear', align_corners=False, antialias=True)
head_mask_sr = torch.nn.functional.interpolate(head_mask, size=(self.img_resolution), mode='bilinear', align_corners=False, antialias=True)
# Create triplanes by running StyleGAN backbone
# N, M, _ = ray_origins.shape
if use_cached_backbone and self._last_planes is not None:
planes = self._last_planes
dynamic_planes = self._last_dynamic_planes
else:
planes, last_featuremap = self.backbone.synthesis(face_ws, update_emas=update_emas, **synthesis_kwargs)
# Reshape output into three 32-channel planes
if not isinstance(planes, list):
planes = [planes]
# last_featuremap = [last_featuremap]
planes = [p.view(len(p), -1, 32, p.shape[-2], p.shape[-1]) for p in planes]
if self.dynamic_texture:
pass # deprecated
else:
dynamic_planes = None
if cache_backbone:
self._last_planes = planes
self._last_dynamic_planes = dynamic_planes
if self.has_background:
if forward_mode == 'train':
background = self.background(z_bg, c, **synthesis_kwargs)
else:
background, _ = self.background.synthesis(bg_ws, update_emas=update_emas, **synthesis_kwargs)
background = torch.sigmoid(background) # (-1,1) (N,3,H,W)
background = F.interpolate(background,size=(256, 256),mode='bilinear')
background_feature = F.interpolate(background,size=(self.img_resolution,self.img_resolution),mode='bilinear')
if background.shape[-1] != neural_rendering_resolution:
background = F.interpolate(background,size=(neural_rendering_resolution,neural_rendering_resolution),mode='bilinear')
# Create a batch of rays for volume rendering
output = {}
coarse_sample_points = coarse_triplane_features = None
if run_full:
ray_origins, ray_directions, _ = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution)
mouth_mask_flat = mouth_mask[:, 0].view(mouth_mask.shape[0], -1, 1) if neural_rendering_resolution==self.neural_rendering_resolution else mouth_mask_sr[:, 0].view(mouth_mask_sr.shape[0], -1, 1)
eye_mask_flat = eye_mask[:, 0].view(eye_mask.shape[0], -1, 1) if neural_rendering_resolution==self.neural_rendering_resolution else eye_mask_sr[:, 0].view(eye_mask_sr.shape[0], -1, 1)
# Perform volume rendering
if chunk is None:
feature_samples, depth_samples, all_depths, all_weights, T_bg, offset, dist_to_surface, densities_face_ineye, densities_face_inmouth,vts_mask, vts_mask_region, coarse_sample_points, coarse_triplane_features, eye_mask_sel, mouth_mask_sel \
= self.renderer((planes[0:1], planes)[neural_rendering_resolution>64], self.decoder, _deformer, ray_origins, ray_directions, self.rendering_kwargs, eye_mask=eye_mask_flat, mouth_mask=mouth_mask_flat, mouth_dynamic=self.mouth_dynamic, auto_fuse=(False, True)[neural_rendering_resolution>64],cam_z=cam_z) # channels last
if dense_eye: # only for batchsize=1
is_eye_region = (eye_mask_flat!=0).squeeze(0).squeeze(-1)
if torch.sum(is_eye_region.to(torch.float32)) == 0:
feature_samples_eye = 0
T_bg_eye = 0
else:
ray_origins_eye = ray_origins[:,is_eye_region]
ray_directions_eye = ray_directions[:,is_eye_region]
eye_mask_eye = eye_mask_flat[:,is_eye_region]
rendering_kwargs_eye = copy.deepcopy(self.rendering_kwargs)
rendering_kwargs_eye['depth_resolution'] = 128
rendering_kwargs_eye['depth_resolution_importance'] = 128
feature_samples_eye, depth_samples_eye, all_depths_eye, all_weights_eye, T_bg_eye, offset_eye, dist_to_surface_eye, densities_face_ineye_eye, densities_face_inmouth_eye,vts_mask_eye, vts_mask_region_eye, _, _, _, _ = self.renderer((planes[0:1], planes)[neural_rendering_resolution>64], self.decoder, _deformer, ray_origins_eye, ray_directions_eye, rendering_kwargs_eye, eye_mask=eye_mask_eye, mouth_mask=None, mouth_dynamic=self.mouth_dynamic, auto_fuse=(False, True)[neural_rendering_resolution>64]) # channels last
else:
feature_samples, depth_samples, all_depths, all_weights, T_bg, offset, dist_to_surface, densities_face_ineye, densities_face_inmouth, vts_mask, vts_mask_region = list(), list(), list(), list(), list(), list(), list(), list(), list(), list(), list()
for _ro, _rd, _em, _mm in zip(torch.split(ray_origins, chunk, dim=1), torch.split(ray_directions, chunk, dim=1), torch.split(eye_mask_flat, chunk, dim=1), torch.split(mouth_mask_flat, chunk, dim=1)):
_f, _d, _ad, _aw, _tbg, _off, _ds, _dfe, _dfm, _vm, _vmr = self.renderer((planes[0:1], planes)[neural_rendering_resolution>64], self.decoder, _deformer, _ro, _rd, self.rendering_kwargs, eye_mask=_em, mouth_mask=_mm, mouth_dynamic=self.mouth_dynamic, auto_fuse=(False, True)[neural_rendering_resolution>64],cam_z=cam_z) # channels last
feature_samples.append(_f)
depth_samples.append(_d)
all_depths.append(_ad)
all_weights.append(_aw)
T_bg.append(_tbg)
offset.append(_off)
dist_to_surface.append(_ds)
densities_face_ineye.append(_dfe)
densities_face_inmouth.append(_dfm)
vts_mask.append(_vm)
vts_mask_region.append(_vmr)
feature_samples = torch.cat(feature_samples, 1)
depth_samples = torch.cat(depth_samples, 1)
all_depths = torch.cat(all_depths, 1)
all_weights = torch.cat(all_weights, 1)
T_bg = torch.cat(T_bg, 1)
offset = torch.cat(offset, 1)
dist_to_surface = torch.cat(dist_to_surface, 1)
densities_face_ineye = torch.cat(densities_face_ineye, 1)
densities_face_inmouth = torch.cat(densities_face_inmouth, 1)
vts_mask = torch.cat(vts_mask, 1)
vts_mask_region = torch.cat(vts_mask_region, 1)
weights_samples = all_weights.sum(2)
if dense_eye:
feature_samples[:,is_eye_region] = feature_samples_eye
T_bg[:,is_eye_region] = T_bg_eye
# Reshape into 'raw' neural-rendered image
# H = W = self.neural_rendering_resolution
feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous()
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
# background
if feature_image.shape[-1]<=128:
T_bg = T_bg.permute(0, 2, 1).reshape(N, 1, H, W)
if self.has_background:
feature_image = feature_image + T_bg*background
else:
feature_image = feature_image + T_bg
feature_image = 2*feature_image - 1
rgb_image = feature_image[:, :3]
if self.superresolution is not None and rgb_image.shape[-1]<=128:
sr_image = self.superresolution(rgb_image, feature_image, ws, ws_super=ws_super, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
else:
sr_image = rgb_image
else:
rgb_image = feature_image[:, :3]
if self.has_background:
background_sr = 2*background - 1
background_sr = self.superresolution(background_sr[:, :3], background_sr, ws, ws_super=ws_super, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
background_sr = (background_sr + 1) * 0.5
T_bg = T_bg.permute(0, 2, 1).reshape(N, 1, H, W)
rgb_image = rgb_image + T_bg*background_sr
else:
T_bg = T_bg.permute(0, 2, 1).reshape(N, 1, H, W)
rgb_image = rgb_image + T_bg
rgb_image = 2*rgb_image - 1
sr_image = rgb_image
if self.has_patch_sr:
rgb_image_ = rgb_image
background_feature = torch.cat([background_sr,background_feature[:,3:]],dim=1)
feature_image_ = feature_image + T_bg * background_feature
feature_image_ = 2*feature_image_ - 1
sr_rgb_image = self.patch_sr(rgb_image_[:, :3], feature_image_, torch.ones_like(ws))
output.update({'image_raw_sr': sr_rgb_image})
output.update({'image': rgb_image, 'image_feature':feature_image,'image_sr':sr_image, 'background':2*background-1, 'image_depth': depth_image, 'interval':all_depths.squeeze(-1), 'all_weights':all_weights.squeeze(-1), \
'T_bg': T_bg, 'offset':offset, 'dist_to_surface':dist_to_surface, 'eye_mask': eye_mask, 'head_mask': head_mask, 'mouth_mask': mouth_mask, 'mouth_mask_sr': mouth_mask_sr+eye_mask_sr, \
'densities_face_ineye': densities_face_ineye, 'densities_face_inmouth': densities_face_inmouth, 'seg': (1 - T_bg)*2 - 1, 'vts_mask':vts_mask, 'vts_mask_region':vts_mask_region, 'dynamic_planes':dynamic_planes, 'uv': uv,\
'coarse_sample_points':coarse_sample_points, 'coarse_triplane_features':coarse_triplane_features, 'eye_mask_sel':eye_mask_sel, 'mouth_mask_sel':mouth_mask_sel})
if patch_scale<1:
patch_ray_origins, patch_ray_directions, patch_info = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution, patch_scale=patch_scale, mask=[eye_mask_sr+mouth_mask_sr, head_mask_sr], masked_sampling=self.masked_sampling)
if self.has_background:
background_sr = 2*background - 1
background_sr = self.superresolution(background_sr[:, :3], background_sr, ws, ws_super=ws_super, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
background_sr = (background_sr+1)*0.5
patch_background_sr = []
patch_background_feature = []
patch_eye_mask = []
if uv is not None:
patch_uv = []
if run_full:
patch_sr_image = []
patch_rgb_image = []
sr_image_ = sr_image.detach()
rgb_image_ = rgb_image.detach()
rgb_image_ = torch.nn.functional.interpolate(rgb_image_, size=(sr_image_.shape[-1]),
mode='bilinear', align_corners=False, antialias=True)
if self.mouth_part:
patch_mouth_mask = []
for i in range(len(patch_info)):
top, left = patch_info[i]
patch_eye_mask.append(eye_mask_sr[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution])
if uv is not None:
patch_uv.append(uv[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution])
if run_full:
patch_sr_image.append(sr_image_[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution])
patch_rgb_image.append(rgb_image_[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution])
if self.mouth_part:
patch_mouth_mask.append(mouth_mask_sr[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution])
if self.has_background:
patch_background_sr.append(background_sr[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution])
patch_background_feature.append(background_feature[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution])
patch_eye_mask = torch.cat(patch_eye_mask, 0)
if uv is not None:
patch_uv = torch.cat(patch_uv, 0)
else:
patch_uv = None
if run_full:
patch_sr_image = torch.cat(patch_sr_image, 0)
patch_rgb_image = torch.cat(patch_rgb_image, 0)
output.update({'patch_image': patch_sr_image, 'patch_image_gr': patch_rgb_image})
if self.mouth_part:
patch_mouth_mask = torch.cat(patch_mouth_mask, 0)
if self.has_background:
patch_background_sr = torch.cat(patch_background_sr, 0)
patch_background_feature = torch.cat(patch_background_feature, 0)
# Perform volume rendering
patch_mouth_mask_flat = patch_mouth_mask[:, 0].view(mouth_mask.shape[0], -1, 1)
patch_eye_mask_flat = patch_eye_mask[:, 0].view(eye_mask.shape[0], -1, 1)
patch_feature_samples, patch_depth_samples, patch_all_depths, patch_all_weights, patch_T_bg, patch_offset, patch_dist_to_surface, patch_densities_face_ineye, patch_densities_face_inmouth, patch_vts_mask, patch_vts_mask_region, _, _, _, _ = self.renderer(planes, self.decoder, _deformer, patch_ray_origins, patch_ray_directions, self.rendering_kwargs, eye_mask=patch_eye_mask_flat, mouth_mask=patch_mouth_mask_flat, mouth_dynamic=self.mouth_dynamic, auto_fuse=True,cam_z=cam_z) # channels last
# Reshape into 'raw' neural-rendered image
patch_feature_image = patch_feature_samples.permute(0, 2, 1).reshape(N, patch_feature_samples.shape[-1], H, W).contiguous()
patch_depth_image = patch_depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
patch_rgb_image = patch_feature_image[:, :3]
if self.has_background:
patch_T_bg = patch_T_bg.permute(0, 2, 1).reshape(N, 1, H, W)
patch_rgb_image = patch_rgb_image + patch_T_bg * patch_background_sr
patch_rgb_image = 2*patch_rgb_image-1
if self.has_patch_sr:
patch_rgb_image_ = patch_rgb_image.clone().detach()
patch_background_feature = torch.cat([patch_background_sr,patch_background_feature[:,3:]],dim=1)
patch_feature_image_ = patch_feature_image.clone().detach() + patch_T_bg.clone().detach() * patch_background_feature.clone().detach()
patch_feature_image_ = 2*patch_feature_image_ - 1
sr_patch_rgb_image = self.patch_sr(patch_rgb_image_[:, :3], patch_feature_image_, torch.ones_like(ws))
output.update({'patch_image_raw_sr': sr_patch_rgb_image})
output.update({'patch_image_raw': patch_rgb_image, 'patch_seg': (1 - patch_T_bg)*2 - 1, 'patch_T_bg': patch_T_bg, 'patch_uv': patch_uv, 'patch_mouth_mask': patch_mouth_mask, 'patch_all_depths': patch_all_depths.squeeze(-1), 'patch_all_weights': patch_all_weights.squeeze(-1)})
return output
def sample(self, coordinates, directions, shape_params,exp_params,pose_params,eye_pose_params, z, c, use_deform=True, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
# Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes.
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
# planes = torch.tanh(planes)
planes = planes.view(len(planes), -1, 32, planes.shape[-2], planes.shape[-1])
# target space to canonical space deformation
if use_deform:
_deformer = self._deformer(shape_params,exp_params,pose_params,eye_pose_params)
out_deform = _deformer(coordinates)
coordinates = out_deform['canonical']
offset = out_deform['offset']
else:
coordinates = coordinates
offset = torch.zeros_like(coordinates)
out = self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
out['canonical'] = coordinates
out['offset'] = offset
# out['offset'] = torch.zeros_like(coordinates)
return out
def sample_mixed(self, coordinates, directions, shape_params,exp_params,pose_params,eye_pose_params, ws, use_deform=True, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
# Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
planes, last_featuremap = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs)
# planes = planes.view(len(planes), -1, 32, planes.shape[-2], planes.shape[-1])
if not isinstance(planes, list):
planes = [planes]
last_featuremap = [last_featuremap]
planes = [p.view(len(p), -1, 32, p.shape[-2], p.shape[-1]) for p in planes]
dynamic_planes = None
# target space to canonical space deformation
if use_deform:
_deformer = self._deformer(shape_params,exp_params,pose_params,eye_pose_params)
out_deform = _deformer(coordinates, mouth=self.mouth_part)
coordinates_eye = out_deform['canonical_eye']
coordinates_face = out_deform['canonical_face']
coordinates_mouth = out_deform['canonical_mouth']
if out_deform['dynamic_mask'] is not None:
dynamic_mask = out_deform['dynamic_mask'] * (1-out_deform['inside_bbox_eye'][..., None])
else:
dynamic_mask = out_deform['dynamic_mask']
offset = torch.zeros_like(coordinates_eye)
else:
coordinates = coordinates
offset = torch.zeros_like(coordinates)
dynamic_mask = None
plane_eye = [p[:, :3] for p in planes]
if dynamic_mask is not None:
if self.mouth_part and self.mouth_dynamic:
plane_mouth = [torch.cat([p[:, :3],p[:, -3:]],dim=2) for p in planes]
else:
plane_mouth = [p[:, :3] for p in planes]
plane_face = [torch.cat([p[:, 3:6],p[:, -3:]],dim=2) for p in planes]
else:
if self.mouth_part:
plane_mouth = [p[:, :3] for p in planes]
plane_face = [p[:, 3:6] for p in planes]
out_eye = self.renderer.run_model(plane_eye, self.decoder, coordinates_eye, directions, self.rendering_kwargs,dynamic_mask=None)
out_face = self.renderer.run_model(plane_face, self.decoder, coordinates_face, directions, self.rendering_kwargs,dynamic_mask=dynamic_mask)
out_eye['canonical'] = coordinates_eye
out_face['canonical'] = coordinates_face
out_eye['offset'] = offset
out_face['offset'] = offset
if self.mouth_part:
out_mouth = self.renderer.run_model(plane_mouth, self.decoder, coordinates_mouth, directions, self.rendering_kwargs,dynamic_mask=(None, dynamic_mask)[self.mouth_dynamic])
out_mouth['canonical'] = coordinates_mouth
out_mouth['offset'] = offset
return out_eye, out_face, out_mouth
else:
return out_eye, out_face
def forward(self, shape_params,exp_params,pose_params,eye_pose_params, z, z_bg, c, c_compose, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, patch_scale=1.0, **synthesis_kwargs):
# Render a batch of generated images.
_deformer = self._deformer(shape_params,exp_params,pose_params,eye_pose_params)
c_compose_condition = c_compose.clone()
if self.flame_condition:
c_compose_condition = torch.cat([c_compose_condition,shape_params],dim=-1)
ws = self.mapping(z, c_compose_condition, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
# Render correspondence map as condition to the discriminator
render_out = self.deformer.renderer(shape_params, exp_params, pose_params, eye_pose_params, c, half_size=int(self.img_resolution/2))
uv = render_out[0]
landmarks = render_out[-1]
img = self.synthesis(ws, z_bg, c, _deformer=_deformer, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, smpl_param=(shape_params, exp_params, pose_params, eye_pose_params), patch_scale=patch_scale, **synthesis_kwargs)
img['uv'] = uv
img['landmarks'] = landmarks
return img
def zero_init(m):
with torch.no_grad():
nn.init.constant_(m.weight,0)
nn.init.constant_(m.bias,0)
@persistence.persistent_class
class OSGDecoder(torch.nn.Module):
def __init__(self, n_features, options):
super().__init__()
self.hidden_dim = 64
self.net = torch.nn.Sequential(
FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']),
torch.nn.Softplus()
)
self.out_sigma = FullyConnectedLayer(self.hidden_dim, 1, lr_multiplier=options['decoder_lr_mul'])
self.out_rgb = FullyConnectedLayer(self.hidden_dim, options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul'])
self.out_sigma.apply(zero_init)
def forward(self, sampled_features, ray_directions):
# Aggregate features
sampled_features = sampled_features.mean(1)
x = sampled_features
N, M, C = x.shape
x = x.view(N*M, C)
x = self.net(x)
rgb = self.out_rgb(x)
sigma = self.out_sigma(x)
rgb = rgb.view(N, M, -1)
sigma = sigma.view(N, M, -1)
rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
return {'rgb': rgb, 'sigma': sigma}