|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
import os |
|
import copy |
|
import cv2 |
|
import hashlib |
|
import tqdm |
|
import scipy |
|
import uuid |
|
import random |
|
from PIL import Image |
|
|
|
from utils.commons.hparams import hparams |
|
from utils.commons.tensor_utils import tensors_to_scalars, convert_to_np, move_to_cuda |
|
from utils.nn.model_utils import not_requires_grad, num_params |
|
from utils.nn.schedulers import NoneSchedule |
|
from utils.commons.ckpt_utils import load_ckpt, get_last_checkpoint, restore_weights, restore_opt_state |
|
|
|
from utils.commons.base_task import BaseTask |
|
from tasks.os_avatar.loss_utils.vgg19_loss import VGG19Loss |
|
from modules.real3d.facev2v_warp.losses import PerceptualLoss |
|
from tasks.os_avatar.dataset_utils.motion2video_dataset import Img2Plane_Dataset |
|
import lpips |
|
from utils.commons.dataset_utils import data_loader |
|
|
|
from modules.real3d.img2plane_baseline import OSAvatar_Img2plane |
|
from modules.eg3ds.models.triplane import TriPlaneGenerator |
|
from modules.eg3ds.models.dual_discriminator import DualDiscriminator |
|
from modules.eg3ds.torch_utils.ops import conv2d_gradfix |
|
from modules.eg3ds.torch_utils.ops import upfirdn2d |
|
from modules.eg3ds.models.dual_discriminator import filtered_resizing |
|
|
|
|
|
class ScheduleForImg2Plane(NoneSchedule): |
|
def __init__(self, optimizer, lr, lr_d, warmup_updates=0): |
|
self.optimizer = optimizer |
|
self.constant_lr = self.lr = lr |
|
self.lr_d = lr_d |
|
self.warmup_updates = warmup_updates |
|
self.step(0) |
|
|
|
def step(self, num_updates): |
|
constant_lr = self.constant_lr |
|
if self.warmup_updates > 0 and num_updates <= self.warmup_updates: |
|
warmup = min(num_updates / self.warmup_updates, 1.0) |
|
self.lr = max(constant_lr * warmup, 1e-7) |
|
else: |
|
self.lr = constant_lr |
|
|
|
for optim_i in range(len(self.optimizer)-1): |
|
self.optimizer[optim_i].param_groups[0]['lr'] = max(1e-5, self.lr * (hparams.get("lr_decay_rate", 0.95)) ** (num_updates // hparams.get("lr_decay_interval", 5_000))) |
|
self.optimizer[optim_i].param_groups[1]['lr'] = max(1e-5, self.lr * (hparams.get("lr_decay_rate", 0.95)) ** (num_updates // hparams.get("lr_decay_interval", 5_000))) if num_updates >= min(2_000, hparams['start_adv_iters']) else 0 |
|
|
|
self.optimizer[optim_i].param_groups[2]['lr'] = max(1e-5, self.lr * (hparams.get("lr_decay_rate", 0.95)) ** (num_updates // hparams.get("lr_decay_interval", 5_000))) if num_updates >= hparams['start_adv_iters'] else 0 |
|
|
|
self.optimizer[-1].param_groups[0]['lr'] = self.lr_d * 1 |
|
return self.lr |
|
|
|
|
|
class OSAvatarImg2PlaneTask(BaseTask): |
|
def __init__(self): |
|
super().__init__() |
|
if hparams['lpips_mode'] == 'vgg19': |
|
self.criterion_lpips = VGG19Loss() |
|
elif hparams['lpips_mode'] in ['vgg16', 'vgg']: |
|
hparams['lpips_mode'] = 'vgg' |
|
self.criterion_lpips = lpips.LPIPS(net=hparams['lpips_mode'],lpips=True) |
|
elif hparams['lpips_mode'] == 'vgg19_v2': |
|
self.criterion_lpips = PerceptualLoss() |
|
else: |
|
raise NotImplementedError |
|
self.gen_tmp_output = {} |
|
if hparams.get("use_kv_dataset", True): |
|
self.dataset_cls = Img2Plane_Dataset |
|
else: |
|
raise NotImplementedError() |
|
self.start_adv_iters = hparams["start_adv_iters"] |
|
self.resample_filter = upfirdn2d.setup_filter([1,3,3,1]) |
|
|
|
@data_loader |
|
def train_dataloader(self): |
|
train_dataset = self.dataset_cls(prefix='train') |
|
self.train_dl = train_dataset.get_dataloader() |
|
return self.train_dl |
|
|
|
@data_loader |
|
def val_dataloader(self): |
|
val_dataset = self.dataset_cls(prefix='val') |
|
self.val_dl = val_dataset.get_dataloader() |
|
return self.val_dl |
|
|
|
@data_loader |
|
def test_dataloader(self): |
|
val_dataset = self.dataset_cls(prefix='val') |
|
self.val_dl = val_dataset.get_dataloader() |
|
return self.val_dl |
|
|
|
def build_model(self): |
|
self.eg3d_model = TriPlaneGenerator() |
|
load_ckpt(self.eg3d_model, hparams['pretrained_eg3d_ckpt'], strict=True) |
|
self.model = OSAvatar_Img2plane() |
|
load_ckpt(self.model, hparams['pretrained_eg3d_ckpt'], strict=False, verbose=False) |
|
self.disc = DualDiscriminator() |
|
load_ckpt(self.disc, hparams['pretrained_eg3d_ckpt'], strict=False, model_name='disc') |
|
if hparams.get('init_from_ckpt', '') != '': |
|
|
|
load_ckpt(self.model, hparams['init_from_ckpt'], model_name='model', strict=False) |
|
|
|
load_ckpt(self.disc, hparams['init_from_ckpt'], model_name='disc', strict=True) |
|
|
|
print(f"restored weights from {hparams.get('init_from_ckpt', '')}") |
|
|
|
|
|
self.img2plane_backbone_params = [p for k, p in self.model.img2plane_backbone.named_parameters() if p.requires_grad] |
|
self.decoder_params = [p for p in self.model.decoder.parameters() if p.requires_grad] |
|
self.upsample_params = [p for p in self.model.superresolution.parameters() if p.requires_grad] |
|
self.disc_params = [p for k, p in self.disc.named_parameters() if p.requires_grad] |
|
return self.model |
|
|
|
def build_optimizer(self, model): |
|
self.optimizer_gen = optimizer_gen = torch.optim.Adam( |
|
self.img2plane_backbone_params, |
|
lr=hparams['lr_g'], |
|
betas=(hparams['optimizer_adam_beta1_g'], hparams['optimizer_adam_beta2_g']) |
|
) |
|
self.optimizer_gen.add_param_group({ |
|
'params': self.decoder_params, |
|
'lr': hparams['lr_g'], |
|
'betas': (hparams['optimizer_adam_beta1_g'], hparams['optimizer_adam_beta2_g']) |
|
}) |
|
self.optimizer_gen.add_param_group({ |
|
'params': self.upsample_params, |
|
'lr': hparams['lr_g'], |
|
'betas': (hparams['optimizer_adam_beta1_g'], hparams['optimizer_adam_beta2_g']) |
|
}) |
|
|
|
mb_ratio_d = hparams['reg_interval_d'] / (hparams['reg_interval_d'] + 1) |
|
self.optimizer_disc = optimizer_disc = torch.optim.Adam( |
|
self.disc_params, |
|
lr=hparams['lr_d'] * mb_ratio_d, |
|
betas=(hparams['optimizer_adam_beta1_d'] ** mb_ratio_d, hparams['optimizer_adam_beta2_d'] ** mb_ratio_d)) |
|
optimizers = [optimizer_gen] * 2 + [optimizer_disc] |
|
return optimizers |
|
|
|
def build_scheduler(self, optimizer): |
|
mb_ratio_d = hparams['reg_interval_d'] / (hparams['reg_interval_d'] + 1) |
|
return ScheduleForImg2Plane(optimizer, hparams['lr_g'], hparams['lr_d'] * mb_ratio_d, hparams['warmup_updates']) |
|
|
|
def on_train_start(self): |
|
print("==============================") |
|
num_params(self.model, model_name="Generator") |
|
for n, m in self.model.named_children(): |
|
num_params(m, model_name="|-- "+n) |
|
print("==============================") |
|
num_params(self.disc, model_name="Discriminator") |
|
for n, m in self.disc.named_children(): |
|
num_params(m, model_name="|-- "+n) |
|
print("==============================") |
|
|
|
def forward_G(self, img, camera, cond=None, ret=None, update_emas=False, cache_backbone=True, use_cached_backbone=False): |
|
""" |
|
ref_img: [B, 3, W, H] |
|
camera: [b, 25], 16 dim c2w, and 9 dim intrinsic |
|
""" |
|
G = self.model |
|
gen_output = G.forward(img=img, camera=camera, cond=cond, ret=ret, update_emas=update_emas, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone) |
|
return gen_output |
|
|
|
def forward_D(self, img, camera, update_emas=False): |
|
D = self.disc |
|
logits = D.forward(img, camera, update_emas=update_emas) |
|
return logits |
|
|
|
def prepare_batch(self, batch): |
|
|
|
out_batch = {} |
|
ws_camera = batch['ffhq_ws_cameras'] |
|
camera = batch['ffhq_ref_cameras'] |
|
fake_camera = batch['ffhq_mv_cameras'] |
|
z = torch.randn([camera.shape[0], hparams['z_dim']],device=camera.device) |
|
with torch.no_grad(): |
|
ws = self.eg3d_model.mapping(z, ws_camera, update_emas=False, truncation_psi=0.7) |
|
ref_img_gen = self.eg3d_model.synthesis(ws, camera, cache_backbone=True, use_cached_backbone=False) |
|
mv_img_gen = self.eg3d_model.synthesis(ws, fake_camera, cache_backbone=False, use_cached_backbone=True) |
|
ref_img_gen['image_raw'] = filtered_resizing(ref_img_gen['image'], size=hparams['neural_rendering_resolution'], f=self.resample_filter, filter_mode='antialiased') |
|
mv_img_gen['image_raw'] = filtered_resizing(mv_img_gen['image'], size=hparams['neural_rendering_resolution'], f=self.resample_filter, filter_mode='antialiased') |
|
|
|
out_batch.update({ |
|
'ffhq_planes': ref_img_gen['plane'], |
|
'ffhq_ref_imgs_raw': ref_img_gen['image_raw'], |
|
'ffhq_ref_imgs': ref_img_gen['image'], |
|
'ffhq_ref_imgs_depth': ref_img_gen['image_depth'], |
|
|
|
'ffhq_mv_imgs_raw': mv_img_gen['image_raw'], |
|
'ffhq_mv_imgs': mv_img_gen['image'], |
|
'ffhq_mv_imgs_depth': mv_img_gen['image_depth'], |
|
'ffhq_mv_imgs_feature': mv_img_gen['image_feature'], |
|
'ffhq_ref_cameras': batch['ffhq_ref_cameras'], |
|
'ffhq_mv_cameras': batch['ffhq_mv_cameras'], |
|
}) |
|
|
|
return out_batch |
|
|
|
def run_G_reference_image(self, batch): |
|
losses = {} |
|
ret = {} |
|
if self.global_step+1 < self.start_adv_iters: |
|
|
|
if (self.global_step+1) % 5 != 0: |
|
return losses |
|
elif self.global_step < self.start_adv_iters + 5000: |
|
if (self.global_step+1) % 2 != 0: |
|
return losses |
|
else: |
|
|
|
pass |
|
|
|
with torch.autograd.profiler.record_function('G_ref_forward'): |
|
camera = batch['ffhq_ref_cameras'] |
|
img = batch['ffhq_ref_imgs'] |
|
img_raw = batch['ffhq_ref_imgs_raw'] |
|
img_depth = batch['ffhq_ref_imgs_depth'] |
|
|
|
gen_img = self.forward_G(img, camera, cond={'ref_cameras': batch['ffhq_ref_cameras']}, ret=ret) |
|
if 'losses' in ret: losses.update(ret['losses']) |
|
self.gen_tmp_output['recon_ref_imgs'] = gen_img['image'].detach() |
|
self.gen_tmp_output['recon_ref_imgs_raw'] = gen_img['image_raw'].detach() |
|
|
|
losses['G_ref_plane_l1_mean'] = (gen_img['plane'][:,:]).detach().abs().mean() |
|
losses['G_ref_plane_l1_std'] = (gen_img['plane'][:,:]).detach().abs().std() |
|
if hparams['use_mse']: |
|
losses['G_ref_img_mae'] = (gen_img['image'] - img).pow(2).mean() |
|
losses['G_ref_img_mae_raw'] = (gen_img['image_raw'] - img_raw).pow(2).mean() |
|
else: |
|
losses['G_ref_img_mae'] = (gen_img['image'] - img).abs().mean() |
|
losses['G_ref_img_mae_raw'] = (gen_img['image_raw'] - img_raw).abs().mean() |
|
|
|
pred_img_01 = ((gen_img['image'] + 1) / 2).clamp(0, 1) |
|
pred_img_01_raw = ((gen_img['image_raw'] + 1) / 2).clamp(0, 1) |
|
gt_img_01 = ((img + 1) / 2).clamp(0, 1) |
|
gt_img_01_raw = ((img_raw + 1) / 2).clamp(0, 1) |
|
losses['G_ref_img_lpips'] = self.criterion_lpips(pred_img_01, gt_img_01).mean() |
|
losses['G_ref_img_lpips_raw'] = self.criterion_lpips(pred_img_01_raw, gt_img_01_raw).mean() |
|
losses['G_ref_img_mae_depth'] = (gen_img['image_depth'] - img_depth).detach().abs().mean() |
|
disc_inp_img = { |
|
'image': gen_img['image'], |
|
'image_raw': gen_img['image_raw'], |
|
} |
|
gen_logits = self.forward_D(disc_inp_img, camera) |
|
losses['G_ref_adv'] = torch.nn.functional.softplus(-gen_logits).mean() |
|
return losses |
|
|
|
def run_G_multiview_image(self, batch): |
|
losses = {} |
|
ret = {} |
|
with torch.autograd.profiler.record_function('G_mv_forward'): |
|
camera = batch['ffhq_mv_cameras'] |
|
img = batch['ffhq_mv_imgs'] |
|
img_raw = batch['ffhq_mv_imgs_raw'] |
|
img_depth = batch['ffhq_mv_imgs_depth'] |
|
|
|
gen_img = self.forward_G(batch['ffhq_ref_imgs'], camera, cond={'ref_cameras': batch['ffhq_ref_cameras']}, ret=ret) |
|
if 'losses' in ret: losses.update(ret['losses']) |
|
self.gen_tmp_output['recon_mv_imgs'] = gen_img['image'].detach() |
|
self.gen_tmp_output['recon_mv_imgs_raw'] = gen_img['image_raw'].detach() |
|
losses['G_mv_plane_l1_mean'] = (gen_img['plane'][:,:]).detach().abs().mean() |
|
losses['G_mv_plane_l1_std'] = (gen_img['plane'][:,:]).detach().abs().std() |
|
if hparams['use_mse']: |
|
losses['G_mv_img_mae'] = (gen_img['image'] - img).pow(2).mean() |
|
losses['G_mv_img_mae_raw'] = (gen_img['image_raw'] - img_raw).pow(2).mean() |
|
else: |
|
losses['G_mv_img_mae'] = (gen_img['image'] - img).abs().mean() |
|
losses['G_mv_img_mae_raw'] = (gen_img['image_raw'] - img_raw).abs().mean() |
|
pred_img_01 = (gen_img['image'] + 1) / 2 |
|
pred_img_01_raw = (gen_img['image_raw'] + 1) / 2 |
|
gt_img_01 = (img + 1) / 2 |
|
gt_img_01_raw = (img_raw + 1) / 2 |
|
losses['G_mv_img_lpips'] = self.criterion_lpips(pred_img_01, gt_img_01).mean() |
|
losses['G_mv_img_lpips_raw'] = self.criterion_lpips(pred_img_01_raw, gt_img_01_raw).mean() |
|
losses['G_mv_img_mae_depth'] = (gen_img['image_depth'] - img_depth).detach().abs().mean() |
|
|
|
disc_inp_img = { |
|
'image': gen_img['image'], |
|
'image_raw': gen_img['image_raw'], |
|
} |
|
gen_logits = self.forward_D(disc_inp_img, camera) |
|
losses['G_mv_adv'] = torch.nn.functional.softplus(-gen_logits).mean() |
|
return losses |
|
|
|
|
|
def run_G_reg(self, batch): |
|
losses = {} |
|
imgs = batch['ffhq_mv_imgs'] |
|
|
|
if (self.global_step+1) % hparams['reg_interval_g'] == 0: |
|
with torch.autograd.profiler.record_function('G_regularize_forward'): |
|
initial_coordinates = torch.rand((imgs.shape[0], 1000, 3), device=imgs.device) - 0.5 |
|
perturbed_coordinates = initial_coordinates + torch.randn_like(initial_coordinates) * 5e-3 |
|
all_coordinates = torch.cat([initial_coordinates, perturbed_coordinates], dim=1) |
|
source_sigma = self.model.sample(coordinates=all_coordinates, directions=torch.randn_like(all_coordinates), img=imgs, cond={'ref_cameras': batch['ffhq_mv_cameras']}, update_emas=False)['sigma'] |
|
source_sigma_initial = source_sigma[:, :source_sigma.shape[1]//2] |
|
source_sigma_perturbed = source_sigma[:, source_sigma.shape[1]//2:] |
|
density_reg_loss = torch.nn.functional.l1_loss(source_sigma_initial, source_sigma_perturbed) |
|
|
|
|
|
losses['G_regularize_density_l1'] = density_reg_loss |
|
return losses |
|
|
|
def forward_D_main(self, batch): |
|
""" |
|
we update ema this substep. |
|
""" |
|
losses = {} |
|
with torch.autograd.profiler.record_function('D_minimize_fake_forward'): |
|
if 'recon_ref_imgs' in self.gen_tmp_output: |
|
camera = batch['ffhq_ref_cameras'] |
|
disc_inp_img = { |
|
'image': self.gen_tmp_output['recon_ref_imgs'], |
|
'image_raw': self.gen_tmp_output['recon_ref_imgs_raw'], |
|
} |
|
gen_logits = self.forward_D(disc_inp_img, camera, update_emas=True) |
|
losses['D_minimize_ref_fake'] = torch.nn.functional.softplus(gen_logits).mean() |
|
|
|
if 'recon_mv_imgs' in self.gen_tmp_output: |
|
camera = batch['ffhq_mv_cameras'] |
|
disc_inp_img = { |
|
'image': self.gen_tmp_output['recon_mv_imgs'], |
|
'image_raw': self.gen_tmp_output['recon_mv_imgs_raw'], |
|
} |
|
gen_logits = self.forward_D(disc_inp_img, camera, update_emas=True) |
|
losses['D_minimize_mv_fake'] = torch.nn.functional.softplus(gen_logits).mean() |
|
|
|
|
|
with torch.autograd.profiler.record_function('D_maximize_true_forward'): |
|
if hparams.get("ffhq_disc_inp_mode", "eg3d_gen") == 'eg3d_gen': |
|
ref_cameras = batch['ffhq_ref_cameras'] |
|
ref_img_tmp_image = batch['ffhq_ref_imgs'].detach().requires_grad_(True) |
|
ref_img_tmp_image_raw = batch['ffhq_ref_imgs_raw'].detach().requires_grad_(True) |
|
elif hparams.get("ffhq_disc_inp_mode", "eg3d_gen") == 'ffhq': |
|
ref_cameras = batch['ffhq_gt_cameras'] |
|
ref_img_tmp_image = batch['ffhq_gt_imgs'].detach().requires_grad_(True) |
|
ref_img_tmp_image_raw = batch['ffhq_gt_imgs_raw'].detach().requires_grad_(True) |
|
|
|
ref_img_tmp = {'image': ref_img_tmp_image, 'image_raw': ref_img_tmp_image_raw} |
|
ref_logits = self.forward_D(ref_img_tmp, ref_cameras) |
|
losses['D_maximize_ref'] = torch.nn.functional.softplus(-ref_logits).mean() |
|
|
|
if hparams.get("ffhq_disc_inp_mode", "eg3d_gen") == 'eg3d_gen': |
|
mv_cameras = batch['ffhq_mv_cameras'] |
|
mv_img_tmp_image = batch['ffhq_mv_imgs'].detach().requires_grad_(True) |
|
mv_img_tmp_image_raw = batch['ffhq_mv_imgs_raw'].detach().requires_grad_(True) |
|
mv_img_tmp = {'image': mv_img_tmp_image, 'image_raw': mv_img_tmp_image_raw} |
|
mv_logits = self.forward_D(mv_img_tmp, mv_cameras) |
|
losses['D_maximize_mv'] = torch.nn.functional.softplus(-mv_logits).mean() |
|
|
|
if (self.global_step+1) % hparams['reg_interval_d'] == 0 and self.training is True: |
|
with torch.autograd.profiler.record_function('D_gradient_penalty_on_real_imgs_forward'), conv2d_gradfix.no_weight_gradients(): |
|
ref_r1_grads = torch.autograd.grad(outputs=[ref_logits.sum()], inputs=[ref_img_tmp['image'], ref_img_tmp['image_raw']], create_graph=True, only_inputs=True) |
|
ref_r1_grads_image = ref_r1_grads[0] |
|
ref_r1_grads_image_raw = ref_r1_grads[1] |
|
ref_r1_penalty_raw = ref_r1_grads_image_raw.square().sum([1,2,3]).mean() |
|
ref_r1_penalty_image = ref_r1_grads_image.square().sum([1,2,3]).mean() |
|
losses['D_ref_gradient_penalty'] = (ref_r1_penalty_image + ref_r1_penalty_raw) / 2 |
|
|
|
if hparams.get("ffhq_disc_inp_mode", "eg3d_gen") == 'eg3d_gen': |
|
mv_r1_grads = torch.autograd.grad(outputs=[mv_logits.sum()], inputs=[mv_img_tmp['image'], mv_img_tmp['image_raw']], create_graph=True, only_inputs=True) |
|
mv_r1_grads_image = mv_r1_grads[0] |
|
mv_r1_grads_image_raw = mv_r1_grads[1] |
|
mv_r1_penalty_raw = mv_r1_grads_image_raw.square().sum([1,2,3]).mean() |
|
mv_r1_penalty_image = mv_r1_grads_image.square().sum([1,2,3]).mean() |
|
losses['D_mv_gradient_penalty'] = (mv_r1_penalty_image + mv_r1_penalty_raw) / 2 |
|
|
|
self.gen_tmp_output = {} |
|
return losses |
|
|
|
def _training_step(self, sample, batch_idx, optimizer_idx): |
|
if len(sample) == 0: |
|
return None |
|
if optimizer_idx == 0: |
|
sample = self.prepare_batch(sample) |
|
self.cache_sample = sample |
|
else: |
|
sample = self.cache_sample |
|
|
|
losses = {} |
|
if optimizer_idx == 0: |
|
losses.update(self.run_G_reference_image(sample)) |
|
|
|
loss_weights = { |
|
'G_ref_img_mae': 1.0, |
|
'G_ref_img_mae_raw': 1.0, |
|
'G_ref_img_mae_depth': hparams.get('lambda_mse_depth', 0), |
|
'G_ref_img_lpips': 0.1, |
|
'G_ref_img_lpips_raw': 0.1, |
|
'G_ref_adv': 0.0 if self.global_step < self.start_adv_iters else 0.1, |
|
} |
|
|
|
elif optimizer_idx == 1: |
|
losses.update(self.run_G_multiview_image(sample)) |
|
losses.update(self.run_G_reg(sample)) |
|
loss_weights = { |
|
'G_mv_img_mae': 1.0, |
|
'G_mv_img_mae_raw': 1.0, |
|
'G_mv_img_mae_depth': hparams.get('lambda_mse_depth', 0), |
|
'G_mv_img_lpips': 0.1, |
|
'G_mv_img_lpips_raw': 0.1, |
|
'G_mv_adv': 0.0 if self.global_step < self.start_adv_iters else 0.025, |
|
'G_regularize_density_l1': hparams['lambda_density_reg'], |
|
} |
|
|
|
elif optimizer_idx == 2: |
|
losses.update(self.forward_D_main(sample)) |
|
loss_weights = { |
|
'D_maximize_ref': 1.0, |
|
'D_minimize_ref_fake': 1.0, |
|
'D_ref_gradient_penalty': hparams['lambda_gradient_penalty'] * hparams['reg_interval_d'], |
|
'D_maximize_mv': 1.0, |
|
'D_minimize_mv_fake': 1.0, |
|
'D_mv_gradient_penalty': hparams['lambda_gradient_penalty'] * hparams['reg_interval_d'], |
|
} |
|
total_loss = sum([loss_weights[k] * v for k, v in losses.items() if isinstance(v, torch.Tensor) and v.requires_grad]) |
|
if len(losses) == 0: |
|
return None |
|
return total_loss, losses |
|
|
|
|
|
|
|
|
|
def validation_start(self): |
|
if self.global_step % hparams['valid_infer_interval'] == 0: |
|
self.gen_dir = os.path.join(hparams['work_dir'], f'validation_results') |
|
os.makedirs(self.gen_dir, exist_ok=True) |
|
|
|
@torch.no_grad() |
|
def validation_step(self, sample, batch_idx): |
|
outputs = {} |
|
losses = {} |
|
if len(sample) == 0: |
|
return None |
|
|
|
sample = self.prepare_batch(sample) |
|
rank = 0 if len(set(os.environ['CUDA_VISIBLE_DEVICES'].split(","))) == 1 else dist.get_rank() |
|
|
|
losses.update(self.run_G_reference_image(sample)) |
|
losses.update(self.run_G_multiview_image(sample)) |
|
losses.update(self.forward_D_main(sample)) |
|
outputs['losses'] = losses |
|
outputs['total_loss'] = sum(outputs['losses'].values()) |
|
outputs = tensors_to_scalars(outputs) |
|
|
|
if self.global_step % hparams['valid_infer_interval'] == 0 \ |
|
and batch_idx < hparams['num_valid_plots'] and rank == 0: |
|
|
|
imgs_ref = sample['ffhq_ref_imgs'] |
|
dummy_cond = { |
|
'cond_src': torch.zeros([imgs_ref.shape[0], 3, 512, 512], dtype=sample['ffhq_mv_cameras'].dtype, device=sample['ffhq_mv_cameras'].device), |
|
'cond_drv': torch.zeros([imgs_ref.shape[0], 3, 512, 512], dtype=sample['ffhq_mv_cameras'].dtype, device=sample['ffhq_mv_cameras'].device), |
|
'ref_cameras': sample['ffhq_ref_cameras'], |
|
} |
|
gen_img = self.model.forward(imgs_ref, sample['ffhq_mv_cameras'], cond=dummy_cond, noise_mode='const') |
|
gen_img_recon = self.model.forward(imgs_ref, sample['ffhq_ref_cameras'], cond=dummy_cond, noise_mode='const') |
|
imgs_recon = gen_img_recon['image'].permute(0,2,3,1) |
|
imgs_recon_raw = filtered_resizing(gen_img_recon['image_raw'], size=512, f=self.resample_filter, filter_mode='antialiased').permute(0, 2,3,1) |
|
imgs_recon_depth = gen_img_recon['image_depth'].permute(0,2,3,1) |
|
imgs_pred = gen_img['image'].permute(0,2,3,1) |
|
imgs_pred_raw = filtered_resizing(gen_img['image_raw'], size=512, f=self.resample_filter, filter_mode='antialiased').permute(0, 2,3,1) |
|
imgs_pred_depth = gen_img['image_depth'].permute(0,2,3,1) |
|
imgs_ref = imgs_ref.permute(0,2,3,1) |
|
imgs_mv = sample['ffhq_mv_imgs'].permute(0,2,3,1) |
|
|
|
for i in range(len(imgs_pred)): |
|
idx_string = format(i+batch_idx * hparams['batch_size'], "05d") |
|
base_fn = f"{idx_string}" |
|
img_ref_mv_recon_pred = torch.cat([imgs_ref[i], imgs_mv[i], imgs_recon_raw[i], imgs_pred_raw[i], imgs_recon[i], imgs_pred[i]], dim=1) |
|
self.save_rgb_to_fname(img_ref_mv_recon_pred, f"{self.gen_dir}/images_rgb_iter{self.global_step}/ref_mv_recon_pred_{base_fn}.png") |
|
img_depth_recon_pred = torch.cat([imgs_recon_depth[i], imgs_pred_depth[i]], dim=1) |
|
self.save_depth_to_fname(img_depth_recon_pred, f"{self.gen_dir}/images_depth_iter{self.global_step}/recon_pred_{base_fn}.png") |
|
|
|
return outputs |
|
|
|
def validation_end(self, outputs): |
|
return super().validation_end(outputs) |
|
|
|
@staticmethod |
|
def save_rgb_to_fname(rgb, fname): |
|
""" |
|
rgb: [H, W, 3] |
|
""" |
|
os.makedirs(os.path.dirname(fname), exist_ok=True) |
|
img = (rgb * 127.5 + 128).clamp(0, 255) |
|
img = convert_to_np(img).astype(np.uint8) |
|
|
|
Image.fromarray(img, 'RGB').save(fname) |
|
|
|
@staticmethod |
|
def save_depth_to_fname(depth, fname): |
|
""" |
|
depth: [H, W, 3] |
|
""" |
|
os.makedirs(os.path.dirname(fname), exist_ok=True) |
|
low, high = depth.min(), depth.max() |
|
img = (depth - low) * (255 / (high - low)) |
|
img = convert_to_np(img) |
|
img = np.rint(img).clip(0, 255).astype(np.uint8) |
|
|
|
Image.fromarray(img[:, :, 0], 'L').save(fname) |
|
|