Spaces:
Runtime error
Runtime error
# python3.7 | |
"""Contains the base class for Encoder (GAN Inversion) runner.""" | |
import os | |
import shutil | |
import torch | |
import torch.distributed as dist | |
import torchvision.transforms as T | |
from utils.visualizer import HtmlPageVisualizer | |
from utils.visualizer import get_grid_shape | |
from utils.visualizer import postprocess_image | |
from utils.visualizer import save_image | |
from utils.visualizer import load_image | |
from utils.visualizer import postprocess_tensor | |
from metrics.inception import build_inception_model | |
from metrics.fid import extract_feature | |
from metrics.fid import compute_fid | |
from metrics.MSSIM import MSSSIM | |
from metrics.LPIPS import LPIPS | |
import numpy as np | |
from .base_runner import BaseRunner | |
from datasets import BaseDataset | |
from torch.utils.data import DataLoader | |
from PIL import Image | |
from runners.controllers.summary_writer import log_image | |
import torchvision | |
from editings.latent_editor import LatentEditor | |
from editings.styleclip.edit_hfgi import styleclip_edit, load_stylegan_generator,load_direction_calculator | |
from editings.GradCtrl.manipulate import main as gradctrl | |
import torch.nn.functional as F | |
import time | |
__all__ = ['BaseEncoderRunner'] | |
class BaseEncoderRunner(BaseRunner): | |
"""Defines the base class for Encoder runner.""" | |
def __init__(self, config, logger): | |
super().__init__(config, logger) | |
self.inception_model = None | |
def build_models(self): | |
super().build_models() | |
assert 'encoder' in self.models | |
assert 'generator_smooth' in self.models | |
assert 'discriminator' in self.models | |
self.resolution = self.models['generator_smooth'].resolution | |
self.G_kwargs_train = self.config.modules['generator_smooth'].get( | |
'kwargs_train', dict()) | |
self.G_kwargs_val = self.config.modules['generator_smooth'].get( | |
'kwargs_val', dict()) | |
self.D_kwargs_train = self.config.modules['discriminator'].get( | |
'kwargs_train', dict()) | |
self.D_kwargs_val = self.config.modules['discriminator'].get( | |
'kwargs_val', dict()) | |
if self.config.use_disc2: | |
self.D2_kwargs_train = self.config.modules['discriminator2'].get( | |
'kwargs_train', dict()) | |
self.D2_kwargs_val = self.config.modules['discriminator2'].get( | |
'kwargs_val', dict()) | |
if self.config.mapping_method != 'pretrained': | |
self.M_kwargs_train = self.config.modules['mapping'].get( | |
'kwargs_train', dict()) | |
self.M_kwargs_val = self.config.modules['mapping'].get( | |
'kwargs_val', dict()) | |
if self.config.create_mixing_network: | |
self.MIX_kwargs_train = self.config.modules['mixer'].get( | |
'kwargs_train', dict()) | |
self.MIX_kwargs_val = self.config.modules['mixer'].get( | |
'kwargs_val', dict()) | |
def train_step(self, data, **train_kwargs): | |
raise NotImplementedError('Should be implemented in derived class.') | |
def mse(self, mse_num): | |
if mse_num == 0: | |
return -1 | |
self.set_mode('val') | |
if self.val_loader is None: | |
self.build_dataset('val') | |
if mse_num == "auto": | |
mse_num = len(self.val_loader.dataset) | |
indices = list(range(self.rank, mse_num, self.world_size)) | |
self.logger.init_pbar() | |
task1 = self.logger.add_pbar_task('MSE-LPIPS-SSIM', total=mse_num) | |
lpips = LPIPS() | |
ssim = MSSSIM(size_average=False) | |
n_evals = 3 | |
gather_list = [torch.zeros( (self.val_batch_size, n_evals), device=torch.cuda.current_device()) for i in range(self.world_size)] | |
all_errors = np.zeros( (mse_num, n_evals), dtype=np.float64) | |
shared_tensor = torch.zeros((self.val_batch_size, n_evals), device=torch.cuda.current_device()) | |
gather_idx = 0 | |
for batch_idx in range(0, len(indices), self.val_batch_size): | |
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] | |
batch_size = len(sub_indices) | |
data = next(self.val_loader) | |
for key in data: | |
if key != 'name': | |
data[key] = data[key][:batch_size].cuda( | |
torch.cuda.current_device(), non_blocking=True) | |
with torch.no_grad(): | |
real_images = data['image'] | |
return_dict = self.forward_pass(data, return_vals='fakes, wp_mixed') | |
fakes = return_dict['fakes'] | |
shared_tensor[:, 0] = torch.mean((fakes - real_images)**2, dim=(1,2,3)) #MSE Error | |
shared_tensor[:, 1]= lpips(real_images, fakes) | |
shared_tensor[:, 2]= ssim(real_images, fakes) #ssim (real_images[0].unsqueeze(0), fakes[0].unsqueeze(0) ) | |
dist.all_gather(gather_list, shared_tensor) | |
if self.rank == 0: | |
for t in gather_list: | |
all_errors[gather_idx:gather_idx+batch_size, 0] = t[:,0].cpu().numpy() | |
all_errors[gather_idx:gather_idx+batch_size, 1] = t[:,1].cpu().numpy() | |
all_errors[gather_idx:gather_idx+batch_size, 2] = t[:,2].cpu().numpy() | |
gather_idx = gather_idx+batch_size | |
self.logger.update_pbar(task1, batch_size * self.world_size) | |
self.logger.close_pbar() | |
mean_lst, std_lst = np.mean(all_errors, axis=0), np.std(all_errors, axis=0) | |
mse_mean, lpips_mean, ssim_mean = mean_lst[0].item(), mean_lst[1].item(), mean_lst[2].item() | |
mse_std, lpips_std, ssim_std = std_lst[0].item(), std_lst[1].item(), std_lst[2].item() | |
return_vals = {'mse': (mse_mean,mse_std), 'lpips': (lpips_mean, lpips_std), 'ssim':(ssim_mean, ssim_std)} | |
return return_vals | |
def fid_attribute(self, | |
fid_num, | |
z=None, | |
ignore_cache=False, | |
align_tf=True, | |
attribute='smile', factor=1, direction=None): | |
"""Computes the FID metric.""" | |
self.set_mode('val') | |
direction = torch.load(f'editings/interfacegan_directions/{attribute}.pt').cuda() | |
if factor < 0: | |
self.config.data['smile']['root_dir'] = f'/media/hdd2/adundar/hamza/genforce/data/temp/smile_with_original' | |
elif factor > 0: | |
self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_without_original" | |
fake_loader = self.build_dataset(f"smile") | |
#fid_num = min(fid_num, len(self.val_loader.dataset)) | |
fid_num = len(fake_loader.dataset) | |
if self.inception_model is None: | |
if align_tf: | |
self.logger.info(f'Building inception model ' | |
f'(aligned with TensorFlow) ...') | |
else: | |
self.logger.info(f'Building inception model ' | |
f'(using torchvision) ...') | |
self.inception_model = build_inception_model(align_tf).cuda() | |
self.logger.info(f'Finish building inception model.') | |
if z is not None: | |
assert isinstance(z, np.ndarray) | |
assert z.ndim == 2 and z.shape[1] == self.z_space_dim | |
fid_num = min(fid_num, z.shape[0]) | |
z = torch.from_numpy(z).type(torch.FloatTensor) | |
if not fid_num: | |
return -1 | |
indices = list(range(self.rank, fid_num, self.world_size)) | |
self.logger.init_pbar() | |
# Extract features from fake images. | |
fake_feature_list = [] | |
task1 = self.logger.add_pbar_task(f'FID-{attribute}_fake', total=fid_num) | |
for batch_idx in range(0, len(indices), self.val_batch_size): | |
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] | |
batch_size = len(sub_indices) | |
data = next(fake_loader) | |
for key in data: | |
if key != 'name': | |
data[key] = data[key][:batch_size].cuda( | |
torch.cuda.current_device(), non_blocking=True) | |
with torch.no_grad(): | |
real_images = data['image'] | |
#valids = data['valid'] | |
return_dict = self.forward_pass(data, return_vals='all', only_enc = True) | |
wp = return_dict['wp_mixed'] | |
eouts = return_dict['eouts'] | |
edit_wp = wp + factor * direction | |
edited_images, _ = self.runG(edit_wp, "synthesis", highres_outs=eouts) | |
fake_feature_list.append( | |
extract_feature(self.inception_model, edited_images)) | |
self.logger.update_pbar(task1, batch_size * self.world_size) | |
np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy', | |
np.concatenate(fake_feature_list, axis=0)) | |
self.logger.close_pbar() | |
#Extract features from real images if needed. | |
cached_fid_file = f'{self.work_dir}/real_{attribute}_{factor}_fid.npy' | |
do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache) | |
if do_real_test: | |
real_feature_list = [] | |
self.logger.init_pbar() | |
if factor < 0: | |
self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_without_original" | |
elif factor > 0: | |
self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_with_original" | |
real_loader = self.build_dataset(f"smile") | |
fid_num = len(real_loader.dataset) | |
indices = list(range(self.rank, fid_num, self.world_size)) | |
task2 = self.logger.add_pbar_task(f"{attribute}_real", total=fid_num) | |
for batch_idx in range(0, len(indices), self.val_batch_size): | |
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] | |
batch_size = len(sub_indices) | |
data = next(real_loader) | |
for key in data: | |
if key != 'name': | |
data[key] = data[key][:batch_size].cuda( | |
torch.cuda.current_device(), non_blocking=True) | |
with torch.no_grad(): | |
real_images = data['image'] | |
real_feature_list.append( | |
extract_feature(self.inception_model, real_images)) | |
self.logger.update_pbar(task2, batch_size * self.world_size) | |
np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy', | |
np.concatenate(real_feature_list, axis=0)) | |
dist.barrier() | |
if self.rank != 0: | |
return -1 | |
self.logger.close_pbar() | |
# Collect fake features. | |
fake_feature_list.clear() | |
for rank in range(self.world_size): | |
fake_feature_list.append( | |
np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy')) | |
os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy') | |
fake_features = np.concatenate(fake_feature_list, axis=0) | |
# assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num | |
feature_dim = fake_features.shape[1] | |
feature_num = fake_features.shape[0] | |
pad = feature_num % self.world_size #feature_dim.shape[0] | |
if pad: | |
pad = self.world_size - pad | |
fake_features = np.pad(fake_features, ((0, pad), (0, 0))) | |
fake_features = fake_features.reshape(self.world_size, -1, feature_dim) | |
fake_features = fake_features.transpose(1, 0, 2) | |
fake_features = fake_features.reshape(-1, feature_dim)[:feature_num] | |
# Collect (or load) real features. | |
if do_real_test: | |
real_feature_list.clear() | |
for rank in range(self.world_size): | |
real_feature_list.append( | |
np.load(f'{self.work_dir}/real_fid_features_{rank}.npy')) | |
os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy') | |
real_features = np.concatenate(real_feature_list, axis=0) | |
# assert real_features.shape == (fid_num, feature_dim) | |
feature_dim = real_features.shape[1] | |
feature_num = real_features.shape[0] | |
pad = feature_num % self.world_size | |
if pad: | |
pad = self.world_size - pad | |
real_features = np.pad(real_features, ((0, pad), (0, 0))) | |
real_features = real_features.reshape( | |
self.world_size, -1, feature_dim) | |
real_features = real_features.transpose(1, 0, 2) | |
real_features = real_features.reshape(-1, feature_dim)[:feature_num] | |
np.save(cached_fid_file, real_features) | |
else: | |
real_features = np.load(cached_fid_file) | |
# assert real_features.shape == (fid_num, feature_dim) | |
fid_value = compute_fid(fake_features, real_features) | |
return fid_value | |
def fid(self, | |
fid_num, | |
z=None, | |
ignore_cache=False, | |
align_tf=True): | |
"""Computes the FID metric.""" | |
self.set_mode('val') | |
if self.val_loader is None: | |
self.build_dataset('val') | |
fid_num = min(fid_num, len(self.val_loader.dataset)) | |
if self.inception_model is None: | |
if align_tf: | |
self.logger.info(f'Building inception model ' | |
f'(aligned with TensorFlow) ...') | |
else: | |
self.logger.info(f'Building inception model ' | |
f'(using torchvision) ...') | |
self.inception_model = build_inception_model(align_tf).cuda() | |
self.logger.info(f'Finish building inception model.') | |
if z is not None: | |
assert isinstance(z, np.ndarray) | |
assert z.ndim == 2 and z.shape[1] == self.z_space_dim | |
fid_num = min(fid_num, z.shape[0]) | |
z = torch.from_numpy(z).type(torch.FloatTensor) | |
if not fid_num: | |
return -1 | |
indices = list(range(self.rank, fid_num, self.world_size)) | |
self.logger.init_pbar() | |
#generator = self.run_with_optim if run_with_optim else self.run_without_optim | |
# Extract features from fake images. | |
fake_feature_list = [] | |
real_feature_list = [] | |
task1 = self.logger.add_pbar_task('FID', total=fid_num) | |
for batch_idx in range(0, len(indices), self.val_batch_size): | |
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] | |
batch_size = len(sub_indices) | |
data = next(self.val_loader) | |
for key in data: | |
if key != 'name': | |
data[key] = data[key][:batch_size].cuda( | |
torch.cuda.current_device(), non_blocking=True) | |
# z_rand = torch.randn((batch_size,self.config.z_count,self.config.latent_dim) | |
# , device=torch.cuda.current_device()) | |
# data['z_rand'] = z_rand | |
with torch.no_grad(): | |
real_images = data['image'] | |
#valids = data['valid'] | |
return_dict = self.forward_pass(data, return_vals='fakes, wp_mixed') | |
fakes = return_dict['fakes'] | |
if self.config.test_time_optims != 0: | |
wp_mixed = return_dict['wp_mixed'] | |
fakes = self.optimize(data, wp_mixed) | |
with torch.no_grad(): | |
#final_out = real_images * valids + fakes * (1.0-valids) #Final output is the mixed one. | |
fake_feature_list.append( | |
extract_feature(self.inception_model, fakes)) | |
# Extract features from real images if needed. | |
cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy' | |
do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache) | |
if do_real_test: | |
with torch.no_grad(): | |
real_feature_list.append( | |
extract_feature(self.inception_model, real_images)) | |
self.logger.update_pbar(task1, batch_size * self.world_size) | |
np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy', | |
np.concatenate(fake_feature_list, axis=0)) | |
if (do_real_test): | |
np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy', | |
np.concatenate(real_feature_list, axis=0)) | |
# Extract features from real images if needed. | |
# cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy' | |
# do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache) | |
# if do_real_test: | |
# real_feature_list = [] | |
# task2 = self.logger.add_pbar_task("Real", total=fid_num) | |
# for batch_idx in range(0, len(indices), self.val_batch_size): | |
# sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] | |
# batch_size = len(sub_indices) | |
# data = next(self.val_loader) | |
# for key in data: | |
# data[key] = data[key][:batch_size].cuda( | |
# torch.cuda.current_device(), non_blocking=True) | |
# with torch.no_grad(): | |
# real_images = data['image'] | |
# real_feature_list.append( | |
# extract_feature(self.inception_model, real_images)) | |
# self.logger.update_pbar(task2, batch_size * self.world_size) | |
# np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy', | |
# np.concatenate(real_feature_list, axis=0)) | |
dist.barrier() | |
if self.rank != 0: | |
return -1 | |
self.logger.close_pbar() | |
# Collect fake features. | |
fake_feature_list.clear() | |
for rank in range(self.world_size): | |
fake_feature_list.append( | |
np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy')) | |
os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy') | |
fake_features = np.concatenate(fake_feature_list, axis=0) | |
assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num | |
feature_dim = fake_features.shape[1] | |
pad = fid_num % self.world_size | |
if pad: | |
pad = self.world_size - pad | |
fake_features = np.pad(fake_features, ((0, pad), (0, 0))) | |
fake_features = fake_features.reshape(self.world_size, -1, feature_dim) | |
fake_features = fake_features.transpose(1, 0, 2) | |
fake_features = fake_features.reshape(-1, feature_dim)[:fid_num] | |
# Collect (or load) real features. | |
if do_real_test: | |
real_feature_list.clear() | |
for rank in range(self.world_size): | |
real_feature_list.append( | |
np.load(f'{self.work_dir}/real_fid_features_{rank}.npy')) | |
os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy') | |
real_features = np.concatenate(real_feature_list, axis=0) | |
assert real_features.shape == (fid_num, feature_dim) | |
real_features = np.pad(real_features, ((0, pad), (0, 0))) | |
real_features = real_features.reshape( | |
self.world_size, -1, feature_dim) | |
real_features = real_features.transpose(1, 0, 2) | |
real_features = real_features.reshape(-1, feature_dim)[:fid_num] | |
np.save(cached_fid_file, real_features) | |
else: | |
real_features = np.load(cached_fid_file) | |
assert real_features.shape == (fid_num, feature_dim) | |
fid_value = compute_fid(fake_features, real_features) | |
return fid_value | |
def val(self, **val_kwargs): | |
self.synthesize(**val_kwargs) | |
def synthesize(self, | |
num, | |
html_name=None, | |
save_raw_synthesis=False): | |
"""Synthesizes images. | |
Args: | |
num: Number of images to synthesize. | |
z: Latent codes used for generation. If not specified, this function | |
will sample latent codes randomly. (default: None) | |
html_name: Name of the output html page for visualization. If not | |
specified, no visualization page will be saved. (default: None) | |
save_raw_synthesis: Whether to save raw synthesis on the disk. | |
(default: False) | |
""" | |
dist.barrier() | |
if self.rank != 0: | |
return | |
if not html_name and not save_raw_synthesis: | |
return | |
self.set_mode('val') | |
if self.val_loader is None: | |
self.build_dataset('val') | |
# temp_dir = os.path.join(self.work_dir, 'synthesize_results') | |
# os.makedirs(temp_dir, exist_ok=True) | |
if not num: | |
return | |
# if num % self.val_batch_size != 0: | |
# num = (num //self.val_batch_size +1)*self.val_batch_size | |
# TODO: Use same z during the entire training process. | |
self.logger.init_pbar() | |
task = self.logger.add_pbar_task('Synthesis', total=num) | |
for i in range(num): | |
data = next(self.val_loader) | |
for key in data: | |
if key != 'name': | |
data[key] = data[key].cuda( | |
torch.cuda.current_device(), non_blocking=True) | |
with torch.no_grad(): | |
real_images = data['image'] | |
return_dict = self.forward_pass(data, return_vals='all') | |
fakes = return_dict['fakes'] | |
wp_mixed = return_dict['wp_mixed'] | |
eouts = return_dict['eouts'] | |
log_list_gpu = {"real": real_images, "fake": fakes} | |
# Add editings to log_list | |
editings = ['age', 'pose', 'smile'] | |
for edit in editings: | |
direction = torch.load(f'editings/interfacegan_directions/{edit}.pt').cuda() | |
factors = [+3, -3] | |
for factor in factors: | |
name = f"{edit}_{factor}" | |
edit_wp = wp_mixed + factor * direction | |
edited_images, _ = self.runG(edit_wp, "synthesis", highres_outs=eouts) | |
# if edit == 'smile' and factor == -3: | |
# res = gouts['gates'].shape[-1] | |
# log_list_gpu[f'smile_-3_gate'] = ( torch.mean((gouts_edits['gates']) , dim=1, keepdim=True), 0) | |
#edited_images = F.adaptive_avg_pool2d(edited_images, 256) | |
log_list_gpu[name] = edited_images | |
#log_list_gpu[f'{name}_gate'] = ( torch.mean((temp['gates']) , dim=1, keepdim=True), 0) | |
#Add gate to log_list | |
# res = gouts['gates'].shape[-1] | |
# log_list_gpu[f'gate{res}x{res}'] = ( torch.mean((gouts['gates']) , dim=1, keepdim=True), 0) | |
#Log images | |
for log_name, log_val in log_list_gpu.items(): | |
log_im = log_val[0] if type(log_val) is tuple else log_val | |
min_val = log_val[1] if type(log_val) is tuple else -1 | |
cpu_img = postprocess_tensor(log_im.detach().cpu(), min_val=min_val) | |
grid = torchvision.utils.make_grid(cpu_img, nrow=5) | |
log_image( name = f"image/{log_name}", grid=grid, iter=self.iter) | |
self.logger.update_pbar(task, 1) | |
self.logger.close_pbar() | |
def save_edited_images(self, opts): | |
dist.barrier() | |
if self.rank != 0: | |
return | |
self.set_mode('val') | |
if opts.method == 'inversion': | |
pass | |
elif opts.method == 'interfacegan': | |
direction = torch.load(f'editings/interfacegan_directions/{opts.edit}.pt').cuda() | |
elif opts.method == 'ganspace': | |
ganspace_pca = torch.load('editings/ganspace_pca/ffhq_pca.pt') | |
direction = { | |
'eye_openness': (54, 7, 8, 20), | |
'smile': (46, 4, 5, -20), | |
'beard': (58, 7, 9, -20), | |
'white_hair': (57, 7, 10, -24), | |
'lipstick': (34, 10, 11, 20), | |
'overexposed': (27, 8, 18, 15), | |
'screaming': (35, 3, 7, -10), | |
'head_angle_up': (11, 1, 4, 10), | |
} | |
editor = LatentEditor() | |
elif opts.method == 'styleclip': | |
#model_path = '/media/hdd2/adundar/hamza/hyperstyle/pretrained_models/stylegan2-ffhq-config-f.pt' | |
# calculator_args = { | |
# 'delta_i_c': 'editings/styleclip/global_directions/ffhq/fs3.npy', | |
# 's_statistics': 'editings/styleclip/global_directions/ffhq/S_mean_std', | |
# 'text_prompt_templates': 'editings/styleclip/global_directions/templates.txt' | |
# } | |
stylegan_model = load_stylegan_generator(opts.model_path) | |
global_direction_calculator = load_direction_calculator(opts.calculator_args) | |
#Eyeglasses 5, bangs 2, bobcut 5 | |
# edit_args = {'alpha_min': 2, 'alpha_max': 2, 'num_alphas':1, 'beta_min':0.11, 'beta_max':0.11, 'num_betas': 1, | |
# 'neutral_text':'face', 'target_text': 'face with bangs'} | |
self.config.data['val']['root_dir'] = opts.dataset | |
dataset = BaseDataset(**self.config.data['val']) | |
val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1) | |
temp_dir = os.path.join(self.work_dir, opts.output) | |
os.makedirs(temp_dir, exist_ok=True) | |
self.logger.init_pbar() | |
task = self.logger.add_pbar_task('Synthesis', total=len(val_loader)) | |
global_i = 0 | |
all_latents = {} | |
for idx, data in enumerate(val_loader): | |
for key in data: | |
if key != 'name': | |
data[key] = data[key].cuda( | |
torch.cuda.current_device(), non_blocking=True) | |
with torch.no_grad(): | |
return_dict = self.forward_pass(data, return_vals='all', only_enc=True) | |
wp_mixed = return_dict['wp_mixed'] | |
eouts = return_dict['eouts'] | |
#fakes = return_dict['fakes'] | |
factors = np.linspace(0, 3, 100) | |
global_i = 0 | |
# for factor in factors: | |
if opts.method == 'interfacegan': | |
wp_mixed = wp_mixed + opts.factor * direction | |
if opts.method == 'ganspace': | |
#interpolate_dir = direction[opts.edit][0:3] + (factor,) | |
wp_mixed = editor.apply_ganspace_(wp_mixed, ganspace_pca, [direction[opts.edit]]) | |
#wp_edit = editor.apply_ganspace_(wp_mixed, ganspace_pca, [interpolate_dir]) | |
# z = torch.randn((1,self.config.latent_dim), device=torch.cuda.current_device()) | |
# z = self.runM(z) | |
# diff = z - wp_mixed | |
# edit = (diff * 3.5) / 10 | |
# wp_mixed = wp_mixed + edit | |
edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False) | |
if opts.method == 'styleclip': | |
# opts.edit_args['alpha_min'] = factor | |
edited_images = styleclip_edit(wp_mixed, gouts_edits['additions'], stylegan_model, global_direction_calculator, opts.edit_args) | |
edited_images = T.Resize((256,256))(edited_images) | |
edited_images = postprocess_image(edited_images.detach().cpu().numpy()) | |
for j in range(edited_images.shape[0]): | |
# dir_name = data['name'][j][:-4] | |
# os.makedirs(os.path.join(temp_dir, dir_name), exist_ok=True) | |
# save_name = f'{global_i:03d}_' + data['name'][j] | |
save_name = data['name'][j] | |
pil_img = Image.fromarray(edited_images[j]) #.resize((256,256)) | |
#pil_img.save(os.path.join(temp_dir, dir_name, save_name )) | |
pil_img.save(os.path.join(temp_dir, save_name )) | |
global_i += 1 | |
# if global_i >= 1000: | |
# break | |
# if global_i % 100 == 0: | |
# print(f"{global_i}/1000") | |
self.logger.update_pbar(task, 1) | |
self.logger.close_pbar() | |
# def forward_pass(self, data, return_vals='all', only_enc=False): | |
# encoder_type = self.config.encoder_type | |
# forward_func = getattr(self, f'{encoder_type}_forward') | |
# return_dict = forward_func(data,only_enc) | |
# return return_dict | |
# # if return_vals == 'all': | |
# # return return_dict | |
# # requested = return_vals.split(',') | |
# # modified_dict = {} | |
# # for request in requested: | |
# # stripped_request = request.strip() | |
# # modified_dict[stripped_request] = return_dict[stripped_request] | |
# # return modified_dict | |
# def base_forward(self, data): | |
# reals = data['image'] | |
# valids = data['valid'] | |
# z_rand = data['z_rand'] | |
# wp_rand = self.runM(z_rand) | |
# wp_enc, blender = self.runE(reals, valids) | |
# wp_mixed = self.mix(wp_enc, wp_rand, blender) | |
# fakes = self.runG(wp_mixed, 'synthesis') | |
# return_dict = {'fakes': fakes, 'wp_enc': wp_enc, 'blender': blender, 'wp_mixed':wp_mixed} | |
# return return_dict | |
# def pSp_forward(self, data, only_enc): | |
# return self.e4e_forward(data, only_enc) | |
# def train_forward(self, data, iscycle=False): | |
# reals = data['image'] | |
# direction = data['direction'] | |
# edit_name = data['edit_name'] | |
# factor = data['factor'] | |
# E = self.models['encoder'] | |
# with torch.no_grad(): | |
# wp, eouts = E(reals) | |
# #wp = wp + self.meanw.repeat(reals.shape[0], 1, 1) | |
# edit = torch.zeros_like(wp) | |
# for i in range (edit.shape[0]): | |
# if edit_name[i] is None: | |
# edit[i] = 0 | |
# elif edit_name[i] == 'randw': | |
# diff = direction[i] - wp[i] | |
# # one_hot = [1] * 8 + [0] * 10 | |
# # one_hot = torch.tensor(one_hot, device=diff.device).unsqueeze(1) | |
# # diff = diff * one_hot | |
# #norm = torch.linalg.norm(diff, dim=1, keepdim=True) | |
# edit[i] = (diff * factor[i]) / 10 | |
# elif edit_name[i] == 'interface': | |
# edit[i] = (factor[i] * direction[i]) | |
# # # Debug | |
# # with torch.no_grad(): | |
# # fakes,_ =self.runG(wp, 'synthesis', highres_outs=None) | |
# # fakes = postprocess_image(fakes.detach().cpu().numpy()) | |
# # for i in range(fakes.shape[0]): | |
# # pil_img = Image.fromarray(fakes[i]).resize((256,256)) | |
# # pil_img.save(f'{self.iter}_orig.png') | |
# # fakes,_ =self.runG(wp+edit, 'synthesis', highres_outs=None) | |
# # fakes = postprocess_image(fakes.detach().cpu().numpy()) | |
# # for i in range(fakes.shape[0]): | |
# # pil_img = Image.fromarray(fakes[i]).resize((256,256)) | |
# # pil_img.save(f'{self.iter}_edit.png') | |
# # fakes,_ =self.runG(direction.unsqueeze(1).repeat(1,18,1), 'synthesis', highres_outs=None) | |
# # fakes = postprocess_image(fakes.detach().cpu().numpy()) | |
# # for i in range(fakes.shape[0]): | |
# # pil_img = Image.fromarray(fakes[i]).resize((256,256)) | |
# # pil_img.save(f'{self.iter}_rand.png') | |
# with torch.no_grad(): | |
# eouts['inversion'] = self.runG(wp, 'synthesis', highres_outs=None, return_f=True) | |
# wp = wp + edit | |
# fakes, gouts = self.runG(wp, 'synthesis', highres_outs=eouts) | |
# #fakes = F.adaptive_avg_pool2d(fakes, (256,256)) | |
# fakes_cycle = None | |
# if iscycle: | |
# # wp_cycle = wp_cycle + self.meanw.repeat(reals.shape[0], 1, 1) | |
# with torch.no_grad(): | |
# wp_cycle, eout_cycle = E(fakes) | |
# eout_cycle['inversion'] = self.runG(wp_cycle, 'synthesis', highres_outs=None, return_f=True) | |
# #wp_cycle = wp_cycle - edit | |
# wp_cycle = wp_cycle - edit | |
# #wp_cycle = wp_cycle - (data['factor'] * data['direction']).unsqueeze(1) | |
# fakes_cycle, _ = self.runG(wp_cycle, 'synthesis', highres_outs=eout_cycle) | |
# #fakes_cycle = F.adaptive_avg_pool2d(fakes, (256,256)) | |
# #cycle = F.mse_loss(fakes_cycle, reals, reduction='mean') | |
# return_dict = {'fakes': fakes, 'wp_mixed':wp, 'gouts':gouts, 'eouts': eouts, 'cycle': fakes_cycle} | |
# return return_dict | |
# def e4e_forward(self, data, only_enc=False): | |
# #return self.base_forward(data) | |
# reals = data['image'] | |
# #valids = data['valid'] | |
# E = self.models['encoder'] | |
# wp_mixed, eouts = E(reals) | |
# #wp_mixed = wp_mixed + self.meanw.repeat(reals.shape[0], 1, 1) | |
# eouts['inversion'] = self.runG(wp_mixed, 'synthesis', highres_outs=None, return_f=True) | |
# if only_enc: | |
# return_dict = {'wp_mixed':wp_mixed,'eouts': eouts} | |
# return return_dict | |
# fakes, gouts = self.runG(wp_mixed, 'synthesis', highres_outs=eouts) | |
# #fakes = self.runG(wp_mixed, 'synthesis', highres_outs=None) | |
# #fakes = F.adaptive_avg_pool2d(fakes, (256,256)) | |
# return_dict = {'fakes': fakes, 'wp_mixed':wp_mixed, 'gouts':gouts, 'eouts': eouts} | |
# return return_dict | |
# def hyperstyle_forward(self, data): | |
# return_dict = self.base_forward(data) | |
# E = self.models['encoder'] | |
# reals = data['image'] | |
# valids = data['valid'] | |
# #HyperNetwork | |
# weight_deltas = E(reals, valids, mode='hyper', gouts=return_dict['fakes']) | |
# fakes = self.runG(return_dict['wp_mixed'], 'synthesis', weight_deltas=weight_deltas) | |
# return_dict['fakes'] = fakes | |
# return return_dict | |
def interface_generate(self, num, edit, factor): | |
direction = torch.load(f'editings/interfacegan_directions/{edit}.pt').cuda() | |
indices = list(range(self.rank, num, self.world_size)) | |
gt_path = os.path.join(self.work_dir, f'interfacegan_gt') | |
smile_add_path = os.path.join(self.work_dir, f'interfacegan_{edit}_{factor}') | |
smile_rm_path = os.path.join(self.work_dir, f'interfacegan_{edit}_-{factor}') | |
if self.rank == 0: | |
os.makedirs(gt_path, exist_ok=True) | |
os.makedirs(smile_add_path, exist_ok=True) | |
os.makedirs(smile_rm_path, exist_ok=True) | |
dist.barrier() | |
self.logger.init_pbar() | |
task = self.logger.add_pbar_task('Interfacegan', total=num) | |
for batch_idx in range(0, len(indices), self.val_batch_size): | |
sub_indices = indices[batch_idx:batch_idx + self.val_batch_size] | |
batch_size = len(sub_indices) | |
z = torch.randn((batch_size,512), device=torch.cuda.current_device()) | |
w_r = self.runM(z, repeat_w=True) | |
gt_imgs,_ = self.runG(w_r, resize=False) | |
gt_imgs = postprocess_image(gt_imgs.detach().cpu().numpy()) | |
for i in range(gt_imgs.shape[0]): | |
save_name = str(sub_indices[i]) + ".png" | |
pil_img = Image.fromarray(gt_imgs[i]).resize((256,256)) | |
pil_img.save(os.path.join(gt_path, save_name )) | |
smile_added, _ = self.runG(w_r + factor*direction, resize=False) | |
smile_added = postprocess_image(smile_added.detach().cpu().numpy()) | |
for i in range(gt_imgs.shape[0]): | |
save_name = str(sub_indices[i]) + ".png" | |
pil_img = Image.fromarray(smile_added[i]).resize((256,256)) | |
pil_img.save(os.path.join(smile_add_path, save_name )) | |
smile_removed, _= self.runG(w_r - factor*direction, resize=False) | |
smile_removed = postprocess_image(smile_removed.detach().cpu().numpy()) | |
for i in range(gt_imgs.shape[0]): | |
save_name = str(sub_indices[i]) + ".png" | |
pil_img = Image.fromarray(smile_removed[i]).resize((256,256)) | |
pil_img.save(os.path.join(smile_rm_path, save_name )) | |
self.logger.update_pbar(task, batch_size * self.world_size) | |
self.logger.close_pbar() | |
def grad_edit(self, edit, factor, dataset=None): | |
dist.barrier() | |
if self.rank != 0: | |
return | |
self.set_mode('val') | |
edit_name = edit | |
edit = 'val' | |
self.config.data[edit]['root_dir'] = dataset | |
dataset = BaseDataset(**self.config.data[edit]) | |
val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) | |
temp_dir = os.path.join(self.work_dir, f'fakes_{edit_name}_{factor}') | |
os.makedirs(temp_dir, exist_ok=True) | |
self.logger.init_pbar() | |
task = self.logger.add_pbar_task('Synthesis', total=len(val_loader)) | |
global_i = 0 | |
args = {'model': 'ffhq', 'model_dir': '/media/hdd2/adundar/hamza/genforce/editings/GradCtrl/model_ffhq', | |
'attribute': edit_name, 'exclude': 'default', 'top_channels': 'default', 'layerwise': 'default' } | |
for idx, data in enumerate(val_loader): | |
for key in data: | |
if key != 'name': | |
data[key] = data[key].cuda( | |
torch.cuda.current_device(), non_blocking=True) | |
with torch.no_grad(): | |
return_dict = self.forward_pass(data, return_vals='all', only_enc=True) | |
wp_mixed = return_dict['wp_mixed'] | |
eouts = return_dict['eouts'] | |
#fakes = return_dict['fakes'] | |
edit_wp = gradctrl(args, wp_mixed, factor) | |
edited_images, gouts_edits = self.runG(edit_wp, "synthesis", highres_outs=eouts, resize=False) | |
#edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False) | |
edited_images = postprocess_image(edited_images.detach().cpu().numpy()) | |
for j in range(edited_images.shape[0]): | |
save_name = data['name'][j] | |
pil_img = Image.fromarray(edited_images[j]).resize((256,256)) | |
pil_img.save(os.path.join(temp_dir, save_name )) | |
global_i += 1 | |
self.logger.update_pbar(task, 1) | |
self.logger.close_pbar() | |
def measure_time(self, edit, factor, dataset=None, save_latents=False): | |
dist.barrier() | |
if self.rank != 0: | |
return | |
self.set_mode('val') | |
edit_name = edit | |
if dataset is not None: | |
edit = 'val' | |
self.config.data[edit]['root_dir'] = dataset | |
dataset = BaseDataset(**self.config.data[edit]) | |
val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0) | |
global_i = 0 | |
time_list = [] | |
for idx, data in enumerate(val_loader): | |
for key in data: | |
if key != 'name': | |
data[key] = data[key].cuda( | |
torch.cuda.current_device(), non_blocking=True) | |
with torch.no_grad(): | |
start = time.time() | |
return_dict = self.forward_pass(data, return_vals='all', only_enc=True) | |
wp_mixed = return_dict['wp_mixed'] | |
eouts = return_dict['eouts'] | |
edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False) | |
end = time.time() | |
time_list.append(end-start) | |
print(np.mean(time_list)) | |
print(np.mean(time_list[1:])) | |