刘虹雨
update
8ed2f16
# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
"""
V4:
1. 使用相同的latent code控制两个stylegan (不共享梯度);
2. 正交投影的参数从2D改成了3D,使三次投影的变换一致;
3. 三平面变成四平面;
4. 三平面的顺序调换;
5. 生成嘴部的动态纹理, 和静态纹理融合 (Styleunet)
"""
from os import device_encoding
from turtle import update
import math
import torch
import numpy as np
import torch.nn.functional as F
from pytorch3d.io import load_obj
import cv2
from torchvision.utils import save_image
import dnnlib
from torch_utils import persistence
from training_avatar_texture.networks_stylegan2_next3d import Generator as StyleGAN2Backbone
from training_avatar_texture.networks_stylegan2_styleunet_next3d import Generator as CondStyleGAN2Backbone
from training_avatar_texture.volumetric_rendering.renderer_next3d import ImportanceRenderer
from training_avatar_texture.volumetric_rendering.ray_sampler import RaySampler
from training_avatar_texture.volumetric_rendering.renderer_next3d import Pytorch3dRasterizer, face_vertices, generate_triangles, transform_points, \
batch_orth_proj, angle2matrix
from training_avatar_texture.volumetric_rendering.renderer_next3d import fill_mouth
@persistence.persistent_class
class TriPlaneGenerator(torch.nn.Module):
def __init__(self,
z_dim, # Input latent (Z) dimensionality.
c_dim, # Conditioning label (C) dimensionality.
w_dim, # Intermediate latent (W) dimensionality.
img_resolution, # Output resolution.
img_channels, # Number of output color channels.
topology_path, #
sr_num_fp16_res=0,
mapping_kwargs={}, # Arguments for MappingNetwork.
rendering_kwargs={},
sr_kwargs={},
**synthesis_kwargs, # Arguments for SynthesisNetwork.
):
super().__init__()
self.z_dim = z_dim
self.c_dim = c_dim
self.w_dim = w_dim
self.img_resolution = img_resolution
self.img_channels = img_channels
self.topology_path = 'flame_head_template.obj'#topology_path
self.renderer = ImportanceRenderer()
self.ray_sampler = RaySampler()
self.texture_backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32, mapping_kwargs=mapping_kwargs,
**synthesis_kwargs) # render neural texture
self.mouth_backbone = CondStyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32, in_size=64, final_size=4,
cond_channels=32, num_cond_res=64, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32 * 3, mapping_ws=self.texture_backbone.num_ws * 2,
mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
# debug: use splitted w
self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=32,
img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res,
sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs)
self.decoder = OSGDecoder(32, {'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), 'decoder_output_dim': 32})
self.neural_rendering_resolution = 64
self.rendering_kwargs = rendering_kwargs
self._last_planes = None
self.load_lms = True
# set pytorch3d rasterizer
self.uv_resolution = 256
self.rasterizer = Pytorch3dRasterizer(image_size=256)
verts, faces, aux = load_obj(self.topology_path)
uvcoords = aux.verts_uvs[None, ...] # (N, V, 2)
uvfaces = faces.textures_idx[None, ...] # (N, F, 3)
faces = faces.verts_idx[None, ...]
# faces
dense_triangles = generate_triangles(self.uv_resolution, self.uv_resolution)
self.register_buffer('dense_faces', torch.from_numpy(dense_triangles).long()[None, :, :].contiguous())
self.register_buffer('faces', faces)
self.register_buffer('raw_uvcoords', uvcoords)
# eye masks
mask = cv2.imread('flame_uv_face_eye_mask.png').astype(np.float32) / 255.;
mask = torch.from_numpy(mask[:, :, 0])[None, None, :, :].contiguous()
self.uv_face_mask = F.interpolate(mask, [256, 256])
# mouth mask
self.fill_mouth = True
# uv coords
uvcoords = torch.cat([uvcoords, uvcoords[:, :, 0:1] * 0. + 1.], -1) # [bz, ntv, 3]
uvcoords = uvcoords * 2 - 1;
uvcoords[..., 1] = -uvcoords[..., 1]
face_uvcoords = face_vertices(uvcoords, uvfaces)
self.register_buffer('uvcoords', uvcoords)
self.register_buffer('uvfaces', uvfaces)
self.register_buffer('face_uvcoords', face_uvcoords)
self.orth_scale = torch.tensor([[5.0]])
self.orth_shift = torch.tensor([[0, -0.01, -0.01]])
# neural blending
self.neural_blending = CondStyleGAN2Backbone(z_dim, c_dim, w_dim, cond_channels=32, img_resolution=256, img_channels=32, in_size=256,
final_size=32, num_cond_res=256, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
def mapping(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False):
if self.rendering_kwargs['c_gen_conditioning_zero']:
c = torch.zeros_like(c)
c = c[:, :25] # remove expression labels
return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi,
truncation_cutoff=truncation_cutoff, update_emas=update_emas)
def synthesis(self, ws, c, v, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False,
**synthesis_kwargs):
# split vertices and landmarks
if self.load_lms:
v, lms = v[:, :5023], v[:, 5023:]
batch_size = ws.shape[0]
eg3d_ws, texture_ws = ws[:, :self.texture_backbone.num_ws], ws[:, self.texture_backbone.num_ws:]
cam2world_matrix = c[:, :16].view(-1, 4, 4)
intrinsics = c[:, 16:25].view(-1, 3, 3)
if neural_rendering_resolution is None:
neural_rendering_resolution = self.neural_rendering_resolution
else:
self.neural_rendering_resolution = neural_rendering_resolution
# Create a batch of rays for volume rendering
ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution)
# Create triplanes by running StyleGAN backbone
N, M, _ = ray_origins.shape
textures = self.texture_backbone.synthesis(texture_ws, update_emas=update_emas, **synthesis_kwargs)
# rasterize to three orthogonal views
rendering_views = [
[0, 0, 0],
[0, 90, 0],
[0, -90, 0],
[90, 0, 0]
]
rendering_images, alpha_images, uvcoords_images, lm2ds = self.rasterize(v, lms, textures, rendering_views, batch_size, ws.device)
# generate front mouth masks
rendering_image_front = rendering_images[0]
mouths_mask = self.gen_mouth_mask(lm2ds[0])
rendering_mouth = [rendering_image_front[i:i + 1, :][:, :, m[0]:m[1], m[2]:m[3]] for i, m in enumerate(mouths_mask)]
rendering_mouth = torch.cat([torch.nn.functional.interpolate(uv, size=(64, 64), mode='bilinear', antialias=True) for uv in rendering_mouth],
0)
# generate mouth front plane and integrate back to face front plane
mouths_plane = self.mouth_backbone.synthesis(rendering_mouth, eg3d_ws, update_emas=update_emas, **synthesis_kwargs)
rendering_stitch = []
for rendering, m, mouth_plane in zip(rendering_image_front, mouths_mask, mouths_plane):
rendering = rendering.unsqueeze(0)
dummy = torch.zeros_like(rendering)
dummy[:, :] = rendering
dummy[:, :, m[0]:m[1], m[2]:m[3]] = torch.nn.functional.interpolate(mouth_plane.unsqueeze(0), size=(m[1] - m[0], m[1] - m[0]),
mode='bilinear', antialias=True)
rendering_stitch.append(dummy)
rendering_stitch = torch.cat(rendering_stitch, 0)
rendering_stitch = self.neural_blending.synthesis(rendering_stitch, eg3d_ws, update_emas=update_emas, **synthesis_kwargs)
# generate static triplane
static_plane = self.backbone.synthesis(eg3d_ws, update_emas=update_emas, **synthesis_kwargs)
static_plane = static_plane.view(len(static_plane), 3, 32, static_plane.shape[-2], static_plane.shape[-1])
# blend features of neural texture and tri-plane
alpha_image = torch.cat(alpha_images, 1).unsqueeze(2)
rendering_stitch = torch.cat((rendering_stitch, rendering_images[1], rendering_images[2]), 1)
rendering_stitch = rendering_stitch.view(*static_plane.shape)
blended_planes = rendering_stitch * alpha_image + static_plane * (1 - alpha_image)
# Perform volume rendering
feature_samples, depth_samples, weights_samples = self.renderer(blended_planes, self.decoder, ray_origins, ray_directions,
self.rendering_kwargs) # channels last
# Reshape into 'raw' neural-rendered image
H = W = self.neural_rendering_resolution
feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous()
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
# Run superresolution to get final image
rgb_image = feature_image[:, :3]
sr_image = self.superresolution(rgb_image, feature_image, eg3d_ws, noise_mode=self.rendering_kwargs['superresolution_noise_mode'],
**{k: synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
return {'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image}
def rasterize(self, v, lms, textures, tforms, batch_size, device):
rendering_images, alpha_images, uvcoords_images, transformed_lms = [], [], [], []
for tform in tforms:
v_flip, lms_flip = v.detach().clone(), lms.detach().clone()
v_flip[..., 1] *= -1;
lms_flip[..., 1] *= -1
# rasterize texture to three orthogonal views
tform = angle2matrix(torch.tensor(tform).reshape(1, -1)).expand(batch_size, -1, -1).to(device)
transformed_vertices = (torch.bmm(v_flip, tform) + self.orth_shift.to(device)) * self.orth_scale.to(device)
transformed_vertices = batch_orth_proj(transformed_vertices, torch.tensor([1., 0, 0]).to(device))
transformed_vertices[:, :, 1:] = -transformed_vertices[:, :, 1:]
transformed_vertices[:, :, 2] = transformed_vertices[:, :, 2] + 10
transformed_lm = (torch.bmm(lms_flip, tform) + self.orth_shift.to(device)) * self.orth_scale.to(device)
transformed_lm = batch_orth_proj(transformed_lm, torch.tensor([1., 0, 0]).to(device))[:, :, :2]
transformed_lm[:, :, 1:] = -transformed_lm[:, :, 1:]
faces = self.faces.detach().clone()[..., [0, 2, 1]].expand(batch_size, -1, -1)
attributes = self.face_uvcoords.detach().clone()[:, :, [0, 2, 1]].expand(batch_size, -1, -1, -1)
rendering = self.rasterizer(transformed_vertices, faces, attributes, 256, 256)
alpha_image = rendering[:, -1, :, :][:, None, :, :].detach()
uvcoords_image = rendering[:, :-1, :, :];
grid = (uvcoords_image).permute(0, 2, 3, 1)[:, :, :, :2]
mask_face_eye = F.grid_sample(self.uv_face_mask.expand(batch_size, -1, -1, -1).to(device), grid.detach(), align_corners=False)
alpha_image = mask_face_eye * alpha_image
if self.fill_mouth:
alpha_image = fill_mouth(alpha_image)
uvcoords_image = mask_face_eye * uvcoords_image
rendering_image = F.grid_sample(textures, grid.detach(), align_corners=False)
rendering_images.append(rendering_image)
alpha_images.append(alpha_image)
uvcoords_images.append(uvcoords_image)
transformed_lms.append(transformed_lm)
rendering_image_side = rendering_images[1] + rendering_images[2] # concatenate two side-view renderings
alpha_image_side = (alpha_images[1].bool() | alpha_images[1].bool()).float()
rendering_images = [rendering_images[0], rendering_image_side, rendering_images[3]]
alpha_images = [alpha_images[0], alpha_image_side, alpha_images[3]]
return rendering_images, alpha_images, uvcoords_images, transformed_lms
def sample(self, coordinates, directions, z, c, v, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
# Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes.
if self.load_lms:
v, lms = v[:, :5023], v[:, 5023:]
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
batch_size = ws.shape[0]
eg3d_ws, texture_ws = ws[:, :self.texture_backbone.num_ws], ws[:, self.texture_backbone.num_ws:]
textures = self.texture_backbone.synthesis(texture_ws, update_emas=update_emas, **synthesis_kwargs)
# rasterize to three orthogonal views
rendering_views = [
[0, 0, 0],
[0, 90, 0],
[0, -90, 0],
[90, 0, 0]
]
rendering_images, alpha_images, uvcoords_images, lm2ds = self.rasterize(v, lms, textures, rendering_views, batch_size, ws.device)
# generate front mouth masks
rendering_image_front = rendering_images[0]
mouths_mask = self.gen_mouth_mask(lm2ds[0])
rendering_mouth = [rendering_image_front[i:i + 1, :][:, :, m[0]:m[1], m[2]:m[3]] for i, m in enumerate(mouths_mask)]
rendering_mouth = torch.cat([torch.nn.functional.interpolate(uv, size=(64, 64), mode='bilinear', antialias=True) for uv in rendering_mouth],
0)
# generate mouth front plane and integrate back to face front plane
mouths_plane = self.mouth_backbone.synthesis(rendering_mouth, eg3d_ws, update_emas=update_emas, **synthesis_kwargs)
rendering_stitch = []
for rendering, m, mouth_plane in zip(rendering_image_front, mouths_mask, mouths_plane):
rendering = rendering.unsqueeze(0)
dummy = torch.zeros_like(rendering)
dummy[:, :] = rendering
dummy[:, :, m[0]:m[1], m[2]:m[3]] = torch.nn.functional.interpolate(mouth_plane.unsqueeze(0), size=(m[1] - m[0], m[1] - m[0]),
mode='bilinear', antialias=True)
rendering_stitch.append(dummy)
rendering_stitch = torch.cat(rendering_stitch, 0)
rendering_stitch = self.neural_blending.synthesis(rendering_stitch, eg3d_ws, update_emas=update_emas, **synthesis_kwargs)
# generate static triplane
static_plane = self.backbone.synthesis(eg3d_ws, update_emas=update_emas, **synthesis_kwargs)
static_plane = static_plane.view(len(static_plane), 3, 32, static_plane.shape[-2], static_plane.shape[-1])
# blend features of neural texture and tri-plane
alpha_image = torch.cat(alpha_images, 1).unsqueeze(2)
rendering_stitch = torch.cat((rendering_stitch, rendering_images[1], rendering_images[2]), 1)
rendering_stitch = rendering_stitch.view(*static_plane.shape)
blended_planes = rendering_stitch * alpha_image + static_plane * (1 - alpha_image)
return self.renderer.run_model(blended_planes, self.decoder, coordinates, directions, self.rendering_kwargs)
def sample_mixed(self, coordinates, directions, ws, v, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
# Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
if self.load_lms:
v, lms = v[:, :5023], v[:, 5023:]
batch_size = ws.shape[0]
eg3d_ws, texture_ws = ws[:, :self.texture_backbone.num_ws], ws[:, self.texture_backbone.num_ws:]
textures = self.texture_backbone.synthesis(texture_ws, update_emas=update_emas, **synthesis_kwargs)
# rasterize to three orthogonal views
rendering_views = [
[0, 0, 0],
[0, 90, 0],
[0, -90, 0],
[90, 0, 0]
]
rendering_images, alpha_images, uvcoords_images, lm2ds = self.rasterize(v, lms, textures, rendering_views, batch_size, ws.device)
# generate front mouth masks
rendering_image_front = rendering_images[0]
mouths_mask = self.gen_mouth_mask(lm2ds[0])
rendering_mouth = [rendering_image_front[i:i + 1, :][:, :, m[0]:m[1], m[2]:m[3]] for i, m in enumerate(mouths_mask)]
rendering_mouth = torch.cat([torch.nn.functional.interpolate(uv, size=(64, 64), mode='bilinear', antialias=True) for uv in rendering_mouth],
0)
# generate mouth front plane and integrate back to face front plane
mouths_plane = self.mouth_backbone.synthesis(rendering_mouth, eg3d_ws, update_emas=update_emas, **synthesis_kwargs)
rendering_stitch = []
for rendering, m, mouth_plane in zip(rendering_image_front, mouths_mask, mouths_plane):
rendering = rendering.unsqueeze(0)
dummy = torch.zeros_like(rendering)
dummy[:, :] = rendering
dummy[:, :, m[0]:m[1], m[2]:m[3]] = torch.nn.functional.interpolate(mouth_plane.unsqueeze(0), size=(m[1] - m[0], m[1] - m[0]),
mode='bilinear', antialias=True)
rendering_stitch.append(dummy)
rendering_stitch = torch.cat(rendering_stitch, 0)
rendering_stitch = self.neural_blending.synthesis(rendering_stitch, eg3d_ws, update_emas=update_emas, **synthesis_kwargs)
# generate static triplane
static_plane = self.backbone.synthesis(eg3d_ws, update_emas=update_emas, **synthesis_kwargs)
static_plane = static_plane.view(len(static_plane), 3, 32, static_plane.shape[-2], static_plane.shape[-1])
# blend features of neural texture and tri-plane
alpha_image = torch.cat(alpha_images, 1).unsqueeze(2)
rendering_stitch = torch.cat((rendering_stitch, rendering_images[1], rendering_images[2]), 1)
rendering_stitch = rendering_stitch.view(*static_plane.shape)
blended_planes = rendering_stitch * alpha_image + static_plane * (1 - alpha_image)
return self.renderer.run_model(blended_planes, self.decoder, coordinates, directions, self.rendering_kwargs)
def forward(self, z, c, v, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False,
use_cached_backbone=False, **synthesis_kwargs):
# Render a batch of generated images.
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
return self.synthesis(ws, c, v, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution,
cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs)
def gen_mouth_mask(self, lms2d):
lm = lms2d.clone().cpu().numpy() # lms2d: (4, 68, 2)
lm[..., 0] = lm[..., 0] * 128 + 128
lm[..., 1] = lm[..., 1] * 128 + 128
lm_mouth_outer = lm[:, 48:60] # left-clockwise
mouth_left = lm_mouth_outer[:, 0]
mouth_right = lm_mouth_outer[:, 6]
mouth_avg = (mouth_left + mouth_right) * 0.5 # (4, 2)
ups, bottoms = np.max(lm_mouth_outer[..., 0], axis=1, keepdims=True), np.min(lm_mouth_outer[..., 0], axis=1, keepdims=True)
lefts, rights = np.min(lm_mouth_outer[..., 1], axis=1, keepdims=True), np.max(lm_mouth_outer[..., 1], axis=1, keepdims=True)
mask_res = np.max(np.concatenate((ups - bottoms, rights - lefts), axis=1), axis=1, keepdims=True) * 1.2
mask_res = mask_res.astype(int)
mouth_mask = np.concatenate([(mouth_avg[:, 1:] - mask_res // 2).astype(int), (mouth_avg[:, 1:] + mask_res // 2).astype(int),
(mouth_avg[:, 0:1] - mask_res // 2).astype(int), (mouth_avg[:, 0:1] + mask_res // 2).astype(int)], 1) # (4, 4)
return mouth_mask
from training.networks_stylegan2 import FullyConnectedLayer
class OSGDecoder(torch.nn.Module):
def __init__(self, n_features, options):
super().__init__()
self.hidden_dim = 64
self.net = torch.nn.Sequential(
FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']),
torch.nn.Softplus(),
FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul'])
)
def forward(self, sampled_features, ray_directions, sampled_embeddings=None):
# Aggregate features
sampled_features = sampled_features.mean(1)
x = sampled_features
N, M, C = x.shape
x = x.view(N * M, C)
x = self.net(x)
x = x.view(N, M, -1)
rgb = torch.sigmoid(x[..., 1:]) * (1 + 2 * 0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
sigma = x[..., 0:1]
return {'rgb': rgb, 'sigma': sigma}