# python3.7
"""Contains the base class for Encoder (GAN Inversion) runner."""

import os
import shutil

import torch
import torch.distributed as dist
import torchvision.transforms as T

from utils.visualizer import HtmlPageVisualizer
from utils.visualizer import get_grid_shape
from utils.visualizer import postprocess_image
from utils.visualizer import save_image
from utils.visualizer import load_image
from utils.visualizer import postprocess_tensor

from metrics.inception import build_inception_model
from metrics.fid import extract_feature
from metrics.fid import compute_fid
from metrics.MSSIM import MSSSIM
from metrics.LPIPS import LPIPS
import numpy as np
from .base_runner import BaseRunner
from datasets import BaseDataset
from torch.utils.data import DataLoader
from PIL import Image
from runners.controllers.summary_writer import log_image
import torchvision
from editings.latent_editor import LatentEditor
from editings.styleclip.edit_hfgi import styleclip_edit, load_stylegan_generator,load_direction_calculator
from editings.GradCtrl.manipulate import main as gradctrl
import torch.nn.functional as F
import time


__all__ = ['BaseEncoderRunner']


class BaseEncoderRunner(BaseRunner):
    """Defines the base class for Encoder runner."""

    def __init__(self, config, logger):
        super().__init__(config, logger)
        self.inception_model = None

    def build_models(self):
        super().build_models()
        assert 'encoder' in self.models
        assert 'generator_smooth' in self.models
        assert 'discriminator' in self.models

        self.resolution = self.models['generator_smooth'].resolution
        self.G_kwargs_train = self.config.modules['generator_smooth'].get(
            'kwargs_train', dict())
        self.G_kwargs_val = self.config.modules['generator_smooth'].get(
            'kwargs_val', dict())
        self.D_kwargs_train = self.config.modules['discriminator'].get(
            'kwargs_train', dict())
        self.D_kwargs_val = self.config.modules['discriminator'].get(
            'kwargs_val', dict())
        if self.config.use_disc2:
            self.D2_kwargs_train = self.config.modules['discriminator2'].get(
                'kwargs_train', dict())
            self.D2_kwargs_val = self.config.modules['discriminator2'].get(
                'kwargs_val', dict())
        if self.config.mapping_method != 'pretrained':
            self.M_kwargs_train = self.config.modules['mapping'].get(
                'kwargs_train', dict())
            self.M_kwargs_val = self.config.modules['mapping'].get(
                'kwargs_val', dict())
        if self.config.create_mixing_network:
            self.MIX_kwargs_train = self.config.modules['mixer'].get(
                'kwargs_train', dict())
            self.MIX_kwargs_val = self.config.modules['mixer'].get(
                'kwargs_val', dict())

        

    def train_step(self, data, **train_kwargs):
        raise NotImplementedError('Should be implemented in derived class.')

    def mse(self, mse_num):

        if mse_num == 0:
            return -1
        self.set_mode('val')

        if self.val_loader is None:
            self.build_dataset('val')
        if mse_num == "auto":
            mse_num = len(self.val_loader.dataset)

        indices = list(range(self.rank, mse_num, self.world_size))
        self.logger.init_pbar()
        task1 = self.logger.add_pbar_task('MSE-LPIPS-SSIM', total=mse_num)

        lpips = LPIPS()
        ssim = MSSSIM(size_average=False)
        n_evals = 3
        gather_list = [torch.zeros( (self.val_batch_size, n_evals), device=torch.cuda.current_device()) for i in range(self.world_size)]
        all_errors = np.zeros( (mse_num, n_evals), dtype=np.float64)
        shared_tensor = torch.zeros((self.val_batch_size, n_evals), device=torch.cuda.current_device())
        gather_idx = 0
        for batch_idx in range(0, len(indices), self.val_batch_size):
            sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
            batch_size = len(sub_indices)
            data = next(self.val_loader)
            for key in data:
                if key != 'name':
                    data[key] = data[key][:batch_size].cuda(
                        torch.cuda.current_device(), non_blocking=True)

            with torch.no_grad():
                real_images = data['image']
                return_dict = self.forward_pass(data, return_vals='fakes, wp_mixed')
                fakes = return_dict['fakes']
                shared_tensor[:, 0] = torch.mean((fakes - real_images)**2, dim=(1,2,3))   #MSE Error
                shared_tensor[:, 1]= lpips(real_images, fakes)
                shared_tensor[:, 2]= ssim(real_images, fakes)  #ssim (real_images[0].unsqueeze(0), fakes[0].unsqueeze(0) )

                dist.all_gather(gather_list, shared_tensor)
                if self.rank == 0:
                    for t in gather_list:
                        all_errors[gather_idx:gather_idx+batch_size, 0] = t[:,0].cpu().numpy()
                        all_errors[gather_idx:gather_idx+batch_size, 1] = t[:,1].cpu().numpy()
                        all_errors[gather_idx:gather_idx+batch_size, 2] = t[:,2].cpu().numpy()
                        gather_idx = gather_idx+batch_size
                self.logger.update_pbar(task1, batch_size * self.world_size)
        self.logger.close_pbar()
        mean_lst, std_lst = np.mean(all_errors, axis=0), np.std(all_errors, axis=0)
        mse_mean, lpips_mean, ssim_mean = mean_lst[0].item(), mean_lst[1].item(), mean_lst[2].item()
        mse_std, lpips_std, ssim_std = std_lst[0].item(), std_lst[1].item(), std_lst[2].item()
        return_vals = {'mse': (mse_mean,mse_std), 'lpips': (lpips_mean, lpips_std), 'ssim':(ssim_mean, ssim_std)}
        return return_vals

    def fid_attribute(self,
            fid_num,
            z=None,
            ignore_cache=False,
            align_tf=True,
            attribute='smile', factor=1, direction=None):
        """Computes the FID metric."""
        self.set_mode('val')
        direction = torch.load(f'editings/interfacegan_directions/{attribute}.pt').cuda()
        if factor < 0:
            self.config.data['smile']['root_dir'] = f'/media/hdd2/adundar/hamza/genforce/data/temp/smile_with_original'
        elif factor > 0:
            self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_without_original"
        
        fake_loader = self.build_dataset(f"smile")
        
        #fid_num = min(fid_num, len(self.val_loader.dataset))
        fid_num = len(fake_loader.dataset)
        if self.inception_model is None:
            if align_tf:
                self.logger.info(f'Building inception model '
                                 f'(aligned with TensorFlow) ...')
            else:
                self.logger.info(f'Building inception model '
                                 f'(using torchvision) ...')
            self.inception_model = build_inception_model(align_tf).cuda()
            self.logger.info(f'Finish building inception model.')

        if z is not None:
            assert isinstance(z, np.ndarray)
            assert z.ndim == 2 and z.shape[1] == self.z_space_dim
            fid_num = min(fid_num, z.shape[0])
            z = torch.from_numpy(z).type(torch.FloatTensor)
        if not fid_num:
            return -1

        indices = list(range(self.rank, fid_num, self.world_size))

        self.logger.init_pbar()
        # Extract features from fake images.
        fake_feature_list = []
        task1 = self.logger.add_pbar_task(f'FID-{attribute}_fake', total=fid_num)
        for batch_idx in range(0, len(indices), self.val_batch_size):
            sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
            batch_size = len(sub_indices)
            data = next(fake_loader)
            for key in data:
                if key != 'name':
                    data[key] = data[key][:batch_size].cuda(
                        torch.cuda.current_device(), non_blocking=True)
            with torch.no_grad():
                real_images = data['image']
                #valids = data['valid']
                return_dict = self.forward_pass(data, return_vals='all', only_enc = True)
                wp = return_dict['wp_mixed']
                eouts = return_dict['eouts']
                edit_wp = wp + factor * direction
                edited_images, _ = self.runG(edit_wp, "synthesis", highres_outs=eouts)
                fake_feature_list.append(
                    extract_feature(self.inception_model, edited_images))
            self.logger.update_pbar(task1, batch_size * self.world_size)

        np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy',
                np.concatenate(fake_feature_list, axis=0))
        self.logger.close_pbar()

     
        #Extract features from real images if needed.
        cached_fid_file = f'{self.work_dir}/real_{attribute}_{factor}_fid.npy'
        do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache)
        if do_real_test:
            real_feature_list = []
            self.logger.init_pbar()

            if factor < 0:
                self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_without_original"
            elif factor > 0:
                self.config.data['smile']['root_dir'] = f"/media/hdd2/adundar/hamza/genforce/data/temp/smile_with_original"
            real_loader = self.build_dataset(f"smile")

            fid_num = len(real_loader.dataset)
            indices = list(range(self.rank, fid_num, self.world_size))
            task2 = self.logger.add_pbar_task(f"{attribute}_real", total=fid_num)
            for batch_idx in range(0, len(indices), self.val_batch_size):
                sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
                batch_size = len(sub_indices)
                data = next(real_loader)
                for key in data:
                    if key != 'name':
                        data[key] = data[key][:batch_size].cuda(
                            torch.cuda.current_device(), non_blocking=True)
                with torch.no_grad():
                    real_images = data['image']
                    real_feature_list.append(
                        extract_feature(self.inception_model, real_images))
                self.logger.update_pbar(task2, batch_size * self.world_size)
            np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy',
                    np.concatenate(real_feature_list, axis=0))

        dist.barrier()
        if self.rank != 0:
            return -1
        self.logger.close_pbar()

        # Collect fake features.
        fake_feature_list.clear()
        for rank in range(self.world_size):
            fake_feature_list.append(
                np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy'))
            os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy')
        fake_features = np.concatenate(fake_feature_list, axis=0)
        # assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num
        feature_dim = fake_features.shape[1]
        feature_num = fake_features.shape[0]
        pad = feature_num % self.world_size #feature_dim.shape[0]
        if pad:
            pad = self.world_size - pad
        fake_features = np.pad(fake_features, ((0, pad), (0, 0)))
        fake_features = fake_features.reshape(self.world_size, -1, feature_dim)
        fake_features = fake_features.transpose(1, 0, 2)
        fake_features = fake_features.reshape(-1, feature_dim)[:feature_num]

        # Collect (or load) real features.
        if do_real_test:
            real_feature_list.clear()
            for rank in range(self.world_size):
                real_feature_list.append(
                    np.load(f'{self.work_dir}/real_fid_features_{rank}.npy'))
                os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy')
            real_features = np.concatenate(real_feature_list, axis=0)
            # assert real_features.shape == (fid_num, feature_dim)
            feature_dim = real_features.shape[1]
            feature_num = real_features.shape[0]
            pad = feature_num % self.world_size
            if pad:
                pad = self.world_size - pad
            real_features = np.pad(real_features, ((0, pad), (0, 0)))
            real_features = real_features.reshape(
                self.world_size, -1, feature_dim)
            real_features = real_features.transpose(1, 0, 2)
            real_features = real_features.reshape(-1, feature_dim)[:feature_num]
            np.save(cached_fid_file, real_features)
        else:
            real_features = np.load(cached_fid_file)
            # assert real_features.shape == (fid_num, feature_dim)

        fid_value = compute_fid(fake_features, real_features)
        return fid_value

    def fid(self,
            fid_num,
            z=None,
            ignore_cache=False,
            align_tf=True):
        """Computes the FID metric."""
        self.set_mode('val')

        if self.val_loader is None:
            self.build_dataset('val')
        fid_num = min(fid_num, len(self.val_loader.dataset))

        if self.inception_model is None:
            if align_tf:
                self.logger.info(f'Building inception model '
                                 f'(aligned with TensorFlow) ...')
            else:
                self.logger.info(f'Building inception model '
                                 f'(using torchvision) ...')
            self.inception_model = build_inception_model(align_tf).cuda()
            self.logger.info(f'Finish building inception model.')

        if z is not None:
            assert isinstance(z, np.ndarray)
            assert z.ndim == 2 and z.shape[1] == self.z_space_dim
            fid_num = min(fid_num, z.shape[0])
            z = torch.from_numpy(z).type(torch.FloatTensor)
        if not fid_num:
            return -1

        indices = list(range(self.rank, fid_num, self.world_size))

        self.logger.init_pbar()
        #generator = self.run_with_optim if run_with_optim else self.run_without_optim
        # Extract features from fake images.
        fake_feature_list = []
        real_feature_list = []
        task1 = self.logger.add_pbar_task('FID', total=fid_num)
        for batch_idx in range(0, len(indices), self.val_batch_size):
            sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
            batch_size = len(sub_indices)
            data = next(self.val_loader)
            for key in data:
                if key != 'name':
                    data[key] = data[key][:batch_size].cuda(
                        torch.cuda.current_device(), non_blocking=True)

            # z_rand = torch.randn((batch_size,self.config.z_count,self.config.latent_dim)
            #                     , device=torch.cuda.current_device())
            # data['z_rand'] = z_rand
            with torch.no_grad():
                real_images = data['image']
                #valids = data['valid']
                return_dict = self.forward_pass(data, return_vals='fakes, wp_mixed')
                fakes = return_dict['fakes']
            if self.config.test_time_optims != 0:
                wp_mixed = return_dict['wp_mixed']
                fakes = self.optimize(data, wp_mixed)
            with torch.no_grad():
                #final_out = real_images * valids + fakes * (1.0-valids) #Final output is the mixed one. 
                fake_feature_list.append(
                    extract_feature(self.inception_model, fakes))
            # Extract features from real images if needed.
            cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy'
            do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache)
            if do_real_test:
                with torch.no_grad():
                    real_feature_list.append(
                            extract_feature(self.inception_model, real_images))
            self.logger.update_pbar(task1, batch_size * self.world_size)

        np.save(f'{self.work_dir}/fake_fid_features_{self.rank}.npy',
                np.concatenate(fake_feature_list, axis=0))
        if (do_real_test):
            np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy',
                        np.concatenate(real_feature_list, axis=0))

        # Extract features from real images if needed.
        # cached_fid_file = f'{self.work_dir}/real_fid{fid_num}.npy'
        # do_real_test = (not os.path.exists(cached_fid_file) or ignore_cache)
        # if do_real_test:
        #     real_feature_list = []
        #     task2 = self.logger.add_pbar_task("Real", total=fid_num)
        #     for batch_idx in range(0, len(indices), self.val_batch_size):
        #         sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
        #         batch_size = len(sub_indices)
        #         data = next(self.val_loader)
        #         for key in data:
        #             data[key] = data[key][:batch_size].cuda(
        #                 torch.cuda.current_device(), non_blocking=True)
        #         with torch.no_grad():
        #             real_images = data['image']
        #             real_feature_list.append(
        #                 extract_feature(self.inception_model, real_images))
        #         self.logger.update_pbar(task2, batch_size * self.world_size)
        #     np.save(f'{self.work_dir}/real_fid_features_{self.rank}.npy',
        #             np.concatenate(real_feature_list, axis=0))

        dist.barrier()
        if self.rank != 0:
            return -1
        self.logger.close_pbar()

        # Collect fake features.
        fake_feature_list.clear()
        for rank in range(self.world_size):
            fake_feature_list.append(
                np.load(f'{self.work_dir}/fake_fid_features_{rank}.npy'))
            os.remove(f'{self.work_dir}/fake_fid_features_{rank}.npy')
        fake_features = np.concatenate(fake_feature_list, axis=0)
        assert fake_features.ndim == 2 and fake_features.shape[0] == fid_num
        feature_dim = fake_features.shape[1]
        pad = fid_num % self.world_size
        if pad:
            pad = self.world_size - pad
        fake_features = np.pad(fake_features, ((0, pad), (0, 0)))
        fake_features = fake_features.reshape(self.world_size, -1, feature_dim)
        fake_features = fake_features.transpose(1, 0, 2)
        fake_features = fake_features.reshape(-1, feature_dim)[:fid_num]

        # Collect (or load) real features.
        if do_real_test:
            real_feature_list.clear()
            for rank in range(self.world_size):
                real_feature_list.append(
                    np.load(f'{self.work_dir}/real_fid_features_{rank}.npy'))
                os.remove(f'{self.work_dir}/real_fid_features_{rank}.npy')
            real_features = np.concatenate(real_feature_list, axis=0)
            assert real_features.shape == (fid_num, feature_dim)
            real_features = np.pad(real_features, ((0, pad), (0, 0)))
            real_features = real_features.reshape(
                self.world_size, -1, feature_dim)
            real_features = real_features.transpose(1, 0, 2)
            real_features = real_features.reshape(-1, feature_dim)[:fid_num]
            np.save(cached_fid_file, real_features)
        else:
            real_features = np.load(cached_fid_file)
            assert real_features.shape == (fid_num, feature_dim)

        fid_value = compute_fid(fake_features, real_features)
        return fid_value

    def val(self, **val_kwargs):
        self.synthesize(**val_kwargs)

    def synthesize(self,
                   num,
                   html_name=None,
                   save_raw_synthesis=False):
        """Synthesizes images.

        Args:
            num: Number of images to synthesize.
            z: Latent codes used for generation. If not specified, this function
                will sample latent codes randomly. (default: None)
            html_name: Name of the output html page for visualization. If not
                specified, no visualization page will be saved. (default: None)
            save_raw_synthesis: Whether to save raw synthesis on the disk.
                (default: False)
        """

        dist.barrier()
        if self.rank != 0:
            return
            
        if not html_name and not save_raw_synthesis:
            return

        self.set_mode('val')

        if self.val_loader is None:
            self.build_dataset('val')

        # temp_dir = os.path.join(self.work_dir, 'synthesize_results')
        # os.makedirs(temp_dir, exist_ok=True)

        if not num:
            return
        # if num % self.val_batch_size != 0:
        #     num =  (num //self.val_batch_size +1)*self.val_batch_size
        # TODO: Use same z during the entire training process.
        
        self.logger.init_pbar()
        task = self.logger.add_pbar_task('Synthesis', total=num)
        for i in range(num):
            data = next(self.val_loader)
            for key in data:
                if key != 'name':
                    data[key] = data[key].cuda(
                        torch.cuda.current_device(), non_blocking=True)

            with torch.no_grad():
                real_images = data['image']
                return_dict = self.forward_pass(data, return_vals='all')
                fakes = return_dict['fakes']
                wp_mixed = return_dict['wp_mixed']
                eouts = return_dict['eouts']

                log_list_gpu = {"real": real_images, "fake": fakes}

                # Add editings to log_list
                editings = ['age', 'pose', 'smile']
                for edit in editings:
                    direction = torch.load(f'editings/interfacegan_directions/{edit}.pt').cuda()
                    factors = [+3, -3]
                    for factor in factors:
                        name = f"{edit}_{factor}"
                        edit_wp = wp_mixed + factor * direction
                        edited_images, _ = self.runG(edit_wp, "synthesis", highres_outs=eouts)
                        # if edit == 'smile' and factor == -3:
                        #     res = gouts['gates'].shape[-1]
                        #     log_list_gpu[f'smile_-3_gate'] = ( torch.mean((gouts_edits['gates']) , dim=1, keepdim=True), 0)
                        #edited_images = F.adaptive_avg_pool2d(edited_images, 256)
                        log_list_gpu[name] = edited_images
                        #log_list_gpu[f'{name}_gate'] = ( torch.mean((temp['gates']) , dim=1, keepdim=True), 0)

                #Add gate to log_list
                # res = gouts['gates'].shape[-1]
                # log_list_gpu[f'gate{res}x{res}'] = ( torch.mean((gouts['gates']) , dim=1, keepdim=True), 0)
            
                #Log images
                for log_name, log_val in log_list_gpu.items():
                    log_im = log_val[0] if type(log_val) is tuple else log_val
                    min_val = log_val[1] if type(log_val) is tuple else -1
                    cpu_img = postprocess_tensor(log_im.detach().cpu(), min_val=min_val)
                    grid = torchvision.utils.make_grid(cpu_img, nrow=5)
                    log_image( name = f"image/{log_name}", grid=grid, iter=self.iter)
            self.logger.update_pbar(task, 1)
        self.logger.close_pbar()

    def save_edited_images(self, opts):
        dist.barrier()
        if self.rank != 0:
            return
        self.set_mode('val')

        if opts.method  == 'inversion':
            pass
        elif opts.method == 'interfacegan':
            direction = torch.load(f'editings/interfacegan_directions/{opts.edit}.pt').cuda()
        elif opts.method == 'ganspace':
            ganspace_pca = torch.load('editings/ganspace_pca/ffhq_pca.pt')
            direction = {
                'eye_openness':            (54,  7,  8,  20),
                'smile':                   (46,  4,  5, -20),
                'beard':                   (58,  7,  9,  -20),
                'white_hair':              (57,  7, 10, -24),
                'lipstick':                (34, 10, 11,  20),
                'overexposed':             (27,  8, 18,  15),
                'screaming':               (35,  3,  7, -10),
                'head_angle_up':           (11,  1,  4,  10),
            }
            editor = LatentEditor()
        elif opts.method == 'styleclip':
            #model_path = '/media/hdd2/adundar/hamza/hyperstyle/pretrained_models/stylegan2-ffhq-config-f.pt'
            # calculator_args = {
		    # 'delta_i_c': 'editings/styleclip/global_directions/ffhq/fs3.npy',
		    # 's_statistics': 'editings/styleclip/global_directions/ffhq/S_mean_std',
		    # 'text_prompt_templates': 'editings/styleclip/global_directions/templates.txt'
	        #  }
            stylegan_model = load_stylegan_generator(opts.model_path)
            global_direction_calculator = load_direction_calculator(opts.calculator_args)
            #Eyeglasses 5, bangs 2, bobcut 5
            # edit_args = {'alpha_min': 2, 'alpha_max': 2, 'num_alphas':1, 'beta_min':0.11, 'beta_max':0.11, 'num_betas': 1,
            #     'neutral_text':'face', 'target_text': 'face with bangs'}
    
        
        self.config.data['val']['root_dir'] = opts.dataset

        dataset = BaseDataset(**self.config.data['val'])
        val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1)

        temp_dir = os.path.join(self.work_dir, opts.output)
        os.makedirs(temp_dir, exist_ok=True)
 

        self.logger.init_pbar()
        task = self.logger.add_pbar_task('Synthesis', total=len(val_loader))
        global_i = 0
        all_latents = {}
        for idx, data in enumerate(val_loader):
            for key in data:
                if key != 'name':
                    data[key] = data[key].cuda(
                        torch.cuda.current_device(), non_blocking=True)
            with torch.no_grad():
                return_dict = self.forward_pass(data, return_vals='all', only_enc=True)
                wp_mixed = return_dict['wp_mixed']
                eouts = return_dict['eouts']
                #fakes = return_dict['fakes']
                factors = np.linspace(0, 3, 100)
                global_i = 0

                # for factor in factors:
                if opts.method == 'interfacegan':
                    wp_mixed = wp_mixed + opts.factor * direction
                if opts.method == 'ganspace':
                    #interpolate_dir = direction[opts.edit][0:3] + (factor,) 
                    wp_mixed = editor.apply_ganspace_(wp_mixed, ganspace_pca, [direction[opts.edit]])
                    #wp_edit = editor.apply_ganspace_(wp_mixed, ganspace_pca, [interpolate_dir])


                # z = torch.randn((1,self.config.latent_dim), device=torch.cuda.current_device())
                # z = self.runM(z)
                # diff = z - wp_mixed
                # edit = (diff * 3.5) / 10
                # wp_mixed = wp_mixed + edit
                edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False)
                if opts.method == 'styleclip':
                    # opts.edit_args['alpha_min'] = factor
                    edited_images = styleclip_edit(wp_mixed, gouts_edits['additions'], stylegan_model, global_direction_calculator, opts.edit_args)
                edited_images = T.Resize((256,256))(edited_images)
                edited_images = postprocess_image(edited_images.detach().cpu().numpy())
                for j in range(edited_images.shape[0]):
                    # dir_name = data['name'][j][:-4]
                    # os.makedirs(os.path.join(temp_dir, dir_name), exist_ok=True)
                    # save_name = f'{global_i:03d}_' + data['name'][j]
                    save_name = data['name'][j]
                    pil_img = Image.fromarray(edited_images[j]) #.resize((256,256))
                    #pil_img.save(os.path.join(temp_dir, dir_name,  save_name ))
                    pil_img.save(os.path.join(temp_dir,  save_name ))
                    global_i += 1
                # if global_i >= 1000:
                #     break
                # if global_i % 100 == 0:
                #     print(f"{global_i}/1000")

            self.logger.update_pbar(task, 1)
        self.logger.close_pbar()

    # def forward_pass(self, data, return_vals='all', only_enc=False):
    #     encoder_type = self.config.encoder_type
    #     forward_func = getattr(self, f'{encoder_type}_forward')
    #     return_dict = forward_func(data,only_enc)
    #     return return_dict
    #     # if return_vals == 'all':
    #     #     return return_dict
    #     # requested = return_vals.split(',')
    #     # modified_dict = {}
    #     # for request in requested:
    #     #     stripped_request = request.strip()
    #     #     modified_dict[stripped_request] = return_dict[stripped_request]
    #     # return modified_dict
    
    # def base_forward(self, data):
    #     reals = data['image']
    #     valids = data['valid']
    #     z_rand = data['z_rand']

    #     wp_rand = self.runM(z_rand)
    #     wp_enc, blender = self.runE(reals, valids)
    #     wp_mixed = self.mix(wp_enc, wp_rand, blender)
    #     fakes = self.runG(wp_mixed, 'synthesis')
    #     return_dict = {'fakes': fakes, 'wp_enc': wp_enc, 'blender': blender, 'wp_mixed':wp_mixed}
    #     return return_dict

    # def pSp_forward(self, data, only_enc):
    #     return self.e4e_forward(data, only_enc)

    # def train_forward(self, data, iscycle=False):
    #     reals = data['image']
    #     direction = data['direction']
    #     edit_name = data['edit_name']
    #     factor = data['factor']
    #     E = self.models['encoder']
    #     with torch.no_grad():
    #         wp, eouts = E(reals)
    #     #wp = wp + self.meanw.repeat(reals.shape[0], 1, 1)
    #     edit = torch.zeros_like(wp)
    #     for i in range (edit.shape[0]):
    #         if edit_name[i] is None:
    #             edit[i] = 0
    #         elif edit_name[i] == 'randw':
    #             diff = direction[i] - wp[i]
    #             # one_hot = [1] * 8 + [0] * 10
    #             # one_hot = torch.tensor(one_hot, device=diff.device).unsqueeze(1)
    #             # diff = diff * one_hot   
    #             #norm = torch.linalg.norm(diff, dim=1, keepdim=True)
    #             edit[i] = (diff * factor[i]) / 10
    #         elif edit_name[i] == 'interface':
    #             edit[i] = (factor[i] * direction[i])

    #     # # Debug
    #     # with torch.no_grad():
    #     #     fakes,_ =self.runG(wp, 'synthesis', highres_outs=None)
    #     #     fakes = postprocess_image(fakes.detach().cpu().numpy())
    #     #     for i in range(fakes.shape[0]):
    #     #         pil_img = Image.fromarray(fakes[i]).resize((256,256))
    #     #         pil_img.save(f'{self.iter}_orig.png')

    #     #     fakes,_ =self.runG(wp+edit, 'synthesis', highres_outs=None)
    #     #     fakes = postprocess_image(fakes.detach().cpu().numpy())
    #     #     for i in range(fakes.shape[0]):
    #     #         pil_img = Image.fromarray(fakes[i]).resize((256,256))
    #     #         pil_img.save(f'{self.iter}_edit.png')

    #     #     fakes,_ =self.runG(direction.unsqueeze(1).repeat(1,18,1), 'synthesis', highres_outs=None)
    #     #     fakes = postprocess_image(fakes.detach().cpu().numpy())
    #     #     for i in range(fakes.shape[0]):
    #     #         pil_img = Image.fromarray(fakes[i]).resize((256,256))
    #     #         pil_img.save(f'{self.iter}_rand.png')
    #     with torch.no_grad():
    #         eouts['inversion'] = self.runG(wp, 'synthesis', highres_outs=None, return_f=True)
    #     wp = wp + edit
    #     fakes, gouts = self.runG(wp, 'synthesis', highres_outs=eouts)
    #     #fakes = F.adaptive_avg_pool2d(fakes, (256,256))
    #     fakes_cycle = None
    #     if iscycle:
    #         # wp_cycle = wp_cycle + self.meanw.repeat(reals.shape[0], 1, 1)
    #         with torch.no_grad():
    #             wp_cycle, eout_cycle = E(fakes)
    #             eout_cycle['inversion'] = self.runG(wp_cycle, 'synthesis', highres_outs=None, return_f=True)

    #         #wp_cycle = wp_cycle - edit
    #         wp_cycle = wp_cycle - edit
    #         #wp_cycle = wp_cycle - (data['factor'] * data['direction']).unsqueeze(1)
    #         fakes_cycle, _ = self.runG(wp_cycle, 'synthesis', highres_outs=eout_cycle)
    #         #fakes_cycle = F.adaptive_avg_pool2d(fakes, (256,256))
    #         #cycle = F.mse_loss(fakes_cycle, reals, reduction='mean')
    #     return_dict = {'fakes': fakes, 'wp_mixed':wp, 'gouts':gouts, 'eouts': eouts, 'cycle': fakes_cycle}
    #     return return_dict

    # def e4e_forward(self, data, only_enc=False):
    #     #return self.base_forward(data)
    #     reals = data['image']
    #     #valids = data['valid']
    #     E = self.models['encoder']
    #     wp_mixed, eouts = E(reals)
    #     #wp_mixed = wp_mixed + self.meanw.repeat(reals.shape[0], 1, 1)
    #     eouts['inversion'] = self.runG(wp_mixed, 'synthesis', highres_outs=None, return_f=True)
    #     if only_enc:
    #         return_dict = {'wp_mixed':wp_mixed,'eouts': eouts}
    #         return return_dict
    #     fakes, gouts = self.runG(wp_mixed, 'synthesis', highres_outs=eouts)
    #     #fakes = self.runG(wp_mixed, 'synthesis', highres_outs=None)
    #     #fakes = F.adaptive_avg_pool2d(fakes, (256,256))
    #     return_dict = {'fakes': fakes, 'wp_mixed':wp_mixed, 'gouts':gouts, 'eouts': eouts}
    #     return return_dict


    # def hyperstyle_forward(self, data):
    #     return_dict = self.base_forward(data)
        
    #     E = self.models['encoder']
    #     reals = data['image']
    #     valids = data['valid']

    #     #HyperNetwork 
    #     weight_deltas = E(reals, valids, mode='hyper', gouts=return_dict['fakes'])
    #     fakes = self.runG(return_dict['wp_mixed'], 'synthesis', weight_deltas=weight_deltas)
    #     return_dict['fakes'] = fakes

    #     return return_dict

    def interface_generate(self, num, edit, factor):
        direction = torch.load(f'editings/interfacegan_directions/{edit}.pt').cuda()
        indices = list(range(self.rank, num, self.world_size))
        gt_path = os.path.join(self.work_dir, f'interfacegan_gt')
        smile_add_path = os.path.join(self.work_dir, f'interfacegan_{edit}_{factor}')
        smile_rm_path = os.path.join(self.work_dir, f'interfacegan_{edit}_-{factor}')
        if self.rank == 0:
            os.makedirs(gt_path, exist_ok=True)
            os.makedirs(smile_add_path, exist_ok=True)
            os.makedirs(smile_rm_path, exist_ok=True)
        dist.barrier()

        self.logger.init_pbar()
        task = self.logger.add_pbar_task('Interfacegan', total=num)


        for batch_idx in range(0, len(indices), self.val_batch_size):
            sub_indices = indices[batch_idx:batch_idx + self.val_batch_size]
            batch_size = len(sub_indices)

            z = torch.randn((batch_size,512), device=torch.cuda.current_device())
            w_r = self.runM(z, repeat_w=True)
            gt_imgs,_ = self.runG(w_r, resize=False)
            gt_imgs = postprocess_image(gt_imgs.detach().cpu().numpy())
            for i in range(gt_imgs.shape[0]):
                save_name = str(sub_indices[i]) + ".png"
                pil_img = Image.fromarray(gt_imgs[i]).resize((256,256))
                pil_img.save(os.path.join(gt_path, save_name ))
            
            smile_added, _ = self.runG(w_r + factor*direction, resize=False)
            smile_added = postprocess_image(smile_added.detach().cpu().numpy())
            for i in range(gt_imgs.shape[0]):
                save_name = str(sub_indices[i]) + ".png"
                pil_img = Image.fromarray(smile_added[i]).resize((256,256))
                pil_img.save(os.path.join(smile_add_path, save_name ))

            smile_removed, _= self.runG(w_r - factor*direction, resize=False)
            smile_removed = postprocess_image(smile_removed.detach().cpu().numpy())
            for i in range(gt_imgs.shape[0]):
                save_name = str(sub_indices[i]) + ".png"
                pil_img = Image.fromarray(smile_removed[i]).resize((256,256))
                pil_img.save(os.path.join(smile_rm_path, save_name ))

            self.logger.update_pbar(task, batch_size * self.world_size)
        self.logger.close_pbar()

    def grad_edit(self, edit, factor, dataset=None):
        dist.barrier()
        if self.rank != 0:
            return
        self.set_mode('val')
        edit_name = edit
        edit = 'val'
        self.config.data[edit]['root_dir'] = dataset

        dataset = BaseDataset(**self.config.data[edit])
        val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

        temp_dir = os.path.join(self.work_dir, f'fakes_{edit_name}_{factor}')
        os.makedirs(temp_dir, exist_ok=True)

        self.logger.init_pbar()
        task = self.logger.add_pbar_task('Synthesis', total=len(val_loader))
        global_i = 0
        args = {'model': 'ffhq', 'model_dir': '/media/hdd2/adundar/hamza/genforce/editings/GradCtrl/model_ffhq',
                'attribute': edit_name, 'exclude': 'default', 'top_channels': 'default', 'layerwise': 'default' }
        for idx, data in enumerate(val_loader):
            for key in data:
                if key != 'name':
                    data[key] = data[key].cuda(
                        torch.cuda.current_device(), non_blocking=True)
            with torch.no_grad():
                return_dict = self.forward_pass(data, return_vals='all', only_enc=True)
                wp_mixed = return_dict['wp_mixed']
                eouts = return_dict['eouts']

            #fakes = return_dict['fakes']
            edit_wp = gradctrl(args, wp_mixed, factor)
            edited_images, gouts_edits = self.runG(edit_wp, "synthesis", highres_outs=eouts, resize=False)
            #edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False)
            edited_images = postprocess_image(edited_images.detach().cpu().numpy())
            for j in range(edited_images.shape[0]):
                save_name = data['name'][j]
                pil_img = Image.fromarray(edited_images[j]).resize((256,256))
                pil_img.save(os.path.join(temp_dir, save_name ))
                global_i += 1

            self.logger.update_pbar(task, 1)
        self.logger.close_pbar()

    def measure_time(self, edit, factor, dataset=None, save_latents=False):
            dist.barrier()
            if self.rank != 0:
                return
            self.set_mode('val')
            edit_name = edit
            if dataset is not None:
                edit = 'val'
                self.config.data[edit]['root_dir'] = dataset

            dataset = BaseDataset(**self.config.data[edit])
            val_loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=0)

            global_i = 0
            time_list = []
            for idx, data in enumerate(val_loader):
                for key in data:
                    if key != 'name':
                        data[key] = data[key].cuda(
                            torch.cuda.current_device(), non_blocking=True)
                with torch.no_grad():
                    start = time.time()
                    return_dict = self.forward_pass(data, return_vals='all', only_enc=True)
                    wp_mixed = return_dict['wp_mixed']
                    eouts = return_dict['eouts']
                    edited_images, gouts_edits = self.runG(wp_mixed, "synthesis", highres_outs=eouts, resize=False)
                    end = time.time()
                    time_list.append(end-start)
                    print(np.mean(time_list))
            print(np.mean(time_list[1:]))