import abc import os from argparse import Namespace import wandb import os.path from criteria.localitly_regulizer import Space_Regulizer import torch from torchvision import transforms from lpips import LPIPS from training.projectors import w_projector # w_plus_projector as w_projector from configs import global_config, paths_config, hyperparameters from criteria import l2_loss from criteria import mask from criteria import id_loss from models.e4e.psp import pSp from utils.log_utils import log_image_from_w from utils.models_utils import toogle_grad, load_old_G from torch_utils import misc from torch_utils.ops import upfirdn2d import numpy as np import pickle import copy class BaseCoach: def __init__(self, data_loader, use_wandb): self.use_wandb = use_wandb self.data_loader = data_loader self.w_pivots = {} self.image_counter = 0 if hyperparameters.first_inv_type == "w+": self.initilize_e4e() self.e4e_image_transform = transforms.Compose( [ transforms.ToPILImage(), transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ] ) # Initialize loss self.lpips_loss = ( LPIPS(net=hyperparameters.lpips_type).to(global_config.device).eval() ) self.id_loss = ( id_loss.IDLoss( paths_config.ir_se50, official=False, device=global_config.device ) .to(global_config.device) .eval() ) if hyperparameters.use_mask: self.mask = mask.Mask(device=global_config.device) self.restart_training() # Initialize checkpoint dir self.checkpoint_dir = paths_config.checkpoints_dir os.makedirs(self.checkpoint_dir, exist_ok=True) def restart_training(self): # Initialize networks self.G = load_old_G() toogle_grad(self.G, True) self.original_G = load_old_G() self.space_regulizer = Space_Regulizer(self.original_G, self.lpips_loss) self.optimizer = self.configure_optimizers() def get_inversion(self, w_path_dir, image_name, image): embedding_dir = f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}" os.makedirs(embedding_dir, exist_ok=True) w_pivot = None if hyperparameters.use_last_w_pivots: w_pivot = self.load_inversions(w_path_dir, image_name) if not hyperparameters.use_last_w_pivots or w_pivot is None: w_pivot = self.calc_inversions(image, image_name) torch.save(w_pivot, f"{embedding_dir}/0.pt") w_pivot = w_pivot.to(global_config.device) return w_pivot def load_inversions(self, w_path_dir, image_name): if image_name in self.w_pivots: return self.w_pivots[image_name] if hyperparameters.first_inv_type == "w+": w_potential_path = ( f"{w_path_dir}/{paths_config.e4e_results_keyword}/{image_name}/0.pt" ) else: w_potential_path = ( f"{w_path_dir}/{paths_config.pti_results_keyword}/{image_name}/0.pt" ) if not os.path.isfile(w_potential_path): return None w = torch.load(w_potential_path, map_location=global_config.device).to( global_config.device ) self.w_pivots[image_name] = w return w def calc_inversions(self, image, image_name): if hyperparameters.first_inv_type == "w+": w = self.get_e4e_inversion(image) else: id_image = torch.squeeze((image.to(global_config.device) + 1) / 2) * 255 w = w_projector.project( self.G, id_image, device=torch.device(global_config.device), w_avg_samples=600, num_steps=hyperparameters.first_inv_steps, w_name=image_name, use_wandb=self.use_wandb, ) return w @abc.abstractmethod def train(self): pass def configure_optimizers(self): #params = list(self.G.parameters()) params = [] # res = ["64", "32", "16", "8", "4"] for n, p in self.G.synthesis.named_parameters(): #for r in res: #if r in n: if "rgb" not in n: params.append(p) # params += list(self.G.synthesis.parameters()) optimizer = torch.optim.Adam(params, lr=hyperparameters.pti_learning_rate) return optimizer def calc_loss( self, generated_images, real_images, log_name, new_G, use_ball_holder, w_batch, rgbs, ): loss = 0.0 if hyperparameters.use_mask: real_images, generated_images = self.mask(real_images, generated_images) if hyperparameters.pt_l2_lambda > 0: l2_loss_val = l2_loss.l2_loss(generated_images, real_images, gray=False) if self.use_wandb: wandb.log( {f"MSE_loss_val_{log_name}": l2_loss_val.detach().cpu()}, step=global_config.training_step, ) loss += l2_loss_val * hyperparameters.pt_l2_lambda if hyperparameters.pt_lpips_lambda > 0: loss_lpips = self.lpips_loss(real_images, generated_images) loss_lpips = torch.squeeze(loss_lpips) if self.use_wandb: wandb.log( {f"LPIPS_loss_val_{log_name}": loss_lpips.detach().cpu()}, step=global_config.training_step, ) loss += loss_lpips * hyperparameters.pt_lpips_lambda if hyperparameters.color_transfer_lambda > 0: for y in self.years: color_loss = self.color_losses[y](rgbs[y]) """ print( "Year: ", y, " Color Transfer:", color_loss * hyperparameters.color_transfer_lambda, ) """ loss += color_loss * hyperparameters.color_transfer_lambda if hyperparameters.id_lambda > 0: loss_id = self.id_loss(real_images, generated_images) loss_id = torch.squeeze(loss_id) loss += loss_id * hyperparameters.id_lambda if use_ball_holder and hyperparameters.use_locality_regularization: ball_holder_loss_val = self.space_regulizer.space_regulizer_loss( new_G, w_batch, use_wandb=self.use_wandb ) loss += ball_holder_loss_val return loss, l2_loss_val, loss_lpips def synthesis_block(self, block, x, img, ws, force_fp32=False, fused_modconv=None): w_iter = iter(ws.unbind(dim=1)) dtype = torch.float16 if block.use_fp16 and not force_fp32 else torch.float32 memory_format = ( torch.channels_last if block.channels_last and not force_fp32 else torch.contiguous_format ) if fused_modconv is None: with misc.suppress_tracer_warnings(): # this value will be treated as a constant fused_modconv = (not block.training) and ( dtype == torch.float32 or int(x.shape[0]) == 1 ) # Input. if block.in_channels == 0: x = block.const.to(dtype=dtype, memory_format=memory_format) x = x.unsqueeze(0).repeat([ws.shape[0], 1, 1, 1]) else: misc.assert_shape( x, [None, block.in_channels, block.resolution // 2, block.resolution // 2], ) x = x.to(dtype=dtype, memory_format=memory_format) # Main layers. if block.in_channels == 0: x = block.conv1(x, next(w_iter), fused_modconv=fused_modconv) elif block.architecture == "resnet": y = block.skip(x, gain=np.sqrt(0.5)) x = block.conv0(x, next(w_iter), fused_modconv=fused_modconv) x = block.conv1( x, next(w_iter), fused_modconv=fused_modconv, gain=np.sqrt(0.5), ) x = y.add_(x) else: x = block.conv0(x, next(w_iter), fused_modconv=fused_modconv) x = block.conv1(x, next(w_iter), fused_modconv=fused_modconv) # ToRGB. if img is not None: misc.assert_shape( img, [ None, block.img_channels, block.resolution // 2, block.resolution // 2, ], ) img = upfirdn2d.upsample2d(img, block.resample_filter) if block.is_last or block.architecture == "skip": y = block.torgb(x, next(w_iter), fused_modconv=fused_modconv) y = y.to(dtype=torch.float32, memory_format=torch.contiguous_format) img = img.add_(y) if img is not None else y assert x.dtype == dtype assert img is None or img.dtype == torch.float32 return x, img, y def forward(self, w): generated_images = self.G.synthesis(w, noise_mode="const", force_fp32=True) return generated_images def forward_sibling(self, G_sibling, w): block_ws = [] rgbs = [] ws = w.to(torch.float32) w_idx = 0 for res in G_sibling.block_resolutions: block = getattr(G_sibling, f"b{res}") block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb)) w_idx += block.num_conv x = img = None for res, cur_ws in zip(G_sibling.block_resolutions, block_ws): block = getattr(G_sibling, f"b{res}") x, img, rgb_mod = self.synthesis_block(block, x, img, cur_ws) # print(f"ToRGB: {res}", rgb_mod) rgbs.append(rgb_mod) return img, rgbs def initilize_e4e(self): ckpt = torch.load(paths_config.e4e, map_location="cpu") opts = ckpt["opts"] opts["batch_size"] = hyperparameters.train_batch_size opts["checkpoint_path"] = paths_config.e4e opts = Namespace(**opts) self.e4e_inversion_net = pSp(opts) self.e4e_inversion_net.eval() self.e4e_inversion_net = self.e4e_inversion_net.to(global_config.device) toogle_grad(self.e4e_inversion_net, False) def get_e4e_inversion(self, image): image = (image + 1) / 2 new_image = self.e4e_image_transform(image[0]).to(global_config.device) _, w = self.e4e_inversion_net( new_image.unsqueeze(0), randomize_noise=False, return_latents=True, resize=False, input_code=False, ) if self.use_wandb: log_image_from_w(w, self.G, "First e4e inversion") return w