StyleRes / editings /base_encoder_runner.py
hamzapehlivan
Intial Commit
6709fc9
raw
history blame
40.1 kB
# python3.7
"""Contains the base class for Encoder (GAN Inversion) runner."""
import os
import shutil
import torch
import torch.distributed as dist
import torchvision.transforms as T
from utils.visualizer import HtmlPageVisualizer
from utils.visualizer import get_grid_shape
from utils.visualizer import postprocess_image
from utils.visualizer import save_image
from utils.visualizer import load_image
from utils.visualizer import postprocess_tensor
from metrics.inception import build_inception_model
from metrics.fid import extract_feature
from metrics.fid import compute_fid
from metrics.MSSIM import MSSSIM
from metrics.LPIPS import LPIPS
import numpy as np
from .base_runner import BaseRunner
from datasets import BaseDataset
from torch.utils.data import DataLoader
from PIL import Image
from runners.controllers.summary_writer import log_image
import torchvision
from editings.latent_editor import LatentEditor
from editings.styleclip.edit_hfgi import styleclip_edit, load_stylegan_generator,load_direction_calculator
from editings.GradCtrl.manipulate import main as gradctrl
import torch.nn.functional as F
import time
__all__ = ['BaseEncoderRunner']
class BaseEncoderRunner(BaseRunner):
"""Defines the base class for Encoder runner."""
def __init__(self, config, logger):
super().__init__(config, logger)
self.inception_model = None
def build_models(self):
super().build_models()
assert 'encoder' in self.models
assert 'generator_smooth' in self.models
assert 'discriminator' in self.models
self.resolution = self.models['generator_smooth'].resolution
self.G_kwargs_train = self.config.modules['generator_smooth'].get(
'kwargs_train', dict())
self.G_kwargs_val = self.config.modules['generator_smooth'].get(
'kwargs_val', dict())
self.D_kwargs_train = self.config.modules['discriminator'].get(
'kwargs_train', dict())
self.D_kwargs_val = self.config.modules['discriminator'].get(
'kwargs_val', dict())
if self.config.use_disc2:
self.D2_kwargs_train = self.config.modules['discriminator2'].get(
'kwargs_train', dict())
self.D2_kwargs_val = self.config.modules['discriminator2'].get(
'kwargs_val', dict())
if self.config.mapping_method != 'pretrained':
self.M_kwargs_train = self.config.modules['mapping'].get(
'kwargs_train', dict())
self.M_kwargs_val = self.config.modules['mapping'].get(
'kwargs_val', dict())
if self.config.create_mixing_network:
self.MIX_kwargs_train = self.config.modules['mixer'].get(
'kwargs_train', dict())
self.MIX_kwargs_val = self.config.modules['mixer'].get(
'kwargs_val', dict())
def train_step(self, data, **train_kwargs):
raise NotImplementedError('Should be implemented in derived class.')
def mse(self, mse_num):
if mse_num == 0:
return -1
self.set_mode('val')
if self.val_loader is None:
self.build_dataset('val')
if mse_num == "auto":
mse_num = len(self.val_loader.dataset)
indices = list(range(self.rank, mse_num, self.world_size))
self.logger.init_pbar()
task1 = self.logger.add_pbar_task('MSE-LPIPS-SSIM', total=mse_num)
lpips = LPIPS()
ssim = MSSSIM(size_average=False)
n_evals = 3
gather_list = [torch.zeros( (self.val_batch_size, n_evals), device=torch.cuda.current_device()) for i in range(self.world_size)]
all_errors = np.zeros( (mse_num, n_evals), dtype=np.float64)
shared_tensor = torch.zeros((self.val_batch_size, n_evals), device=torch.cuda.current_device())
gather_idx = 0
for batch_idx in range(0, len(indices), self.val_batch_size):
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
batch_size = len(sub_indices)
data = next(self.val_loader)
for key in data:
if key != 'name':
data[key] = data[key][:batch_size].cuda(
torch.cuda.current_device(), non_blocking=True)
with torch.no_grad():
real_images = data['image']
return_dict = self.forward_pass(data, return_vals='fakes, wp_mixed')
fakes = return_dict['fakes']
shared_tensor[:, 0] = torch.mean((fakes - real_images)**2, dim=(1,2,3)) #MSE Error
shared_tensor[:, 1]= lpips(real_images, fakes)
shared_tensor[:, 2]= ssim(real_images, fakes) #ssim (real_images[0].unsqueeze(0), fakes[0].unsqueeze(0) )
dist.all_gather(gather_list, shared_tensor)
if self.rank == 0:
for t in gather_list:
all_errors[gather_idx:gather_idx+batch_size, 0] = t[:,0].cpu().numpy()
all_errors[gather_idx:gather_idx+batch_size, 1] = t[:,1].cpu().numpy()
all_errors[gather_idx:gather_idx+batch_size, 2] = t[:,2].cpu().numpy()
gather_idx = gather_idx+batch_size
self.logger.update_pbar(task1, batch_size * self.world_size)
self.logger.close_pbar()
mean_lst, std_lst = np.mean(all_errors, axis=0), np.std(all_errors, axis=0)
mse_mean, lpips_mean, ssim_mean = mean_lst[0].item(), mean_lst[1].item(), mean_lst[2].item()
mse_std, lpips_std, ssim_std = std_lst[0].item(), std_lst[1].item(), std_lst[2].item()
return_vals = {'mse': (mse_mean,mse_std), 'lpips': (lpips_mean, lpips_std), 'ssim':(ssim_mean, ssim_std)}
return return_vals
def fid_attribute(self,
fid_num,
z=None,
ignore_cache=False,
align_tf=True,
attribute='smile', factor=1, direction=None):
"""Computes the FID metric."""
self.set_mode('val')
direction = torch.load(f'editings/interfacegan_directions/{attribute}.pt').cuda()
if factor < 0:
self.config.data['smile']['root_dir'] = f'/media/hdd2/adundar/hamza/genforce/data/temp/smile_with_original'
elif factor > 0:
self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_without_original"
fake_loader = self.build_dataset(f"smile")
#fid_num = min(fid_num, len(self.val_loader.dataset))
fid_num = len(fake_loader.dataset)
if self.inception_model is None:
if align_tf:
self.logger.info(f'Building inception model '
f'(aligned with TensorFlow) ...')
else:
self.logger.info(f'Building inception model '
f'(using torchvision) ...')
self.inception_model = build_inception_model(align_tf).cuda()
self.logger.info(f'Finish building inception model.')
if z is not None:
assert isinstance(z, np.ndarray)
assert z.ndim == 2 and z.shape[1] == self.z_space_dim
fid_num = min(fid_num, z.shape[0])
z = torch.from_numpy(z).type(torch.FloatTensor)
if not fid_num:
return -1
indices = list(range(self.rank, fid_num, self.world_size))
self.logger.init_pbar()
# Extract features from fake images.
fake_feature_list = []
task1 = self.logger.add_pbar_task(f'FID-{attribute}_fake', total=fid_num)
for batch_idx in range(0, len(indices), self.val_batch_size):
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
batch_size = len(sub_indices)
data = next(fake_loader)
for key in data:
if key != 'name':
data[key] = data[key][:batch_size].cuda(
torch.cuda.current_device(), non_blocking=True)
with torch.no_grad():
real_images = data['image']
#valids = data['valid']
return_dict = self.forward_pass(data, return_vals='all', only_enc = True)
wp = return_dict['wp_mixed']
eouts = return_dict['eouts']
edit_wp = wp + factor * direction
edited_images, _ = self.runG(edit_wp, "synthesis", highres_outs=eouts)
fake_feature_list.append(
extract_feature(self.inception_model, edited_images))
self.logger.update_pbar(task1, batch_size * self.world_size)
np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy',
np.concatenate(fake_feature_list, axis=0))
self.logger.close_pbar()
#Extract features from real images if needed.
cached_fid_file = f'{self.work_dir}/real_{attribute}_{factor}_fid.npy'
do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache)
if do_real_test:
real_feature_list = []
self.logger.init_pbar()
if factor < 0:
self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_without_original"
elif factor > 0:
self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_with_original"
real_loader = self.build_dataset(f"smile")
fid_num = len(real_loader.dataset)
indices = list(range(self.rank, fid_num, self.world_size))
task2 = self.logger.add_pbar_task(f"{attribute}_real", total=fid_num)
for batch_idx in range(0, len(indices), self.val_batch_size):
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
batch_size = len(sub_indices)
data = next(real_loader)
for key in data:
if key != 'name':
data[key] = data[key][:batch_size].cuda(
torch.cuda.current_device(), non_blocking=True)
with torch.no_grad():
real_images = data['image']
real_feature_list.append(
extract_feature(self.inception_model, real_images))
self.logger.update_pbar(task2, batch_size * self.world_size)
np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy',
np.concatenate(real_feature_list, axis=0))
dist.barrier()
if self.rank != 0:
return -1
self.logger.close_pbar()
# Collect fake features.
fake_feature_list.clear()
for rank in range(self.world_size):
fake_feature_list.append(
np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy'))
os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy')
fake_features = np.concatenate(fake_feature_list, axis=0)
# assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num
feature_dim = fake_features.shape[1]
feature_num = fake_features.shape[0]
pad = feature_num % self.world_size #feature_dim.shape[0]
if pad:
pad = self.world_size - pad
fake_features = np.pad(fake_features, ((0, pad), (0, 0)))
fake_features = fake_features.reshape(self.world_size, -1, feature_dim)
fake_features = fake_features.transpose(1, 0, 2)
fake_features = fake_features.reshape(-1, feature_dim)[:feature_num]
# Collect (or load) real features.
if do_real_test:
real_feature_list.clear()
for rank in range(self.world_size):
real_feature_list.append(
np.load(f'{self.work_dir}/real_fid_features_{rank}.npy'))
os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy')
real_features = np.concatenate(real_feature_list, axis=0)
# assert real_features.shape == (fid_num, feature_dim)
feature_dim = real_features.shape[1]
feature_num = real_features.shape[0]
pad = feature_num % self.world_size
if pad:
pad = self.world_size - pad
real_features = np.pad(real_features, ((0, pad), (0, 0)))
real_features = real_features.reshape(
self.world_size, -1, feature_dim)
real_features = real_features.transpose(1, 0, 2)
real_features = real_features.reshape(-1, feature_dim)[:feature_num]
np.save(cached_fid_file, real_features)
else:
real_features = np.load(cached_fid_file)
# assert real_features.shape == (fid_num, feature_dim)
fid_value = compute_fid(fake_features, real_features)
return fid_value
def fid(self,
fid_num,
z=None,
ignore_cache=False,
align_tf=True):
"""Computes the FID metric."""
self.set_mode('val')
if self.val_loader is None:
self.build_dataset('val')
fid_num = min(fid_num, len(self.val_loader.dataset))
if self.inception_model is None:
if align_tf:
self.logger.info(f'Building inception model '
f'(aligned with TensorFlow) ...')
else:
self.logger.info(f'Building inception model '
f'(using torchvision) ...')
self.inception_model = build_inception_model(align_tf).cuda()
self.logger.info(f'Finish building inception model.')
if z is not None:
assert isinstance(z, np.ndarray)
assert z.ndim == 2 and z.shape[1] == self.z_space_dim
fid_num = min(fid_num, z.shape[0])
z = torch.from_numpy(z).type(torch.FloatTensor)
if not fid_num:
return -1
indices = list(range(self.rank, fid_num, self.world_size))
self.logger.init_pbar()
#generator = self.run_with_optim if run_with_optim else self.run_without_optim
# Extract features from fake images.
fake_feature_list = []
real_feature_list = []
task1 = self.logger.add_pbar_task('FID', total=fid_num)
for batch_idx in range(0, len(indices), self.val_batch_size):
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
batch_size = len(sub_indices)
data = next(self.val_loader)
for key in data:
if key != 'name':
data[key] = data[key][:batch_size].cuda(
torch.cuda.current_device(), non_blocking=True)
# z_rand = torch.randn((batch_size,self.config.z_count,self.config.latent_dim)
# , device=torch.cuda.current_device())
# data['z_rand'] = z_rand
with torch.no_grad():
real_images = data['image']
#valids = data['valid']
return_dict = self.forward_pass(data, return_vals='fakes, wp_mixed')
fakes = return_dict['fakes']
if self.config.test_time_optims != 0:
wp_mixed = return_dict['wp_mixed']
fakes = self.optimize(data, wp_mixed)
with torch.no_grad():
#final_out = real_images * valids + fakes * (1.0-valids) #Final output is the mixed one.
fake_feature_list.append(
extract_feature(self.inception_model, fakes))
# Extract features from real images if needed.
cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy'
do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache)
if do_real_test:
with torch.no_grad():
real_feature_list.append(
extract_feature(self.inception_model, real_images))
self.logger.update_pbar(task1, batch_size * self.world_size)
np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy',
np.concatenate(fake_feature_list, axis=0))
if (do_real_test):
np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy',
np.concatenate(real_feature_list, axis=0))
# Extract features from real images if needed.
# cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy'
# do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache)
# if do_real_test:
# real_feature_list = []
# task2 = self.logger.add_pbar_task("Real", total=fid_num)
# for batch_idx in range(0, len(indices), self.val_batch_size):
# sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
# batch_size = len(sub_indices)
# data = next(self.val_loader)
# for key in data:
# data[key] = data[key][:batch_size].cuda(
# torch.cuda.current_device(), non_blocking=True)
# with torch.no_grad():
# real_images = data['image']
# real_feature_list.append(
# extract_feature(self.inception_model, real_images))
# self.logger.update_pbar(task2, batch_size * self.world_size)
# np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy',
# np.concatenate(real_feature_list, axis=0))
dist.barrier()
if self.rank != 0:
return -1
self.logger.close_pbar()
# Collect fake features.
fake_feature_list.clear()
for rank in range(self.world_size):
fake_feature_list.append(
np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy'))
os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy')
fake_features = np.concatenate(fake_feature_list, axis=0)
assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num
feature_dim = fake_features.shape[1]
pad = fid_num % self.world_size
if pad:
pad = self.world_size - pad
fake_features = np.pad(fake_features, ((0, pad), (0, 0)))
fake_features = fake_features.reshape(self.world_size, -1, feature_dim)
fake_features = fake_features.transpose(1, 0, 2)
fake_features = fake_features.reshape(-1, feature_dim)[:fid_num]
# Collect (or load) real features.
if do_real_test:
real_feature_list.clear()
for rank in range(self.world_size):
real_feature_list.append(
np.load(f'{self.work_dir}/real_fid_features_{rank}.npy'))
os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy')
real_features = np.concatenate(real_feature_list, axis=0)
assert real_features.shape == (fid_num, feature_dim)
real_features = np.pad(real_features, ((0, pad), (0, 0)))
real_features = real_features.reshape(
self.world_size, -1, feature_dim)
real_features = real_features.transpose(1, 0, 2)
real_features = real_features.reshape(-1, feature_dim)[:fid_num]
np.save(cached_fid_file, real_features)
else:
real_features = np.load(cached_fid_file)
assert real_features.shape == (fid_num, feature_dim)
fid_value = compute_fid(fake_features, real_features)
return fid_value
def val(self, **val_kwargs):
self.synthesize(**val_kwargs)
def synthesize(self,
num,
html_name=None,
save_raw_synthesis=False):
"""Synthesizes images.
Args:
num: Number of images to synthesize.
z: Latent codes used for generation. If not specified, this function
will sample latent codes randomly. (default: None)
html_name: Name of the output html page for visualization. If not
specified, no visualization page will be saved. (default: None)
save_raw_synthesis: Whether to save raw synthesis on the disk.
(default: False)
"""
dist.barrier()
if self.rank != 0:
return
if not html_name and not save_raw_synthesis:
return
self.set_mode('val')
if self.val_loader is None:
self.build_dataset('val')
# temp_dir = os.path.join(self.work_dir, 'synthesize_results')
# os.makedirs(temp_dir, exist_ok=True)
if not num:
return
# if num % self.val_batch_size != 0:
# num = (num //self.val_batch_size +1)*self.val_batch_size
# TODO: Use same z during the entire training process.
self.logger.init_pbar()
task = self.logger.add_pbar_task('Synthesis', total=num)
for i in range(num):
data = next(self.val_loader)
for key in data:
if key != 'name':
data[key] = data[key].cuda(
torch.cuda.current_device(), non_blocking=True)
with torch.no_grad():
real_images = data['image']
return_dict = self.forward_pass(data, return_vals='all')
fakes = return_dict['fakes']
wp_mixed = return_dict['wp_mixed']
eouts = return_dict['eouts']
log_list_gpu = {"real": real_images, "fake": fakes}
# Add editings to log_list
editings = ['age', 'pose', 'smile']
for edit in editings:
direction = torch.load(f'editings/interfacegan_directions/{edit}.pt').cuda()
factors = [+3, -3]
for factor in factors:
name = f"{edit}_{factor}"
edit_wp = wp_mixed + factor * direction
edited_images, _ = self.runG(edit_wp, "synthesis", highres_outs=eouts)
# if edit == 'smile' and factor == -3:
# res = gouts['gates'].shape[-1]
# log_list_gpu[f'smile_-3_gate'] = ( torch.mean((gouts_edits['gates']) , dim=1, keepdim=True), 0)
#edited_images = F.adaptive_avg_pool2d(edited_images, 256)
log_list_gpu[name] = edited_images
#log_list_gpu[f'{name}_gate'] = ( torch.mean((temp['gates']) , dim=1, keepdim=True), 0)
#Add gate to log_list
# res = gouts['gates'].shape[-1]
# log_list_gpu[f'gate{res}x{res}'] = ( torch.mean((gouts['gates']) , dim=1, keepdim=True), 0)
#Log images
for log_name, log_val in log_list_gpu.items():
log_im = log_val[0] if type(log_val) is tuple else log_val
min_val = log_val[1] if type(log_val) is tuple else -1
cpu_img = postprocess_tensor(log_im.detach().cpu(), min_val=min_val)
grid = torchvision.utils.make_grid(cpu_img, nrow=5)
log_image( name = f"image/{log_name}", grid=grid, iter=self.iter)
self.logger.update_pbar(task, 1)
self.logger.close_pbar()
def save_edited_images(self, opts):
dist.barrier()
if self.rank != 0:
return
self.set_mode('val')
if opts.method == 'inversion':
pass
elif opts.method == 'interfacegan':
direction = torch.load(f'editings/interfacegan_directions/{opts.edit}.pt').cuda()
elif opts.method == 'ganspace':
ganspace_pca = torch.load('editings/ganspace_pca/ffhq_pca.pt')
direction = {
'eye_openness': (54, 7, 8, 20),
'smile': (46, 4, 5, -20),
'beard': (58, 7, 9, -20),
'white_hair': (57, 7, 10, -24),
'lipstick': (34, 10, 11, 20),
'overexposed': (27, 8, 18, 15),
'screaming': (35, 3, 7, -10),
'head_angle_up': (11, 1, 4, 10),
}
editor = LatentEditor()
elif opts.method == 'styleclip':
#model_path = '/media/hdd2/adundar/hamza/hyperstyle/pretrained_models/stylegan2-ffhq-config-f.pt'
# calculator_args = {
# 'delta_i_c': 'editings/styleclip/global_directions/ffhq/fs3.npy',
# 's_statistics': 'editings/styleclip/global_directions/ffhq/S_mean_std',
# 'text_prompt_templates': 'editings/styleclip/global_directions/templates.txt'
# }
stylegan_model = load_stylegan_generator(opts.model_path)
global_direction_calculator = load_direction_calculator(opts.calculator_args)
#Eyeglasses 5, bangs 2, bobcut 5
# edit_args = {'alpha_min': 2, 'alpha_max': 2, 'num_alphas':1, 'beta_min':0.11, 'beta_max':0.11, 'num_betas': 1,
# 'neutral_text':'face', 'target_text': 'face with bangs'}
self.config.data['val']['root_dir'] = opts.dataset
dataset = BaseDataset(**self.config.data['val'])
val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)
temp_dir = os.path.join(self.work_dir, opts.output)
os.makedirs(temp_dir, exist_ok=True)
self.logger.init_pbar()
task = self.logger.add_pbar_task('Synthesis', total=len(val_loader))
global_i = 0
all_latents = {}
for idx, data in enumerate(val_loader):
for key in data:
if key != 'name':
data[key] = data[key].cuda(
torch.cuda.current_device(), non_blocking=True)
with torch.no_grad():
return_dict = self.forward_pass(data, return_vals='all', only_enc=True)
wp_mixed = return_dict['wp_mixed']
eouts = return_dict['eouts']
#fakes = return_dict['fakes']
factors = np.linspace(0, 3, 100)
global_i = 0
# for factor in factors:
if opts.method == 'interfacegan':
wp_mixed = wp_mixed + opts.factor * direction
if opts.method == 'ganspace':
#interpolate_dir = direction[opts.edit][0:3] + (factor,)
wp_mixed = editor.apply_ganspace_(wp_mixed, ganspace_pca, [direction[opts.edit]])
#wp_edit = editor.apply_ganspace_(wp_mixed, ganspace_pca, [interpolate_dir])
# z = torch.randn((1,self.config.latent_dim), device=torch.cuda.current_device())
# z = self.runM(z)
# diff = z - wp_mixed
# edit = (diff * 3.5) / 10
# wp_mixed = wp_mixed + edit
edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False)
if opts.method == 'styleclip':
# opts.edit_args['alpha_min'] = factor
edited_images = styleclip_edit(wp_mixed, gouts_edits['additions'], stylegan_model, global_direction_calculator, opts.edit_args)
edited_images = T.Resize((256,256))(edited_images)
edited_images = postprocess_image(edited_images.detach().cpu().numpy())
for j in range(edited_images.shape[0]):
# dir_name = data['name'][j][:-4]
# os.makedirs(os.path.join(temp_dir, dir_name), exist_ok=True)
# save_name = f'{global_i:03d}_' + data['name'][j]
save_name = data['name'][j]
pil_img = Image.fromarray(edited_images[j]) #.resize((256,256))
#pil_img.save(os.path.join(temp_dir, dir_name, save_name ))
pil_img.save(os.path.join(temp_dir, save_name ))
global_i += 1
# if global_i >= 1000:
# break
# if global_i % 100 == 0:
# print(f"{global_i}/1000")
self.logger.update_pbar(task, 1)
self.logger.close_pbar()
# def forward_pass(self, data, return_vals='all', only_enc=False):
# encoder_type = self.config.encoder_type
# forward_func = getattr(self, f'{encoder_type}_forward')
# return_dict = forward_func(data,only_enc)
# return return_dict
# # if return_vals == 'all':
# # return return_dict
# # requested = return_vals.split(',')
# # modified_dict = {}
# # for request in requested:
# # stripped_request = request.strip()
# # modified_dict[stripped_request] = return_dict[stripped_request]
# # return modified_dict
# def base_forward(self, data):
# reals = data['image']
# valids = data['valid']
# z_rand = data['z_rand']
# wp_rand = self.runM(z_rand)
# wp_enc, blender = self.runE(reals, valids)
# wp_mixed = self.mix(wp_enc, wp_rand, blender)
# fakes = self.runG(wp_mixed, 'synthesis')
# return_dict = {'fakes': fakes, 'wp_enc': wp_enc, 'blender': blender, 'wp_mixed':wp_mixed}
# return return_dict
# def pSp_forward(self, data, only_enc):
# return self.e4e_forward(data, only_enc)
# def train_forward(self, data, iscycle=False):
# reals = data['image']
# direction = data['direction']
# edit_name = data['edit_name']
# factor = data['factor']
# E = self.models['encoder']
# with torch.no_grad():
# wp, eouts = E(reals)
# #wp = wp + self.meanw.repeat(reals.shape[0], 1, 1)
# edit = torch.zeros_like(wp)
# for i in range (edit.shape[0]):
# if edit_name[i] is None:
# edit[i] = 0
# elif edit_name[i] == 'randw':
# diff = direction[i] - wp[i]
# # one_hot = [1] * 8 + [0] * 10
# # one_hot = torch.tensor(one_hot, device=diff.device).unsqueeze(1)
# # diff = diff * one_hot
# #norm = torch.linalg.norm(diff, dim=1, keepdim=True)
# edit[i] = (diff * factor[i]) / 10
# elif edit_name[i] == 'interface':
# edit[i] = (factor[i] * direction[i])
# # # Debug
# # with torch.no_grad():
# # fakes,_ =self.runG(wp, 'synthesis', highres_outs=None)
# # fakes = postprocess_image(fakes.detach().cpu().numpy())
# # for i in range(fakes.shape[0]):
# # pil_img = Image.fromarray(fakes[i]).resize((256,256))
# # pil_img.save(f'{self.iter}_orig.png')
# # fakes,_ =self.runG(wp+edit, 'synthesis', highres_outs=None)
# # fakes = postprocess_image(fakes.detach().cpu().numpy())
# # for i in range(fakes.shape[0]):
# # pil_img = Image.fromarray(fakes[i]).resize((256,256))
# # pil_img.save(f'{self.iter}_edit.png')
# # fakes,_ =self.runG(direction.unsqueeze(1).repeat(1,18,1), 'synthesis', highres_outs=None)
# # fakes = postprocess_image(fakes.detach().cpu().numpy())
# # for i in range(fakes.shape[0]):
# # pil_img = Image.fromarray(fakes[i]).resize((256,256))
# # pil_img.save(f'{self.iter}_rand.png')
# with torch.no_grad():
# eouts['inversion'] = self.runG(wp, 'synthesis', highres_outs=None, return_f=True)
# wp = wp + edit
# fakes, gouts = self.runG(wp, 'synthesis', highres_outs=eouts)
# #fakes = F.adaptive_avg_pool2d(fakes, (256,256))
# fakes_cycle = None
# if iscycle:
# # wp_cycle = wp_cycle + self.meanw.repeat(reals.shape[0], 1, 1)
# with torch.no_grad():
# wp_cycle, eout_cycle = E(fakes)
# eout_cycle['inversion'] = self.runG(wp_cycle, 'synthesis', highres_outs=None, return_f=True)
# #wp_cycle = wp_cycle - edit
# wp_cycle = wp_cycle - edit
# #wp_cycle = wp_cycle - (data['factor'] * data['direction']).unsqueeze(1)
# fakes_cycle, _ = self.runG(wp_cycle, 'synthesis', highres_outs=eout_cycle)
# #fakes_cycle = F.adaptive_avg_pool2d(fakes, (256,256))
# #cycle = F.mse_loss(fakes_cycle, reals, reduction='mean')
# return_dict = {'fakes': fakes, 'wp_mixed':wp, 'gouts':gouts, 'eouts': eouts, 'cycle': fakes_cycle}
# return return_dict
# def e4e_forward(self, data, only_enc=False):
# #return self.base_forward(data)
# reals = data['image']
# #valids = data['valid']
# E = self.models['encoder']
# wp_mixed, eouts = E(reals)
# #wp_mixed = wp_mixed + self.meanw.repeat(reals.shape[0], 1, 1)
# eouts['inversion'] = self.runG(wp_mixed, 'synthesis', highres_outs=None, return_f=True)
# if only_enc:
# return_dict = {'wp_mixed':wp_mixed,'eouts': eouts}
# return return_dict
# fakes, gouts = self.runG(wp_mixed, 'synthesis', highres_outs=eouts)
# #fakes = self.runG(wp_mixed, 'synthesis', highres_outs=None)
# #fakes = F.adaptive_avg_pool2d(fakes, (256,256))
# return_dict = {'fakes': fakes, 'wp_mixed':wp_mixed, 'gouts':gouts, 'eouts': eouts}
# return return_dict
# def hyperstyle_forward(self, data):
# return_dict = self.base_forward(data)
# E = self.models['encoder']
# reals = data['image']
# valids = data['valid']
# #HyperNetwork
# weight_deltas = E(reals, valids, mode='hyper', gouts=return_dict['fakes'])
# fakes = self.runG(return_dict['wp_mixed'], 'synthesis', weight_deltas=weight_deltas)
# return_dict['fakes'] = fakes
# return return_dict
def interface_generate(self, num, edit, factor):
direction = torch.load(f'editings/interfacegan_directions/{edit}.pt').cuda()
indices = list(range(self.rank, num, self.world_size))
gt_path = os.path.join(self.work_dir, f'interfacegan_gt')
smile_add_path = os.path.join(self.work_dir, f'interfacegan_{edit}_{factor}')
smile_rm_path = os.path.join(self.work_dir, f'interfacegan_{edit}_-{factor}')
if self.rank == 0:
os.makedirs(gt_path, exist_ok=True)
os.makedirs(smile_add_path, exist_ok=True)
os.makedirs(smile_rm_path, exist_ok=True)
dist.barrier()
self.logger.init_pbar()
task = self.logger.add_pbar_task('Interfacegan', total=num)
for batch_idx in range(0, len(indices), self.val_batch_size):
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
batch_size = len(sub_indices)
z = torch.randn((batch_size,512), device=torch.cuda.current_device())
w_r = self.runM(z, repeat_w=True)
gt_imgs,_ = self.runG(w_r, resize=False)
gt_imgs = postprocess_image(gt_imgs.detach().cpu().numpy())
for i in range(gt_imgs.shape[0]):
save_name = str(sub_indices[i]) + ".png"
pil_img = Image.fromarray(gt_imgs[i]).resize((256,256))
pil_img.save(os.path.join(gt_path, save_name ))
smile_added, _ = self.runG(w_r + factor*direction, resize=False)
smile_added = postprocess_image(smile_added.detach().cpu().numpy())
for i in range(gt_imgs.shape[0]):
save_name = str(sub_indices[i]) + ".png"
pil_img = Image.fromarray(smile_added[i]).resize((256,256))
pil_img.save(os.path.join(smile_add_path, save_name ))
smile_removed, _= self.runG(w_r - factor*direction, resize=False)
smile_removed = postprocess_image(smile_removed.detach().cpu().numpy())
for i in range(gt_imgs.shape[0]):
save_name = str(sub_indices[i]) + ".png"
pil_img = Image.fromarray(smile_removed[i]).resize((256,256))
pil_img.save(os.path.join(smile_rm_path, save_name ))
self.logger.update_pbar(task, batch_size * self.world_size)
self.logger.close_pbar()
def grad_edit(self, edit, factor, dataset=None):
dist.barrier()
if self.rank != 0:
return
self.set_mode('val')
edit_name = edit
edit = 'val'
self.config.data[edit]['root_dir'] = dataset
dataset = BaseDataset(**self.config.data[edit])
val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
temp_dir = os.path.join(self.work_dir, f'fakes_{edit_name}_{factor}')
os.makedirs(temp_dir, exist_ok=True)
self.logger.init_pbar()
task = self.logger.add_pbar_task('Synthesis', total=len(val_loader))
global_i = 0
args = {'model': 'ffhq', 'model_dir': '/media/hdd2/adundar/hamza/genforce/editings/GradCtrl/model_ffhq',
'attribute': edit_name, 'exclude': 'default', 'top_channels': 'default', 'layerwise': 'default' }
for idx, data in enumerate(val_loader):
for key in data:
if key != 'name':
data[key] = data[key].cuda(
torch.cuda.current_device(), non_blocking=True)
with torch.no_grad():
return_dict = self.forward_pass(data, return_vals='all', only_enc=True)
wp_mixed = return_dict['wp_mixed']
eouts = return_dict['eouts']
#fakes = return_dict['fakes']
edit_wp = gradctrl(args, wp_mixed, factor)
edited_images, gouts_edits = self.runG(edit_wp, "synthesis", highres_outs=eouts, resize=False)
#edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False)
edited_images = postprocess_image(edited_images.detach().cpu().numpy())
for j in range(edited_images.shape[0]):
save_name = data['name'][j]
pil_img = Image.fromarray(edited_images[j]).resize((256,256))
pil_img.save(os.path.join(temp_dir, save_name ))
global_i += 1
self.logger.update_pbar(task, 1)
self.logger.close_pbar()
def measure_time(self, edit, factor, dataset=None, save_latents=False):
dist.barrier()
if self.rank != 0:
return
self.set_mode('val')
edit_name = edit
if dataset is not None:
edit = 'val'
self.config.data[edit]['root_dir'] = dataset
dataset = BaseDataset(**self.config.data[edit])
val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)
global_i = 0
time_list = []
for idx, data in enumerate(val_loader):
for key in data:
if key != 'name':
data[key] = data[key].cuda(
torch.cuda.current_device(), non_blocking=True)
with torch.no_grad():
start = time.time()
return_dict = self.forward_pass(data, return_vals='all', only_enc=True)
wp_mixed = return_dict['wp_mixed']
eouts = return_dict['eouts']
edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False)
end = time.time()
time_list.append(end-start)
print(np.mean(time_list))
print(np.mean(time_list[1:]))