Spaces:
Running
on
Zero
Running
on
Zero
# Generator for GenHead, modified from EG3D: https://github.com/NVlabs/eg3d | |
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary | |
# | |
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual | |
# property and proprietary rights in and to this material, related | |
# documentation and any modifications thereto. Any use, reproduction, | |
# disclosure or distribution of this material and related documentation | |
# without an express license agreement from NVIDIA CORPORATION or | |
# its affiliates is strictly prohibited. | |
import torch | |
from torch import nn | |
from torch_utils import persistence | |
from models.stylegan.networks_stylegan2 import Generator as StyleGAN2Backbone | |
from models.stylegan.networks_stylegan2 import ToRGBLayer, FullyConnectedLayer, SynthesisNetwork | |
from models.stylegan.superresolution import SuperresolutionPatchMLP | |
from training.deformer.deformation import DeformationModule, DeformationModuleOnlyHead | |
from training.deformer.deform_utils import cam_world_matrix_transform | |
from training.volumetric_rendering.renderer import ImportanceRenderer, DeformImportanceRenderer, PartDeformImportanceRenderer, DeformImportanceRendererNew | |
from training.volumetric_rendering.ray_sampler import RaySampler | |
import dnnlib | |
import torch.nn.functional as F | |
from torch_utils.ops import upfirdn2d | |
import copy | |
# Baseline generator without separate deformation for face, eyes, and mouth | |
class TriPlaneGeneratorDeform(torch.nn.Module): | |
def __init__(self, | |
z_dim, # Input latent (Z) dimensionality. | |
c_dim, # Conditioning label (C) dimensionality. | |
w_dim, # Intermediate latent (W) dimensionality. | |
img_resolution, # Output resolution. | |
img_channels, # Number of output color channels. | |
sr_num_fp16_res = 0, | |
mapping_kwargs = {}, # Arguments for MappingNetwork. | |
rendering_kwargs = {}, | |
deformation_kwargs = {}, | |
sr_kwargs = {}, | |
has_background = True, | |
has_superresolution = False, | |
flame_condition = False, | |
flame_full = False, | |
dynamic_texture = False, # Deprecated | |
random_combine = True, | |
triplane_resolution = 256, | |
triplane_channels = 96, | |
masked_sampling = None, | |
has_patch_sr = False, | |
add_block = False, | |
**synthesis_kwargs, # Arguments for SynthesisNetwork. | |
): | |
super().__init__() | |
self.z_dim=z_dim | |
self.c_dim=c_dim | |
self.w_dim=w_dim | |
self.flame_condition = flame_condition | |
self.has_background = has_background | |
self.has_superresolution = has_superresolution | |
self.dynamic_texture = dynamic_texture | |
decoder_output_dim = 32 if has_superresolution else 3 | |
self.img_resolution=img_resolution | |
self.img_channels=img_channels | |
self.renderer = DeformImportanceRenderer() | |
self.ray_sampler = RaySampler() | |
self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32*3, mapping_kwargs=mapping_kwargs, **synthesis_kwargs) | |
self.to_dynamic = None | |
if self.has_background: | |
self.background = StyleGAN2Backbone(z_dim, 0, w_dim, img_resolution=64, mapping_kwargs={'num_layers':8}, channel_base=16384, img_channels=decoder_output_dim) | |
if self.has_superresolution: | |
self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=32, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs) | |
else: | |
self.superresolution = None | |
self.decoder = OSGDecoder(32, {'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), 'decoder_output_dim': decoder_output_dim}) | |
self.neural_rendering_resolution = 64 | |
self.rendering_kwargs = rendering_kwargs | |
self.deformer = DeformationModule(flame_full=flame_full,dynamic_texture=dynamic_texture, **deformation_kwargs) | |
self._last_planes = None | |
def _deformer(self,shape_params,exp_params,pose_params,eye_pose_params,ws,c_deform,cache_backbone=False, use_cached_backbone=False, use_rotation_limits=False, eye_blink_params=None): | |
return lambda coordinates: self.deformer(coordinates, shape_params,exp_params,pose_params,eye_pose_params,ws,c_deform,cache_backbone=cache_backbone,use_cached_backbone=use_cached_backbone, use_rotation_limits=use_rotation_limits) | |
def mapping(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): | |
if self.rendering_kwargs['c_gen_conditioning_zero']: | |
c = torch.zeros_like(c) | |
return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) | |
def synthesis(self, ws, z_bg, c, _deformer, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, use_dynamic=False,use_rotation_limits=None, smpl_param=None, patch_scale=None, chunk=None, run_full=None, uv=None, diff_dynamic=False, forward_mode='train', eye_blink_params=None, ws_super=None, **synthesis_kwargs): | |
if forward_mode == 'train': | |
face_ws = ws | |
dynamic_ws = ws | |
# for inversion only | |
elif ws.shape[1] >= self.backbone.num_ws + self.background.num_ws: | |
face_ws, bg_ws, dynamic_ws = ws[:, :self.backbone.num_ws, :], ws[:, self.backbone.num_ws:self.backbone.num_ws+self.background.num_ws, :], ws[:, self.backbone.num_ws+self.background.num_ws:, :] | |
else: | |
face_ws, bg_ws, dynamic_ws = ws[:, :self.backbone.num_ws, :], ws[:, self.backbone.num_ws:-1, :], ws[:, self.backbone.num_ws-1:self.backbone.num_ws, :] | |
cam2world_matrix = c[:, :16].view(-1, 4, 4) | |
world2cam_matrix = cam_world_matrix_transform(cam2world_matrix) | |
cam_z = world2cam_matrix[:,2,3] | |
intrinsics = c[:, 16:25].view(-1, 3, 3) | |
if neural_rendering_resolution is None: | |
neural_rendering_resolution = self.neural_rendering_resolution | |
else: | |
self.neural_rendering_resolution = neural_rendering_resolution | |
# Create a batch of rays for volume rendering | |
ray_origins, ray_directions, _ = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution) | |
# Create triplanes by running StyleGAN backbone | |
N, M, _ = ray_origins.shape | |
if use_cached_backbone and self._last_planes is not None: | |
planes = self._last_planes | |
else: | |
planes, last_featuremap = self.backbone.synthesis(face_ws, update_emas=update_emas, **synthesis_kwargs) | |
if cache_backbone: | |
self._last_planes = planes | |
planes = planes.view(len(planes), 3, -1, planes.shape[-2], planes.shape[-1]) | |
if self.dynamic_texture: | |
pass # deprecated | |
else: | |
dynamic_planes = None | |
# Perform volume rendering | |
feature_samples, depth_samples, all_depths, all_weights, T_bg, offset, dist_to_surface, vts_mask, vts_mask_region, coarse_sample_points, coarse_triplane_features = self.renderer(planes, self.decoder, _deformer, ray_origins, ray_directions, self.rendering_kwargs, dynamic=self.dynamic_texture, cam_z=cam_z) # channels last | |
weights_samples = all_weights.sum(2) | |
# Reshape into 'raw' neural-rendered image | |
H = W = self.neural_rendering_resolution | |
feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous() | |
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) | |
# background | |
if self.has_background: | |
if forward_mode == 'train': | |
background = self.background(z_bg, c, **synthesis_kwargs) | |
else: | |
background, _ = self.background.synthesis(bg_ws, update_emas=update_emas, **synthesis_kwargs) | |
background = torch.sigmoid(background) # (-1,1) (N,3,H,W) | |
if background.shape[-1] != neural_rendering_resolution: | |
background = F.interpolate(background,size=(neural_rendering_resolution,neural_rendering_resolution),mode='bilinear') | |
T_bg = T_bg.permute(0, 2, 1).reshape(N, 1, H, W) | |
feature_image = feature_image + T_bg*background | |
else: | |
T_bg = T_bg.permute(0, 2, 1).reshape(N, 1, H, W) | |
background = 0. | |
feature_image = 2*feature_image - 1 | |
rgb_image = feature_image[:, :3] | |
if self.superresolution is not None: | |
sr_image = self.superresolution(rgb_image, feature_image, face_ws, ws_super=ws_super, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'}) | |
else: | |
sr_image = None | |
return {'image': rgb_image, 'image_feature':feature_image, 'image_sr':sr_image, 'image_depth': depth_image, 'background':2*background-1, 'interval':all_depths.squeeze(-1), 'all_weights':all_weights.squeeze(-1), 'T_bg': T_bg, \ | |
'seg': (1 - T_bg)*2 - 1, 'offset':offset, 'dist_to_surface':dist_to_surface, 'vts_mask':vts_mask, 'vts_mask_region':vts_mask_region, 'dynamic_planes':dynamic_planes, 'coarse_sample_points':coarse_sample_points, 'coarse_triplane_features':coarse_triplane_features} | |
def sample(self, coordinates, directions, shape_params,exp_params,pose_params,eye_pose_params, z, c, use_deform=True, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): | |
# Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. | |
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) | |
planes, last_featuremap = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) | |
planes = planes.view(len(planes), 3, -1, planes.shape[-2], planes.shape[-1]) | |
# target space to canonical space deformation | |
if use_deform: | |
_deformer = self._deformer(shape_params,exp_params,pose_params,eye_pose_params) | |
out_deform = _deformer(coordinates) | |
coordinates = out_deform['canonical'] | |
offset = out_deform['offset'] | |
dynamic_mask = out_deform['dynamic_mask'] | |
else: | |
coordinates = coordinates | |
offset = torch.zeros_like(coordinates) | |
dynamic_mask = None | |
out = self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs, dynamic_mask=dynamic_mask) | |
out['canonical'] = coordinates | |
out['offset'] = offset | |
return out | |
def sample_mixed(self, coordinates, directions, shape_params,exp_params,pose_params,eye_pose_params, ws, use_deform=True, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): | |
# Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z' | |
planes, last_featuremap = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs) | |
# planes = torch.tanh(planes) | |
planes = planes.view(len(planes), 3, -1, planes.shape[-2], planes.shape[-1]) | |
# target space to canonical space deformation | |
if use_deform: | |
_deformer = self._deformer(shape_params,exp_params,pose_params,eye_pose_params) | |
out_deform = _deformer(coordinates) | |
coordinates = out_deform['canonical'] | |
offset = out_deform['offset'] | |
dynamic_mask = out_deform['dynamic_mask'] | |
else: | |
coordinates = coordinates | |
offset = torch.zeros_like(coordinates) | |
dynamic_mask = None | |
out = self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs, dynamic_mask=dynamic_mask) | |
out['canonical'] = coordinates | |
out['offset'] = offset | |
return out | |
def forward(self, shape_params,exp_params,pose_params,eye_pose_params, z, z_bg, c, c_compose, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, patch_scale=None, **synthesis_kwargs): | |
# Render a batch of generated images. | |
_deformer = self._deformer(shape_params,exp_params,pose_params,eye_pose_params) | |
c_compose_condition = c_compose.clone() | |
if self.flame_condition: | |
c_compose_condition = torch.cat([c_compose_condition,shape_params,exp_params],dim=-1) | |
ws = self.mapping(z, c_compose_condition, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) | |
# Render correspondence map as condition to the discriminator | |
uv = self.deformer.renderer(shape_params, exp_params, pose_params, eye_pose_params, c, half_size=int(self.img_resolution/2))[0] | |
img = self.synthesis(ws, z_bg, c, _deformer=_deformer, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs) | |
img['uv'] = uv | |
return img | |
# Generator used in GenHead, with part-wise deformation | |
class PartTriPlaneGeneratorDeform(torch.nn.Module): | |
def __init__(self, | |
z_dim, # Input latent (Z) dimensionality. | |
c_dim, # Conditioning label (C) dimensionality. | |
w_dim, # Intermediate latent (W) dimensionality. | |
img_resolution, # Output resolution. | |
img_channels, # Number of output color channels. | |
triplane_channels, | |
sr_num_fp16_res = 0, | |
mapping_kwargs = {}, # Arguments for MappingNetwork. | |
rendering_kwargs = {}, | |
deformation_kwargs = {}, | |
sr_kwargs = {}, | |
has_background = True, | |
has_superresolution = False, | |
has_patch_sr = False, | |
flame_condition = True, | |
flame_full = False, | |
dynamic_texture = False, # Deprecated | |
random_combine = True, | |
add_block = False, | |
triplane_resolution = 256, | |
masked_sampling = False, | |
**synthesis_kwargs, # Arguments for SynthesisNetwork. | |
): | |
super().__init__() | |
self.z_dim=z_dim | |
self.c_dim=c_dim | |
self.w_dim=w_dim | |
self.flame_condition = flame_condition | |
self.dynamic_texture = dynamic_texture | |
self.has_background = has_background | |
self.has_superresolution = has_superresolution | |
self.has_patch_sr = has_patch_sr | |
decoder_output_dim = 32 if has_superresolution else 3 | |
self.img_resolution=img_resolution | |
self.img_channels=img_channels | |
self.renderer = PartDeformImportanceRenderer() if triplane_channels>96 else DeformImportanceRenderer() | |
self.mouth_part = True #if triplane_channels>192 else False | |
self.mouth_dynamic = False | |
self.masked_sampling = masked_sampling | |
self.ray_sampler = RaySampler() | |
self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=triplane_resolution, img_channels=triplane_channels, mapping_kwargs=mapping_kwargs, add_block=add_block, **synthesis_kwargs) | |
if self.has_background: | |
self.background = StyleGAN2Backbone(z_dim, 0, w_dim, img_resolution=64, mapping_kwargs={'num_layers':8}, channel_base=16384, img_channels=decoder_output_dim) | |
if self.has_superresolution: | |
self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=32, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs) | |
else: | |
self.superresolution = None | |
self.decoder = OSGDecoder(32, {'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), 'decoder_output_dim': decoder_output_dim}) | |
self.neural_rendering_resolution = 64 | |
self.rendering_kwargs = rendering_kwargs | |
if self.has_patch_sr: | |
self.patch_sr = SuperresolutionPatchMLP(channels=32, img_resolution=None, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=True) | |
self.to_dynamic = None | |
self.to_dynamic_sr = None | |
self.deformer = DeformationModule(flame_full=flame_full,dynamic_texture=dynamic_texture,part=True,**deformation_kwargs) | |
self._last_planes = None | |
self._last_dynamic_planes = None | |
self.max_pool = nn.MaxPool2d(kernel_size=7, stride=1, padding=3) | |
def _warping(self,images,flows): | |
# images [(B, M, C, H, W)] | |
# flows (B, M, 2, H, W) # inverse warpping flow | |
warp_images = [] | |
B, M, _, H_f, W_f = flows.shape | |
flows = flows.view(B*M,2,H_f, W_f) | |
for im in images: | |
B, M, C, H, W = im.shape | |
im = im.view(B*M, C, H, W) | |
y, x = torch.meshgrid(torch.linspace(-1, 1, H, dtype=torch.float32, device=im.device), torch.linspace(-1, 1, W, dtype=torch.float32, device=im.device), indexing='ij') | |
xy = torch.stack([x, y], dim=-1).unsqueeze(0).repeat(B,1,1,1) #(B,H,W,2) | |
if H_f != H: | |
_flows = F.interpolate(flows,size=(H,W), mode='bilinear', align_corners=True) | |
else: | |
_flows = flows | |
_flows = _flows.permute(0,2,3,1) #(B,H,W,2) | |
uv = _flows + xy | |
warp_image = F.grid_sample(im, uv, mode='bilinear', padding_mode='zeros', align_corners=True) #(B,C,H,W) | |
warp_image = warp_image.view(B, M, C, H, W) | |
warp_images.append(warp_image) | |
return warp_images | |
def _deformer(self,shape_params,exp_params,pose_params,eye_pose_params,eye_blink_params=None,exp_params_dynamics=None,cache_backbone=False, use_cached_backbone=False,use_rotation_limits=False): | |
return lambda coordinates, mouth: self.deformer(coordinates, shape_params,exp_params,pose_params,eye_pose_params,eye_blink_params=eye_blink_params,exp_params_dynamics=exp_params_dynamics,cache_backbone=cache_backbone,use_cached_backbone=use_cached_backbone,use_rotation_limits=use_rotation_limits,mouth=mouth) | |
def mapping(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): | |
if self.rendering_kwargs['c_gen_conditioning_zero']: | |
c = torch.zeros_like(c) | |
return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) | |
def synthesis(self, ws, z_bg, c, _deformer, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, use_dynamic=False, use_rotation_limits=None, smpl_param=None, eye_blink_params=None, patch_scale=1.0, run_full=True, uv=None, chunk=None, diff_dynamic=False, dense_eye=False, forward_mode='train', ws_super=None, **synthesis_kwargs): | |
if forward_mode == 'train': | |
face_ws = ws | |
dynamic_ws = ws | |
# for inversion only | |
elif ws.shape[1] >= self.backbone.num_ws + self.background.num_ws: | |
face_ws, bg_ws, dynamic_ws = ws[:, :self.backbone.num_ws, :], ws[:, self.backbone.num_ws:self.backbone.num_ws+self.background.num_ws, :], ws[:, self.backbone.num_ws+self.background.num_ws:, :] | |
else: | |
face_ws, bg_ws, dynamic_ws = ws[:, :self.backbone.num_ws, :], ws[:, self.backbone.num_ws:-1, :], ws[:, self.backbone.num_ws-1:self.backbone.num_ws, :] | |
cam2world_matrix = c[:, :16].view(-1, 4, 4) | |
world2cam_matrix = cam_world_matrix_transform(cam2world_matrix) | |
cam_z = world2cam_matrix[:,2,3] | |
intrinsics = c[:, 16:25].view(-1, 3, 3) | |
N = cam2world_matrix.shape[0] | |
if neural_rendering_resolution is None: | |
neural_rendering_resolution = self.neural_rendering_resolution | |
elif self.training: | |
self.neural_rendering_resolution = neural_rendering_resolution | |
H = W = neural_rendering_resolution | |
with torch.no_grad(): | |
eye_mask = self.deformer.renderer(smpl_param[0], smpl_param[1], smpl_param[2], smpl_param[3], c, half_size = int(self.img_resolution/2), eye_blink_params=eye_blink_params, eye_mask=True, use_rotation_limits=use_rotation_limits)[1] | |
face_wo_eye_mask = self.deformer.renderer(smpl_param[0], smpl_param[1], smpl_param[2], smpl_param[3], c, half_size = int(self.img_resolution/2), eye_blink_params=eye_blink_params, face_woeye=True, use_rotation_limits=use_rotation_limits)[1] | |
eye_mask = eye_mask * (1-face_wo_eye_mask) | |
blur_sigma = 1 | |
blur_size = blur_sigma * 3 | |
f = torch.arange(-blur_size, blur_size + 1, device=eye_mask.device).div(blur_sigma).square().neg().exp2() | |
eye_mask_sr = upfirdn2d.filter2d(eye_mask, f / f.sum()) | |
eye_mask = torch.nn.functional.interpolate(eye_mask, size=(self.neural_rendering_resolution), mode='bilinear', align_corners=False, antialias=True) | |
head_mask = self.deformer.renderer(smpl_param[0], smpl_param[1], smpl_param[2], smpl_param[3], c, half_size = int(self.img_resolution/2), eye_blink_params=eye_blink_params, only_face=False,cull_backfaces=False, use_rotation_limits=use_rotation_limits)[1] | |
if self.mouth_part: | |
head_wo_mouth_mask = self.deformer.renderer(smpl_param[0], smpl_param[1], smpl_param[2], smpl_param[3], c, half_size = int(self.img_resolution/2), eye_blink_params=eye_blink_params, only_face=False,cull_backfaces=True, noinmouth=True, use_rotation_limits=use_rotation_limits)[1] | |
mouth_mask = head_mask * (1-head_wo_mouth_mask) | |
mouth_mask_sr = -self.max_pool(-mouth_mask) | |
mouth_mask_sr = self.max_pool(mouth_mask_sr) | |
blur_sigma = 2 | |
blur_size = blur_sigma * 3 | |
f = torch.arange(-blur_size, blur_size + 1, device=mouth_mask_sr.device).div(blur_sigma).square().neg().exp2() | |
mouth_mask_sr = upfirdn2d.filter2d(mouth_mask_sr, f / f.sum()) | |
mouth_mask = torch.nn.functional.interpolate(mouth_mask, size=(self.neural_rendering_resolution), mode='bilinear', align_corners=False, antialias=True) | |
mouth_mask_sr = (mouth_mask_sr + torch.nn.functional.interpolate(mouth_mask, size=(self.img_resolution), mode='bilinear', align_corners=False, antialias=True)).clamp(max=1) | |
# mouth_mask_sr[:,:,:128,:] *= 0 # for visualization only (deprecated) | |
else: | |
mouth_mask = None | |
mouth_mask_sr = None | |
head_mask = torch.nn.functional.interpolate(head_mask, size=(neural_rendering_resolution), mode='bilinear', align_corners=False, antialias=True) | |
head_mask_sr = torch.nn.functional.interpolate(head_mask, size=(self.img_resolution), mode='bilinear', align_corners=False, antialias=True) | |
# Create triplanes by running StyleGAN backbone | |
# N, M, _ = ray_origins.shape | |
if use_cached_backbone and self._last_planes is not None: | |
planes = self._last_planes | |
dynamic_planes = self._last_dynamic_planes | |
else: | |
planes, last_featuremap = self.backbone.synthesis(face_ws, update_emas=update_emas, **synthesis_kwargs) | |
# Reshape output into three 32-channel planes | |
if not isinstance(planes, list): | |
planes = [planes] | |
# last_featuremap = [last_featuremap] | |
planes = [p.view(len(p), -1, 32, p.shape[-2], p.shape[-1]) for p in planes] | |
if self.dynamic_texture: | |
pass # deprecated | |
else: | |
dynamic_planes = None | |
if cache_backbone: | |
self._last_planes = planes | |
self._last_dynamic_planes = dynamic_planes | |
if self.has_background: | |
if forward_mode == 'train': | |
background = self.background(z_bg, c, **synthesis_kwargs) | |
else: | |
background, _ = self.background.synthesis(bg_ws, update_emas=update_emas, **synthesis_kwargs) | |
background = torch.sigmoid(background) # (-1,1) (N,3,H,W) | |
background = F.interpolate(background,size=(256, 256),mode='bilinear') | |
background_feature = F.interpolate(background,size=(self.img_resolution,self.img_resolution),mode='bilinear') | |
if background.shape[-1] != neural_rendering_resolution: | |
background = F.interpolate(background,size=(neural_rendering_resolution,neural_rendering_resolution),mode='bilinear') | |
# Create a batch of rays for volume rendering | |
output = {} | |
coarse_sample_points = coarse_triplane_features = None | |
if run_full: | |
ray_origins, ray_directions, _ = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution) | |
mouth_mask_flat = mouth_mask[:, 0].view(mouth_mask.shape[0], -1, 1) if neural_rendering_resolution==self.neural_rendering_resolution else mouth_mask_sr[:, 0].view(mouth_mask_sr.shape[0], -1, 1) | |
eye_mask_flat = eye_mask[:, 0].view(eye_mask.shape[0], -1, 1) if neural_rendering_resolution==self.neural_rendering_resolution else eye_mask_sr[:, 0].view(eye_mask_sr.shape[0], -1, 1) | |
# Perform volume rendering | |
if chunk is None: | |
feature_samples, depth_samples, all_depths, all_weights, T_bg, offset, dist_to_surface, densities_face_ineye, densities_face_inmouth,vts_mask, vts_mask_region, coarse_sample_points, coarse_triplane_features, eye_mask_sel, mouth_mask_sel \ | |
= self.renderer((planes[0:1], planes)[neural_rendering_resolution>64], self.decoder, _deformer, ray_origins, ray_directions, self.rendering_kwargs, eye_mask=eye_mask_flat, mouth_mask=mouth_mask_flat, mouth_dynamic=self.mouth_dynamic, auto_fuse=(False, True)[neural_rendering_resolution>64],cam_z=cam_z) # channels last | |
if dense_eye: # only for batchsize=1 | |
is_eye_region = (eye_mask_flat!=0).squeeze(0).squeeze(-1) | |
if torch.sum(is_eye_region.to(torch.float32)) == 0: | |
feature_samples_eye = 0 | |
T_bg_eye = 0 | |
else: | |
ray_origins_eye = ray_origins[:,is_eye_region] | |
ray_directions_eye = ray_directions[:,is_eye_region] | |
eye_mask_eye = eye_mask_flat[:,is_eye_region] | |
rendering_kwargs_eye = copy.deepcopy(self.rendering_kwargs) | |
rendering_kwargs_eye['depth_resolution'] = 128 | |
rendering_kwargs_eye['depth_resolution_importance'] = 128 | |
feature_samples_eye, depth_samples_eye, all_depths_eye, all_weights_eye, T_bg_eye, offset_eye, dist_to_surface_eye, densities_face_ineye_eye, densities_face_inmouth_eye,vts_mask_eye, vts_mask_region_eye, _, _, _, _ = self.renderer((planes[0:1], planes)[neural_rendering_resolution>64], self.decoder, _deformer, ray_origins_eye, ray_directions_eye, rendering_kwargs_eye, eye_mask=eye_mask_eye, mouth_mask=None, mouth_dynamic=self.mouth_dynamic, auto_fuse=(False, True)[neural_rendering_resolution>64]) # channels last | |
else: | |
feature_samples, depth_samples, all_depths, all_weights, T_bg, offset, dist_to_surface, densities_face_ineye, densities_face_inmouth, vts_mask, vts_mask_region = list(), list(), list(), list(), list(), list(), list(), list(), list(), list(), list() | |
for _ro, _rd, _em, _mm in zip(torch.split(ray_origins, chunk, dim=1), torch.split(ray_directions, chunk, dim=1), torch.split(eye_mask_flat, chunk, dim=1), torch.split(mouth_mask_flat, chunk, dim=1)): | |
_f, _d, _ad, _aw, _tbg, _off, _ds, _dfe, _dfm, _vm, _vmr = self.renderer((planes[0:1], planes)[neural_rendering_resolution>64], self.decoder, _deformer, _ro, _rd, self.rendering_kwargs, eye_mask=_em, mouth_mask=_mm, mouth_dynamic=self.mouth_dynamic, auto_fuse=(False, True)[neural_rendering_resolution>64],cam_z=cam_z) # channels last | |
feature_samples.append(_f) | |
depth_samples.append(_d) | |
all_depths.append(_ad) | |
all_weights.append(_aw) | |
T_bg.append(_tbg) | |
offset.append(_off) | |
dist_to_surface.append(_ds) | |
densities_face_ineye.append(_dfe) | |
densities_face_inmouth.append(_dfm) | |
vts_mask.append(_vm) | |
vts_mask_region.append(_vmr) | |
feature_samples = torch.cat(feature_samples, 1) | |
depth_samples = torch.cat(depth_samples, 1) | |
all_depths = torch.cat(all_depths, 1) | |
all_weights = torch.cat(all_weights, 1) | |
T_bg = torch.cat(T_bg, 1) | |
offset = torch.cat(offset, 1) | |
dist_to_surface = torch.cat(dist_to_surface, 1) | |
densities_face_ineye = torch.cat(densities_face_ineye, 1) | |
densities_face_inmouth = torch.cat(densities_face_inmouth, 1) | |
vts_mask = torch.cat(vts_mask, 1) | |
vts_mask_region = torch.cat(vts_mask_region, 1) | |
weights_samples = all_weights.sum(2) | |
if dense_eye: | |
feature_samples[:,is_eye_region] = feature_samples_eye | |
T_bg[:,is_eye_region] = T_bg_eye | |
# Reshape into 'raw' neural-rendered image | |
# H = W = self.neural_rendering_resolution | |
feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous() | |
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) | |
# background | |
if feature_image.shape[-1]<=128: | |
T_bg = T_bg.permute(0, 2, 1).reshape(N, 1, H, W) | |
if self.has_background: | |
feature_image = feature_image + T_bg*background | |
else: | |
feature_image = feature_image + T_bg | |
feature_image = 2*feature_image - 1 | |
rgb_image = feature_image[:, :3] | |
if self.superresolution is not None and rgb_image.shape[-1]<=128: | |
sr_image = self.superresolution(rgb_image, feature_image, ws, ws_super=ws_super, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'}) | |
else: | |
sr_image = rgb_image | |
else: | |
rgb_image = feature_image[:, :3] | |
if self.has_background: | |
background_sr = 2*background - 1 | |
background_sr = self.superresolution(background_sr[:, :3], background_sr, ws, ws_super=ws_super, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'}) | |
background_sr = (background_sr + 1) * 0.5 | |
T_bg = T_bg.permute(0, 2, 1).reshape(N, 1, H, W) | |
rgb_image = rgb_image + T_bg*background_sr | |
else: | |
T_bg = T_bg.permute(0, 2, 1).reshape(N, 1, H, W) | |
rgb_image = rgb_image + T_bg | |
rgb_image = 2*rgb_image - 1 | |
sr_image = rgb_image | |
if self.has_patch_sr: | |
rgb_image_ = rgb_image | |
background_feature = torch.cat([background_sr,background_feature[:,3:]],dim=1) | |
feature_image_ = feature_image + T_bg * background_feature | |
feature_image_ = 2*feature_image_ - 1 | |
sr_rgb_image = self.patch_sr(rgb_image_[:, :3], feature_image_, torch.ones_like(ws)) | |
output.update({'image_raw_sr': sr_rgb_image}) | |
output.update({'image': rgb_image, 'image_feature':feature_image,'image_sr':sr_image, 'background':2*background-1, 'image_depth': depth_image, 'interval':all_depths.squeeze(-1), 'all_weights':all_weights.squeeze(-1), \ | |
'T_bg': T_bg, 'offset':offset, 'dist_to_surface':dist_to_surface, 'eye_mask': eye_mask, 'head_mask': head_mask, 'mouth_mask': mouth_mask, 'mouth_mask_sr': mouth_mask_sr+eye_mask_sr, \ | |
'densities_face_ineye': densities_face_ineye, 'densities_face_inmouth': densities_face_inmouth, 'seg': (1 - T_bg)*2 - 1, 'vts_mask':vts_mask, 'vts_mask_region':vts_mask_region, 'dynamic_planes':dynamic_planes, 'uv': uv,\ | |
'coarse_sample_points':coarse_sample_points, 'coarse_triplane_features':coarse_triplane_features, 'eye_mask_sel':eye_mask_sel, 'mouth_mask_sel':mouth_mask_sel}) | |
if patch_scale<1: | |
patch_ray_origins, patch_ray_directions, patch_info = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution, patch_scale=patch_scale, mask=[eye_mask_sr+mouth_mask_sr, head_mask_sr], masked_sampling=self.masked_sampling) | |
if self.has_background: | |
background_sr = 2*background - 1 | |
background_sr = self.superresolution(background_sr[:, :3], background_sr, ws, ws_super=ws_super, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'}) | |
background_sr = (background_sr+1)*0.5 | |
patch_background_sr = [] | |
patch_background_feature = [] | |
patch_eye_mask = [] | |
if uv is not None: | |
patch_uv = [] | |
if run_full: | |
patch_sr_image = [] | |
patch_rgb_image = [] | |
sr_image_ = sr_image.detach() | |
rgb_image_ = rgb_image.detach() | |
rgb_image_ = torch.nn.functional.interpolate(rgb_image_, size=(sr_image_.shape[-1]), | |
mode='bilinear', align_corners=False, antialias=True) | |
if self.mouth_part: | |
patch_mouth_mask = [] | |
for i in range(len(patch_info)): | |
top, left = patch_info[i] | |
patch_eye_mask.append(eye_mask_sr[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution]) | |
if uv is not None: | |
patch_uv.append(uv[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution]) | |
if run_full: | |
patch_sr_image.append(sr_image_[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution]) | |
patch_rgb_image.append(rgb_image_[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution]) | |
if self.mouth_part: | |
patch_mouth_mask.append(mouth_mask_sr[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution]) | |
if self.has_background: | |
patch_background_sr.append(background_sr[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution]) | |
patch_background_feature.append(background_feature[i:i+1, :, top:top+neural_rendering_resolution, left:left+neural_rendering_resolution]) | |
patch_eye_mask = torch.cat(patch_eye_mask, 0) | |
if uv is not None: | |
patch_uv = torch.cat(patch_uv, 0) | |
else: | |
patch_uv = None | |
if run_full: | |
patch_sr_image = torch.cat(patch_sr_image, 0) | |
patch_rgb_image = torch.cat(patch_rgb_image, 0) | |
output.update({'patch_image': patch_sr_image, 'patch_image_gr': patch_rgb_image}) | |
if self.mouth_part: | |
patch_mouth_mask = torch.cat(patch_mouth_mask, 0) | |
if self.has_background: | |
patch_background_sr = torch.cat(patch_background_sr, 0) | |
patch_background_feature = torch.cat(patch_background_feature, 0) | |
# Perform volume rendering | |
patch_mouth_mask_flat = patch_mouth_mask[:, 0].view(mouth_mask.shape[0], -1, 1) | |
patch_eye_mask_flat = patch_eye_mask[:, 0].view(eye_mask.shape[0], -1, 1) | |
patch_feature_samples, patch_depth_samples, patch_all_depths, patch_all_weights, patch_T_bg, patch_offset, patch_dist_to_surface, patch_densities_face_ineye, patch_densities_face_inmouth, patch_vts_mask, patch_vts_mask_region, _, _, _, _ = self.renderer(planes, self.decoder, _deformer, patch_ray_origins, patch_ray_directions, self.rendering_kwargs, eye_mask=patch_eye_mask_flat, mouth_mask=patch_mouth_mask_flat, mouth_dynamic=self.mouth_dynamic, auto_fuse=True,cam_z=cam_z) # channels last | |
# Reshape into 'raw' neural-rendered image | |
patch_feature_image = patch_feature_samples.permute(0, 2, 1).reshape(N, patch_feature_samples.shape[-1], H, W).contiguous() | |
patch_depth_image = patch_depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) | |
patch_rgb_image = patch_feature_image[:, :3] | |
if self.has_background: | |
patch_T_bg = patch_T_bg.permute(0, 2, 1).reshape(N, 1, H, W) | |
patch_rgb_image = patch_rgb_image + patch_T_bg * patch_background_sr | |
patch_rgb_image = 2*patch_rgb_image-1 | |
if self.has_patch_sr: | |
patch_rgb_image_ = patch_rgb_image.clone().detach() | |
patch_background_feature = torch.cat([patch_background_sr,patch_background_feature[:,3:]],dim=1) | |
patch_feature_image_ = patch_feature_image.clone().detach() + patch_T_bg.clone().detach() * patch_background_feature.clone().detach() | |
patch_feature_image_ = 2*patch_feature_image_ - 1 | |
sr_patch_rgb_image = self.patch_sr(patch_rgb_image_[:, :3], patch_feature_image_, torch.ones_like(ws)) | |
output.update({'patch_image_raw_sr': sr_patch_rgb_image}) | |
output.update({'patch_image_raw': patch_rgb_image, 'patch_seg': (1 - patch_T_bg)*2 - 1, 'patch_T_bg': patch_T_bg, 'patch_uv': patch_uv, 'patch_mouth_mask': patch_mouth_mask, 'patch_all_depths': patch_all_depths.squeeze(-1), 'patch_all_weights': patch_all_weights.squeeze(-1)}) | |
return output | |
def sample(self, coordinates, directions, shape_params,exp_params,pose_params,eye_pose_params, z, c, use_deform=True, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): | |
# Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. | |
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) | |
planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) | |
# planes = torch.tanh(planes) | |
planes = planes.view(len(planes), -1, 32, planes.shape[-2], planes.shape[-1]) | |
# target space to canonical space deformation | |
if use_deform: | |
_deformer = self._deformer(shape_params,exp_params,pose_params,eye_pose_params) | |
out_deform = _deformer(coordinates) | |
coordinates = out_deform['canonical'] | |
offset = out_deform['offset'] | |
else: | |
coordinates = coordinates | |
offset = torch.zeros_like(coordinates) | |
out = self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) | |
out['canonical'] = coordinates | |
out['offset'] = offset | |
# out['offset'] = torch.zeros_like(coordinates) | |
return out | |
def sample_mixed(self, coordinates, directions, shape_params,exp_params,pose_params,eye_pose_params, ws, use_deform=True, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): | |
# Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z' | |
planes, last_featuremap = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs) | |
# planes = planes.view(len(planes), -1, 32, planes.shape[-2], planes.shape[-1]) | |
if not isinstance(planes, list): | |
planes = [planes] | |
last_featuremap = [last_featuremap] | |
planes = [p.view(len(p), -1, 32, p.shape[-2], p.shape[-1]) for p in planes] | |
dynamic_planes = None | |
# target space to canonical space deformation | |
if use_deform: | |
_deformer = self._deformer(shape_params,exp_params,pose_params,eye_pose_params) | |
out_deform = _deformer(coordinates, mouth=self.mouth_part) | |
coordinates_eye = out_deform['canonical_eye'] | |
coordinates_face = out_deform['canonical_face'] | |
coordinates_mouth = out_deform['canonical_mouth'] | |
if out_deform['dynamic_mask'] is not None: | |
dynamic_mask = out_deform['dynamic_mask'] * (1-out_deform['inside_bbox_eye'][..., None]) | |
else: | |
dynamic_mask = out_deform['dynamic_mask'] | |
offset = torch.zeros_like(coordinates_eye) | |
else: | |
coordinates = coordinates | |
offset = torch.zeros_like(coordinates) | |
dynamic_mask = None | |
plane_eye = [p[:, :3] for p in planes] | |
if dynamic_mask is not None: | |
if self.mouth_part and self.mouth_dynamic: | |
plane_mouth = [torch.cat([p[:, :3],p[:, -3:]],dim=2) for p in planes] | |
else: | |
plane_mouth = [p[:, :3] for p in planes] | |
plane_face = [torch.cat([p[:, 3:6],p[:, -3:]],dim=2) for p in planes] | |
else: | |
if self.mouth_part: | |
plane_mouth = [p[:, :3] for p in planes] | |
plane_face = [p[:, 3:6] for p in planes] | |
out_eye = self.renderer.run_model(plane_eye, self.decoder, coordinates_eye, directions, self.rendering_kwargs,dynamic_mask=None) | |
out_face = self.renderer.run_model(plane_face, self.decoder, coordinates_face, directions, self.rendering_kwargs,dynamic_mask=dynamic_mask) | |
out_eye['canonical'] = coordinates_eye | |
out_face['canonical'] = coordinates_face | |
out_eye['offset'] = offset | |
out_face['offset'] = offset | |
if self.mouth_part: | |
out_mouth = self.renderer.run_model(plane_mouth, self.decoder, coordinates_mouth, directions, self.rendering_kwargs,dynamic_mask=(None, dynamic_mask)[self.mouth_dynamic]) | |
out_mouth['canonical'] = coordinates_mouth | |
out_mouth['offset'] = offset | |
return out_eye, out_face, out_mouth | |
else: | |
return out_eye, out_face | |
def forward(self, shape_params,exp_params,pose_params,eye_pose_params, z, z_bg, c, c_compose, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, patch_scale=1.0, **synthesis_kwargs): | |
# Render a batch of generated images. | |
_deformer = self._deformer(shape_params,exp_params,pose_params,eye_pose_params) | |
c_compose_condition = c_compose.clone() | |
if self.flame_condition: | |
c_compose_condition = torch.cat([c_compose_condition,shape_params],dim=-1) | |
ws = self.mapping(z, c_compose_condition, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) | |
# Render correspondence map as condition to the discriminator | |
render_out = self.deformer.renderer(shape_params, exp_params, pose_params, eye_pose_params, c, half_size=int(self.img_resolution/2)) | |
uv = render_out[0] | |
landmarks = render_out[-1] | |
img = self.synthesis(ws, z_bg, c, _deformer=_deformer, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, smpl_param=(shape_params, exp_params, pose_params, eye_pose_params), patch_scale=patch_scale, **synthesis_kwargs) | |
img['uv'] = uv | |
img['landmarks'] = landmarks | |
return img | |
def zero_init(m): | |
with torch.no_grad(): | |
nn.init.constant_(m.weight,0) | |
nn.init.constant_(m.bias,0) | |
class OSGDecoder(torch.nn.Module): | |
def __init__(self, n_features, options): | |
super().__init__() | |
self.hidden_dim = 64 | |
self.net = torch.nn.Sequential( | |
FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']), | |
torch.nn.Softplus() | |
) | |
self.out_sigma = FullyConnectedLayer(self.hidden_dim, 1, lr_multiplier=options['decoder_lr_mul']) | |
self.out_rgb = FullyConnectedLayer(self.hidden_dim, options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul']) | |
self.out_sigma.apply(zero_init) | |
def forward(self, sampled_features, ray_directions): | |
# Aggregate features | |
sampled_features = sampled_features.mean(1) | |
x = sampled_features | |
N, M, C = x.shape | |
x = x.view(N*M, C) | |
x = self.net(x) | |
rgb = self.out_rgb(x) | |
sigma = self.out_sigma(x) | |
rgb = rgb.view(N, M, -1) | |
sigma = sigma.view(N, M, -1) | |
rgb = torch.sigmoid(rgb)*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF | |
return {'rgb': rgb, 'sigma': sigma} | |