# python3.7 """Contains the base class for Encoder (GAN Inversion) runner.""" import os import shutil import torch import torch.distributed as dist import torchvision.transforms as T from utils.visualizer import HtmlPageVisualizer from utils.visualizer import get_grid_shape from utils.visualizer import postprocess_image from utils.visualizer import save_image from utils.visualizer import load_image from utils.visualizer import postprocess_tensor from metrics.inception import build_inception_model from metrics.fid import extract_feature from metrics.fid import compute_fid from metrics.MSSIM import MSSSIM from metrics.LPIPS import LPIPS import numpy as np from .base_runner import BaseRunner from datasets import BaseDataset from torch.utils.data import DataLoader from PIL import Image from runners.controllers.summary_writer import log_image import torchvision from editings.latent_editor import LatentEditor from editings.styleclip.edit_hfgi import styleclip_edit, load_stylegan_generator,load_direction_calculator from editings.GradCtrl.manipulate import main as gradctrl import torch.nn.functional as F import time __all__ = ['BaseEncoderRunner'] class BaseEncoderRunner(BaseRunner): """Defines the base class for Encoder runner.""" def __init__(self, config, logger): super().__init__(config, logger) self.inception_model = None def build_models(self): super().build_models() assert 'encoder' in self.models assert 'generator_smooth' in self.models assert 'discriminator' in self.models self.resolution = self.models['generator_smooth'].resolution self.G_kwargs_train = self.config.modules['generator_smooth'].get( 'kwargs_train', dict()) self.G_kwargs_val = self.config.modules['generator_smooth'].get( 'kwargs_val', dict()) self.D_kwargs_train = self.config.modules['discriminator'].get( 'kwargs_train', dict()) self.D_kwargs_val = self.config.modules['discriminator'].get( 'kwargs_val', dict()) if self.config.use_disc2: self.D2_kwargs_train = self.config.modules['discriminator2'].get( 'kwargs_train', dict()) self.D2_kwargs_val = self.config.modules['discriminator2'].get( 'kwargs_val', dict()) if self.config.mapping_method != 'pretrained': self.M_kwargs_train = self.config.modules['mapping'].get( 'kwargs_train', dict()) self.M_kwargs_val = self.config.modules['mapping'].get( 'kwargs_val', dict()) if self.config.create_mixing_network: self.MIX_kwargs_train = self.config.modules['mixer'].get( 'kwargs_train', dict()) self.MIX_kwargs_val = self.config.modules['mixer'].get( 'kwargs_val', dict()) def train_step(self, data, **train_kwargs): raise NotImplementedError('Should be implemented in derived class.') def mse(self, mse_num): if mse_num == 0: return -1 self.set_mode('val') if self.val_loader is None: self.build_dataset('val') if mse_num == "auto": mse_num = len(self.val_loader.dataset) indices = list(range(self.rank, mse_num, self.world_size)) self.logger.init_pbar() task1 = self.logger.add_pbar_task('MSE-LPIPS-SSIM', total=mse_num) lpips = LPIPS() ssim = MSSSIM(size_average=False) n_evals = 3 gather_list = [torch.zeros( (self.val_batch_size, n_evals), device=torch.cuda.current_device()) for i in range(self.world_size)] all_errors = np.zeros( (mse_num, n_evals), dtype=np.float64) shared_tensor = torch.zeros((self.val_batch_size, n_evals), device=torch.cuda.current_device()) gather_idx = 0 for batch_idx in range(0, len(indices), self.val_batch_size): sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] batch_size = len(sub_indices) data = next(self.val_loader) for key in data: if key != 'name': data[key] = data[key][:batch_size].cuda( torch.cuda.current_device(), non_blocking=True) with torch.no_grad(): real_images = data['image'] return_dict = self.forward_pass(data, return_vals='fakes, wp_mixed') fakes = return_dict['fakes'] shared_tensor[:, 0] = torch.mean((fakes - real_images)**2, dim=(1,2,3)) #MSE Error shared_tensor[:, 1]= lpips(real_images, fakes) shared_tensor[:, 2]= ssim(real_images, fakes) #ssim (real_images[0].unsqueeze(0), fakes[0].unsqueeze(0) ) dist.all_gather(gather_list, shared_tensor) if self.rank == 0: for t in gather_list: all_errors[gather_idx:gather_idx+batch_size, 0] = t[:,0].cpu().numpy() all_errors[gather_idx:gather_idx+batch_size, 1] = t[:,1].cpu().numpy() all_errors[gather_idx:gather_idx+batch_size, 2] = t[:,2].cpu().numpy() gather_idx = gather_idx+batch_size self.logger.update_pbar(task1, batch_size * self.world_size) self.logger.close_pbar() mean_lst, std_lst = np.mean(all_errors, axis=0), np.std(all_errors, axis=0) mse_mean, lpips_mean, ssim_mean = mean_lst[0].item(), mean_lst[1].item(), mean_lst[2].item() mse_std, lpips_std, ssim_std = std_lst[0].item(), std_lst[1].item(), std_lst[2].item() return_vals = {'mse': (mse_mean,mse_std), 'lpips': (lpips_mean, lpips_std), 'ssim':(ssim_mean, ssim_std)} return return_vals def fid_attribute(self, fid_num, z=None, ignore_cache=False, align_tf=True, attribute='smile', factor=1, direction=None): """Computes the FID metric.""" self.set_mode('val') direction = torch.load(f'editings/interfacegan_directions/{attribute}.pt').cuda() if factor < 0: self.config.data['smile']['root_dir'] = f'/media/hdd2/adundar/hamza/genforce/data/temp/smile_with_original' elif factor > 0: self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_without_original" fake_loader = self.build_dataset(f"smile") #fid_num = min(fid_num, len(self.val_loader.dataset)) fid_num = len(fake_loader.dataset) if self.inception_model is None: if align_tf: self.logger.info(f'Building inception model ' f'(aligned with TensorFlow) ...') else: self.logger.info(f'Building inception model ' f'(using torchvision) ...') self.inception_model = build_inception_model(align_tf).cuda() self.logger.info(f'Finish building inception model.') if z is not None: assert isinstance(z, np.ndarray) assert z.ndim == 2 and z.shape[1] == self.z_space_dim fid_num = min(fid_num, z.shape[0]) z = torch.from_numpy(z).type(torch.FloatTensor) if not fid_num: return -1 indices = list(range(self.rank, fid_num, self.world_size)) self.logger.init_pbar() # Extract features from fake images. fake_feature_list = [] task1 = self.logger.add_pbar_task(f'FID-{attribute}_fake', total=fid_num) for batch_idx in range(0, len(indices), self.val_batch_size): sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] batch_size = len(sub_indices) data = next(fake_loader) for key in data: if key != 'name': data[key] = data[key][:batch_size].cuda( torch.cuda.current_device(), non_blocking=True) with torch.no_grad(): real_images = data['image'] #valids = data['valid'] return_dict = self.forward_pass(data, return_vals='all', only_enc = True) wp = return_dict['wp_mixed'] eouts = return_dict['eouts'] edit_wp = wp + factor * direction edited_images, _ = self.runG(edit_wp, "synthesis", highres_outs=eouts) fake_feature_list.append( extract_feature(self.inception_model, edited_images)) self.logger.update_pbar(task1, batch_size * self.world_size) np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy', np.concatenate(fake_feature_list, axis=0)) self.logger.close_pbar() #Extract features from real images if needed. cached_fid_file = f'{self.work_dir}/real_{attribute}_{factor}_fid.npy' do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache) if do_real_test: real_feature_list = [] self.logger.init_pbar() if factor < 0: self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_without_original" elif factor > 0: self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_with_original" real_loader = self.build_dataset(f"smile") fid_num = len(real_loader.dataset) indices = list(range(self.rank, fid_num, self.world_size)) task2 = self.logger.add_pbar_task(f"{attribute}_real", total=fid_num) for batch_idx in range(0, len(indices), self.val_batch_size): sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] batch_size = len(sub_indices) data = next(real_loader) for key in data: if key != 'name': data[key] = data[key][:batch_size].cuda( torch.cuda.current_device(), non_blocking=True) with torch.no_grad(): real_images = data['image'] real_feature_list.append( extract_feature(self.inception_model, real_images)) self.logger.update_pbar(task2, batch_size * self.world_size) np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy', np.concatenate(real_feature_list, axis=0)) dist.barrier() if self.rank != 0: return -1 self.logger.close_pbar() # Collect fake features. fake_feature_list.clear() for rank in range(self.world_size): fake_feature_list.append( np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy')) os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy') fake_features = np.concatenate(fake_feature_list, axis=0) # assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num feature_dim = fake_features.shape[1] feature_num = fake_features.shape[0] pad = feature_num % self.world_size #feature_dim.shape[0] if pad: pad = self.world_size - pad fake_features = np.pad(fake_features, ((0, pad), (0, 0))) fake_features = fake_features.reshape(self.world_size, -1, feature_dim) fake_features = fake_features.transpose(1, 0, 2) fake_features = fake_features.reshape(-1, feature_dim)[:feature_num] # Collect (or load) real features. if do_real_test: real_feature_list.clear() for rank in range(self.world_size): real_feature_list.append( np.load(f'{self.work_dir}/real_fid_features_{rank}.npy')) os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy') real_features = np.concatenate(real_feature_list, axis=0) # assert real_features.shape == (fid_num, feature_dim) feature_dim = real_features.shape[1] feature_num = real_features.shape[0] pad = feature_num % self.world_size if pad: pad = self.world_size - pad real_features = np.pad(real_features, ((0, pad), (0, 0))) real_features = real_features.reshape( self.world_size, -1, feature_dim) real_features = real_features.transpose(1, 0, 2) real_features = real_features.reshape(-1, feature_dim)[:feature_num] np.save(cached_fid_file, real_features) else: real_features = np.load(cached_fid_file) # assert real_features.shape == (fid_num, feature_dim) fid_value = compute_fid(fake_features, real_features) return fid_value def fid(self, fid_num, z=None, ignore_cache=False, align_tf=True): """Computes the FID metric.""" self.set_mode('val') if self.val_loader is None: self.build_dataset('val') fid_num = min(fid_num, len(self.val_loader.dataset)) if self.inception_model is None: if align_tf: self.logger.info(f'Building inception model ' f'(aligned with TensorFlow) ...') else: self.logger.info(f'Building inception model ' f'(using torchvision) ...') self.inception_model = build_inception_model(align_tf).cuda() self.logger.info(f'Finish building inception model.') if z is not None: assert isinstance(z, np.ndarray) assert z.ndim == 2 and z.shape[1] == self.z_space_dim fid_num = min(fid_num, z.shape[0]) z = torch.from_numpy(z).type(torch.FloatTensor) if not fid_num: return -1 indices = list(range(self.rank, fid_num, self.world_size)) self.logger.init_pbar() #generator = self.run_with_optim if run_with_optim else self.run_without_optim # Extract features from fake images. fake_feature_list = [] real_feature_list = [] task1 = self.logger.add_pbar_task('FID', total=fid_num) for batch_idx in range(0, len(indices), self.val_batch_size): sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] batch_size = len(sub_indices) data = next(self.val_loader) for key in data: if key != 'name': data[key] = data[key][:batch_size].cuda( torch.cuda.current_device(), non_blocking=True) # z_rand = torch.randn((batch_size,self.config.z_count,self.config.latent_dim) # , device=torch.cuda.current_device()) # data['z_rand'] = z_rand with torch.no_grad(): real_images = data['image'] #valids = data['valid'] return_dict = self.forward_pass(data, return_vals='fakes, wp_mixed') fakes = return_dict['fakes'] if self.config.test_time_optims != 0: wp_mixed = return_dict['wp_mixed'] fakes = self.optimize(data, wp_mixed) with torch.no_grad(): #final_out = real_images * valids + fakes * (1.0-valids) #Final output is the mixed one. fake_feature_list.append( extract_feature(self.inception_model, fakes)) # Extract features from real images if needed. cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy' do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache) if do_real_test: with torch.no_grad(): real_feature_list.append( extract_feature(self.inception_model, real_images)) self.logger.update_pbar(task1, batch_size * self.world_size) np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy', np.concatenate(fake_feature_list, axis=0)) if (do_real_test): np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy', np.concatenate(real_feature_list, axis=0)) # Extract features from real images if needed. # cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy' # do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache) # if do_real_test: # real_feature_list = [] # task2 = self.logger.add_pbar_task("Real", total=fid_num) # for batch_idx in range(0, len(indices), self.val_batch_size): # sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] # batch_size = len(sub_indices) # data = next(self.val_loader) # for key in data: # data[key] = data[key][:batch_size].cuda( # torch.cuda.current_device(), non_blocking=True) # with torch.no_grad(): # real_images = data['image'] # real_feature_list.append( # extract_feature(self.inception_model, real_images)) # self.logger.update_pbar(task2, batch_size * self.world_size) # np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy', # np.concatenate(real_feature_list, axis=0)) dist.barrier() if self.rank != 0: return -1 self.logger.close_pbar() # Collect fake features. fake_feature_list.clear() for rank in range(self.world_size): fake_feature_list.append( np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy')) os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy') fake_features = np.concatenate(fake_feature_list, axis=0) assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num feature_dim = fake_features.shape[1] pad = fid_num % self.world_size if pad: pad = self.world_size - pad fake_features = np.pad(fake_features, ((0, pad), (0, 0))) fake_features = fake_features.reshape(self.world_size, -1, feature_dim) fake_features = fake_features.transpose(1, 0, 2) fake_features = fake_features.reshape(-1, feature_dim)[:fid_num] # Collect (or load) real features. if do_real_test: real_feature_list.clear() for rank in range(self.world_size): real_feature_list.append( np.load(f'{self.work_dir}/real_fid_features_{rank}.npy')) os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy') real_features = np.concatenate(real_feature_list, axis=0) assert real_features.shape == (fid_num, feature_dim) real_features = np.pad(real_features, ((0, pad), (0, 0))) real_features = real_features.reshape( self.world_size, -1, feature_dim) real_features = real_features.transpose(1, 0, 2) real_features = real_features.reshape(-1, feature_dim)[:fid_num] np.save(cached_fid_file, real_features) else: real_features = np.load(cached_fid_file) assert real_features.shape == (fid_num, feature_dim) fid_value = compute_fid(fake_features, real_features) return fid_value def val(self, **val_kwargs): self.synthesize(**val_kwargs) def synthesize(self, num, html_name=None, save_raw_synthesis=False): """Synthesizes images. Args: num: Number of images to synthesize. z: Latent codes used for generation. If not specified, this function will sample latent codes randomly. (default: None) html_name: Name of the output html page for visualization. If not specified, no visualization page will be saved. (default: None) save_raw_synthesis: Whether to save raw synthesis on the disk. (default: False) """ dist.barrier() if self.rank != 0: return if not html_name and not save_raw_synthesis: return self.set_mode('val') if self.val_loader is None: self.build_dataset('val') # temp_dir = os.path.join(self.work_dir, 'synthesize_results') # os.makedirs(temp_dir, exist_ok=True) if not num: return # if num % self.val_batch_size != 0: # num = (num //self.val_batch_size +1)*self.val_batch_size # TODO: Use same z during the entire training process. self.logger.init_pbar() task = self.logger.add_pbar_task('Synthesis', total=num) for i in range(num): data = next(self.val_loader) for key in data: if key != 'name': data[key] = data[key].cuda( torch.cuda.current_device(), non_blocking=True) with torch.no_grad(): real_images = data['image'] return_dict = self.forward_pass(data, return_vals='all') fakes = return_dict['fakes'] wp_mixed = return_dict['wp_mixed'] eouts = return_dict['eouts'] log_list_gpu = {"real": real_images, "fake": fakes} # Add editings to log_list editings = ['age', 'pose', 'smile'] for edit in editings: direction = torch.load(f'editings/interfacegan_directions/{edit}.pt').cuda() factors = [+3, -3] for factor in factors: name = f"{edit}_{factor}" edit_wp = wp_mixed + factor * direction edited_images, _ = self.runG(edit_wp, "synthesis", highres_outs=eouts) # if edit == 'smile' and factor == -3: # res = gouts['gates'].shape[-1] # log_list_gpu[f'smile_-3_gate'] = ( torch.mean((gouts_edits['gates']) , dim=1, keepdim=True), 0) #edited_images = F.adaptive_avg_pool2d(edited_images, 256) log_list_gpu[name] = edited_images #log_list_gpu[f'{name}_gate'] = ( torch.mean((temp['gates']) , dim=1, keepdim=True), 0) #Add gate to log_list # res = gouts['gates'].shape[-1] # log_list_gpu[f'gate{res}x{res}'] = ( torch.mean((gouts['gates']) , dim=1, keepdim=True), 0) #Log images for log_name, log_val in log_list_gpu.items(): log_im = log_val[0] if type(log_val) is tuple else log_val min_val = log_val[1] if type(log_val) is tuple else -1 cpu_img = postprocess_tensor(log_im.detach().cpu(), min_val=min_val) grid = torchvision.utils.make_grid(cpu_img, nrow=5) log_image( name = f"image/{log_name}", grid=grid, iter=self.iter) self.logger.update_pbar(task, 1) self.logger.close_pbar() def save_edited_images(self, opts): dist.barrier() if self.rank != 0: return self.set_mode('val') if opts.method == 'inversion': pass elif opts.method == 'interfacegan': direction = torch.load(f'editings/interfacegan_directions/{opts.edit}.pt').cuda() elif opts.method == 'ganspace': ganspace_pca = torch.load('editings/ganspace_pca/ffhq_pca.pt') direction = { 'eye_openness': (54, 7, 8, 20), 'smile': (46, 4, 5, -20), 'beard': (58, 7, 9, -20), 'white_hair': (57, 7, 10, -24), 'lipstick': (34, 10, 11, 20), 'overexposed': (27, 8, 18, 15), 'screaming': (35, 3, 7, -10), 'head_angle_up': (11, 1, 4, 10), } editor = LatentEditor() elif opts.method == 'styleclip': #model_path = '/media/hdd2/adundar/hamza/hyperstyle/pretrained_models/stylegan2-ffhq-config-f.pt' # calculator_args = { # 'delta_i_c': 'editings/styleclip/global_directions/ffhq/fs3.npy', # 's_statistics': 'editings/styleclip/global_directions/ffhq/S_mean_std', # 'text_prompt_templates': 'editings/styleclip/global_directions/templates.txt' # } stylegan_model = load_stylegan_generator(opts.model_path) global_direction_calculator = load_direction_calculator(opts.calculator_args) #Eyeglasses 5, bangs 2, bobcut 5 # edit_args = {'alpha_min': 2, 'alpha_max': 2, 'num_alphas':1, 'beta_min':0.11, 'beta_max':0.11, 'num_betas': 1, # 'neutral_text':'face', 'target_text': 'face with bangs'} self.config.data['val']['root_dir'] = opts.dataset dataset = BaseDataset(**self.config.data['val']) val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) temp_dir = os.path.join(self.work_dir, opts.output) os.makedirs(temp_dir, exist_ok=True) self.logger.init_pbar() task = self.logger.add_pbar_task('Synthesis', total=len(val_loader)) global_i = 0 all_latents = {} for idx, data in enumerate(val_loader): for key in data: if key != 'name': data[key] = data[key].cuda( torch.cuda.current_device(), non_blocking=True) with torch.no_grad(): return_dict = self.forward_pass(data, return_vals='all', only_enc=True) wp_mixed = return_dict['wp_mixed'] eouts = return_dict['eouts'] #fakes = return_dict['fakes'] factors = np.linspace(0, 3, 100) global_i = 0 # for factor in factors: if opts.method == 'interfacegan': wp_mixed = wp_mixed + opts.factor * direction if opts.method == 'ganspace': #interpolate_dir = direction[opts.edit][0:3] + (factor,) wp_mixed = editor.apply_ganspace_(wp_mixed, ganspace_pca, [direction[opts.edit]]) #wp_edit = editor.apply_ganspace_(wp_mixed, ganspace_pca, [interpolate_dir]) # z = torch.randn((1,self.config.latent_dim), device=torch.cuda.current_device()) # z = self.runM(z) # diff = z - wp_mixed # edit = (diff * 3.5) / 10 # wp_mixed = wp_mixed + edit edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False) if opts.method == 'styleclip': # opts.edit_args['alpha_min'] = factor edited_images = styleclip_edit(wp_mixed, gouts_edits['additions'], stylegan_model, global_direction_calculator, opts.edit_args) edited_images = T.Resize((256,256))(edited_images) edited_images = postprocess_image(edited_images.detach().cpu().numpy()) for j in range(edited_images.shape[0]): # dir_name = data['name'][j][:-4] # os.makedirs(os.path.join(temp_dir, dir_name), exist_ok=True) # save_name = f'{global_i:03d}_' + data['name'][j] save_name = data['name'][j] pil_img = Image.fromarray(edited_images[j]) #.resize((256,256)) #pil_img.save(os.path.join(temp_dir, dir_name, save_name )) pil_img.save(os.path.join(temp_dir, save_name )) global_i += 1 # if global_i >= 1000: # break # if global_i % 100 == 0: # print(f"{global_i}/1000") self.logger.update_pbar(task, 1) self.logger.close_pbar() # def forward_pass(self, data, return_vals='all', only_enc=False): # encoder_type = self.config.encoder_type # forward_func = getattr(self, f'{encoder_type}_forward') # return_dict = forward_func(data,only_enc) # return return_dict # # if return_vals == 'all': # # return return_dict # # requested = return_vals.split(',') # # modified_dict = {} # # for request in requested: # # stripped_request = request.strip() # # modified_dict[stripped_request] = return_dict[stripped_request] # # return modified_dict # def base_forward(self, data): # reals = data['image'] # valids = data['valid'] # z_rand = data['z_rand'] # wp_rand = self.runM(z_rand) # wp_enc, blender = self.runE(reals, valids) # wp_mixed = self.mix(wp_enc, wp_rand, blender) # fakes = self.runG(wp_mixed, 'synthesis') # return_dict = {'fakes': fakes, 'wp_enc': wp_enc, 'blender': blender, 'wp_mixed':wp_mixed} # return return_dict # def pSp_forward(self, data, only_enc): # return self.e4e_forward(data, only_enc) # def train_forward(self, data, iscycle=False): # reals = data['image'] # direction = data['direction'] # edit_name = data['edit_name'] # factor = data['factor'] # E = self.models['encoder'] # with torch.no_grad(): # wp, eouts = E(reals) # #wp = wp + self.meanw.repeat(reals.shape[0], 1, 1) # edit = torch.zeros_like(wp) # for i in range (edit.shape[0]): # if edit_name[i] is None: # edit[i] = 0 # elif edit_name[i] == 'randw': # diff = direction[i] - wp[i] # # one_hot = [1] * 8 + [0] * 10 # # one_hot = torch.tensor(one_hot, device=diff.device).unsqueeze(1) # # diff = diff * one_hot # #norm = torch.linalg.norm(diff, dim=1, keepdim=True) # edit[i] = (diff * factor[i]) / 10 # elif edit_name[i] == 'interface': # edit[i] = (factor[i] * direction[i]) # # # Debug # # with torch.no_grad(): # # fakes,_ =self.runG(wp, 'synthesis', highres_outs=None) # # fakes = postprocess_image(fakes.detach().cpu().numpy()) # # for i in range(fakes.shape[0]): # # pil_img = Image.fromarray(fakes[i]).resize((256,256)) # # pil_img.save(f'{self.iter}_orig.png') # # fakes,_ =self.runG(wp+edit, 'synthesis', highres_outs=None) # # fakes = postprocess_image(fakes.detach().cpu().numpy()) # # for i in range(fakes.shape[0]): # # pil_img = Image.fromarray(fakes[i]).resize((256,256)) # # pil_img.save(f'{self.iter}_edit.png') # # fakes,_ =self.runG(direction.unsqueeze(1).repeat(1,18,1), 'synthesis', highres_outs=None) # # fakes = postprocess_image(fakes.detach().cpu().numpy()) # # for i in range(fakes.shape[0]): # # pil_img = Image.fromarray(fakes[i]).resize((256,256)) # # pil_img.save(f'{self.iter}_rand.png') # with torch.no_grad(): # eouts['inversion'] = self.runG(wp, 'synthesis', highres_outs=None, return_f=True) # wp = wp + edit # fakes, gouts = self.runG(wp, 'synthesis', highres_outs=eouts) # #fakes = F.adaptive_avg_pool2d(fakes, (256,256)) # fakes_cycle = None # if iscycle: # # wp_cycle = wp_cycle + self.meanw.repeat(reals.shape[0], 1, 1) # with torch.no_grad(): # wp_cycle, eout_cycle = E(fakes) # eout_cycle['inversion'] = self.runG(wp_cycle, 'synthesis', highres_outs=None, return_f=True) # #wp_cycle = wp_cycle - edit # wp_cycle = wp_cycle - edit # #wp_cycle = wp_cycle - (data['factor'] * data['direction']).unsqueeze(1) # fakes_cycle, _ = self.runG(wp_cycle, 'synthesis', highres_outs=eout_cycle) # #fakes_cycle = F.adaptive_avg_pool2d(fakes, (256,256)) # #cycle = F.mse_loss(fakes_cycle, reals, reduction='mean') # return_dict = {'fakes': fakes, 'wp_mixed':wp, 'gouts':gouts, 'eouts': eouts, 'cycle': fakes_cycle} # return return_dict # def e4e_forward(self, data, only_enc=False): # #return self.base_forward(data) # reals = data['image'] # #valids = data['valid'] # E = self.models['encoder'] # wp_mixed, eouts = E(reals) # #wp_mixed = wp_mixed + self.meanw.repeat(reals.shape[0], 1, 1) # eouts['inversion'] = self.runG(wp_mixed, 'synthesis', highres_outs=None, return_f=True) # if only_enc: # return_dict = {'wp_mixed':wp_mixed,'eouts': eouts} # return return_dict # fakes, gouts = self.runG(wp_mixed, 'synthesis', highres_outs=eouts) # #fakes = self.runG(wp_mixed, 'synthesis', highres_outs=None) # #fakes = F.adaptive_avg_pool2d(fakes, (256,256)) # return_dict = {'fakes': fakes, 'wp_mixed':wp_mixed, 'gouts':gouts, 'eouts': eouts} # return return_dict # def hyperstyle_forward(self, data): # return_dict = self.base_forward(data) # E = self.models['encoder'] # reals = data['image'] # valids = data['valid'] # #HyperNetwork # weight_deltas = E(reals, valids, mode='hyper', gouts=return_dict['fakes']) # fakes = self.runG(return_dict['wp_mixed'], 'synthesis', weight_deltas=weight_deltas) # return_dict['fakes'] = fakes # return return_dict def interface_generate(self, num, edit, factor): direction = torch.load(f'editings/interfacegan_directions/{edit}.pt').cuda() indices = list(range(self.rank, num, self.world_size)) gt_path = os.path.join(self.work_dir, f'interfacegan_gt') smile_add_path = os.path.join(self.work_dir, f'interfacegan_{edit}_{factor}') smile_rm_path = os.path.join(self.work_dir, f'interfacegan_{edit}_-{factor}') if self.rank == 0: os.makedirs(gt_path, exist_ok=True) os.makedirs(smile_add_path, exist_ok=True) os.makedirs(smile_rm_path, exist_ok=True) dist.barrier() self.logger.init_pbar() task = self.logger.add_pbar_task('Interfacegan', total=num) for batch_idx in range(0, len(indices), self.val_batch_size): sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] batch_size = len(sub_indices) z = torch.randn((batch_size,512), device=torch.cuda.current_device()) w_r = self.runM(z, repeat_w=True) gt_imgs,_ = self.runG(w_r, resize=False) gt_imgs = postprocess_image(gt_imgs.detach().cpu().numpy()) for i in range(gt_imgs.shape[0]): save_name = str(sub_indices[i]) + ".png" pil_img = Image.fromarray(gt_imgs[i]).resize((256,256)) pil_img.save(os.path.join(gt_path, save_name )) smile_added, _ = self.runG(w_r + factor*direction, resize=False) smile_added = postprocess_image(smile_added.detach().cpu().numpy()) for i in range(gt_imgs.shape[0]): save_name = str(sub_indices[i]) + ".png" pil_img = Image.fromarray(smile_added[i]).resize((256,256)) pil_img.save(os.path.join(smile_add_path, save_name )) smile_removed, _= self.runG(w_r - factor*direction, resize=False) smile_removed = postprocess_image(smile_removed.detach().cpu().numpy()) for i in range(gt_imgs.shape[0]): save_name = str(sub_indices[i]) + ".png" pil_img = Image.fromarray(smile_removed[i]).resize((256,256)) pil_img.save(os.path.join(smile_rm_path, save_name )) self.logger.update_pbar(task, batch_size * self.world_size) self.logger.close_pbar() def grad_edit(self, edit, factor, dataset=None): dist.barrier() if self.rank != 0: return self.set_mode('val') edit_name = edit edit = 'val' self.config.data[edit]['root_dir'] = dataset dataset = BaseDataset(**self.config.data[edit]) val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) temp_dir = os.path.join(self.work_dir, f'fakes_{edit_name}_{factor}') os.makedirs(temp_dir, exist_ok=True) self.logger.init_pbar() task = self.logger.add_pbar_task('Synthesis', total=len(val_loader)) global_i = 0 args = {'model': 'ffhq', 'model_dir': '/media/hdd2/adundar/hamza/genforce/editings/GradCtrl/model_ffhq', 'attribute': edit_name, 'exclude': 'default', 'top_channels': 'default', 'layerwise': 'default' } for idx, data in enumerate(val_loader): for key in data: if key != 'name': data[key] = data[key].cuda( torch.cuda.current_device(), non_blocking=True) with torch.no_grad(): return_dict = self.forward_pass(data, return_vals='all', only_enc=True) wp_mixed = return_dict['wp_mixed'] eouts = return_dict['eouts'] #fakes = return_dict['fakes'] edit_wp = gradctrl(args, wp_mixed, factor) edited_images, gouts_edits = self.runG(edit_wp, "synthesis", highres_outs=eouts, resize=False) #edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False) edited_images = postprocess_image(edited_images.detach().cpu().numpy()) for j in range(edited_images.shape[0]): save_name = data['name'][j] pil_img = Image.fromarray(edited_images[j]).resize((256,256)) pil_img.save(os.path.join(temp_dir, save_name )) global_i += 1 self.logger.update_pbar(task, 1) self.logger.close_pbar() def measure_time(self, edit, factor, dataset=None, save_latents=False): dist.barrier() if self.rank != 0: return self.set_mode('val') edit_name = edit if dataset is not None: edit = 'val' self.config.data[edit]['root_dir'] = dataset dataset = BaseDataset(**self.config.data[edit]) val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) global_i = 0 time_list = [] for idx, data in enumerate(val_loader): for key in data: if key != 'name': data[key] = data[key].cuda( torch.cuda.current_device(), non_blocking=True) with torch.no_grad(): start = time.time() return_dict = self.forward_pass(data, return_vals='all', only_enc=True) wp_mixed = return_dict['wp_mixed'] eouts = return_dict['eouts'] edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False) end = time.time() time_list.append(end-start) print(np.mean(time_list)) print(np.mean(time_list[1:]))