import utils, torch, time, os, pickle
import numpy as np
import torch.nn as nn
import torch.cuda as cu
import torch.optim as optim
import pickle
from torchvision import transforms
from torchvision.utils import save_image
from utils import augmentData, RGBtoL, LtoRGB
from PIL import Image
from dataloader import dataloader
from torch.autograd import Variable
import matplotlib.pyplot as plt
import random
from datetime import date
from statistics import mean
from architectures import depth_generator_UNet, \
    depth_discriminator_noclass_UNet


class WiggleGAN(object):
    def __init__(self, args):
        # parameters
        self.epoch = args.epoch
        self.sample_num = 100
        self.nCameras = args.cameras
        self.batch_size = args.batch_size
        self.save_dir = args.save_dir
        self.result_dir = args.result_dir
        self.dataset = args.dataset
        self.log_dir = args.log_dir
        self.gpu_mode = args.gpu_mode
        self.model_name = args.gan_type
        self.input_size = args.input_size
        self.class_num = (args.cameras - 1) * 2  # un calculo que hice en paint
        self.sample_num = self.class_num ** 2
        self.imageDim = args.imageDim
        self.epochVentaja = args.epochV
        self.cantImages = args.cIm
        self.visdom = args.visdom
        self.lambdaL1 = args.lambdaL1
        self.depth = args.depth
        self.name_wiggle = args.name_wiggle

        self.clipping = args.clipping
        self.WGAN = False
        if (self.clipping > 0):
            self.WGAN = True

        self.seed = str(random.randint(0, 99999))
        self.seed_load = args.seedLoad
        self.toLoad = False
        if (self.seed_load != "-0000"):
            self.toLoad = True

        self.zGenFactor = args.zGF
        self.zDisFactor = args.zDF
        self.bFactor = args.bF
        self.CR = False
        if (self.zGenFactor > 0 or self.zDisFactor > 0 or self.bFactor > 0):
            self.CR = True

        self.expandGen = args.expandGen
        self.expandDis = args.expandDis

        self.wiggleDepth = args.wiggleDepth
        self.wiggle = False
        if (self.wiggleDepth > 0):
            self.wiggle = True



        # load dataset

        self.onlyGen = args.lrD <= 0 

        if not self.wiggle:
            self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='train',
                                      trans=not self.CR)

            self.data_Validation = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim,
                                          split='validation')

            self.dataprint = self.data_Validation.__iter__().__next__()

            data = self.data_loader.__iter__().__next__().get('x_im')


            if not self.onlyGen:
              self.D = depth_discriminator_noclass_UNet(input_dim=3, output_dim=1, input_shape=data.shape,
                                                        class_num=self.class_num,
                                                        expand_net=self.expandDis, depth = self.depth, wgan = self.WGAN)
              self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))

        self.data_Test = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='test')
        self.dataprint_test = self.data_Test.__iter__().__next__()

        # networks init

        self.G = depth_generator_UNet(input_dim=4, output_dim=3, class_num=self.class_num, expand_net=self.expandGen, depth = self.depth)
        self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))


        if self.gpu_mode:
            self.G.cuda()
            if not self.wiggle and not self.onlyGen:
                self.D.cuda()
            self.BCE_loss = nn.BCELoss().cuda()
            self.CE_loss = nn.CrossEntropyLoss().cuda()
            self.L1 = nn.L1Loss().cuda()
            self.MSE = nn.MSELoss().cuda()
            self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss().cuda()
        else:
            self.BCE_loss = nn.BCELoss()
            self.CE_loss = nn.CrossEntropyLoss()
            self.MSE = nn.MSELoss()
            self.L1 = nn.L1Loss()
            self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss()

        print('---------- Networks architecture -------------')
        utils.print_network(self.G)
        if not self.wiggle and not self.onlyGen:
            utils.print_network(self.D)
        print('-----------------------------------------------')

        temp = torch.zeros((self.class_num, 1))
        for i in range(self.class_num):
            temp[i, 0] = i

        temp_y = torch.zeros((self.sample_num, 1))
        for i in range(self.class_num):
            temp_y[i * self.class_num: (i + 1) * self.class_num] = temp

        self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1)
        if self.gpu_mode:
             self.sample_y_ = self.sample_y_.cuda()

        if (self.toLoad):
            self.load()

    def train(self):

        if self.visdom:
            random.seed(time.time())
            today = date.today()

            vis = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
            visValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
            visEpoch = utils.VisdomLineTwoPlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
            visImages = utils.VisdomImagePlotter(env_name='Cobo_depth_Images_' + str(today) + '_' + self.seed)
            visImagesTest = utils.VisdomImagePlotter(env_name='Cobo_depth_ImagesTest_' + str(today) + '_' + self.seed)

            visLossGTest = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
            visLossGValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)

            visLossDTest = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
            visLossDValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)

        self.train_hist = {}
        self.epoch_hist = {}
        self.details_hist = {}
        self.train_hist['D_loss_train'] = []
        self.train_hist['G_loss_train'] = []
        self.train_hist['D_loss_Validation'] = []
        self.train_hist['G_loss_Validation'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        self.details_hist['G_T_Comp_im'] = []
        self.details_hist['G_T_BCE_fake_real'] = []
        self.details_hist['G_T_Cycle'] = []
        self.details_hist['G_zCR'] = []

        self.details_hist['G_V_Comp_im'] = []
        self.details_hist['G_V_BCE_fake_real'] = []
        self.details_hist['G_V_Cycle'] = []

        self.details_hist['D_T_BCE_fake_real_R'] = []
        self.details_hist['D_T_BCE_fake_real_F'] = []
        self.details_hist['D_zCR'] = []
        self.details_hist['D_bCR'] = []

        self.details_hist['D_V_BCE_fake_real_R'] = []
        self.details_hist['D_V_BCE_fake_real_F'] = []

        self.epoch_hist['D_loss_train'] = []
        self.epoch_hist['G_loss_train'] = []
        self.epoch_hist['D_loss_Validation'] = []
        self.epoch_hist['G_loss_Validation'] = []

        ##Para poder tomar el promedio por epoch
        iterIniTrain = 0
        iterFinTrain = 0

        iterIniValidation = 0
        iterFinValidation = 0

        maxIter = self.data_loader.dataset.__len__() // self.batch_size
        maxIterVal = self.data_Validation.dataset.__len__() // self.batch_size

        if (self.WGAN):
            one = torch.tensor(1, dtype=torch.float).cuda()
            mone = one * -1
        else:
            self.y_real_ = torch.ones(self.batch_size, 1)
            self.y_fake_ = torch.zeros(self.batch_size, 1)
            if self.gpu_mode:
                self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()

        print('training start!!')
        start_time = time.time()

        for epoch in range(self.epoch):

            if (epoch < self.epochVentaja):
                ventaja = True
            else:
                ventaja = False

            self.G.train()

            if not self.onlyGen:
              self.D.train()
            epoch_start_time = time.time()


            # TRAIN!!!
            for iter, data in enumerate(self.data_loader):

                x_im = data.get('x_im')
                x_dep = data.get('x_dep')
                y_im = data.get('y_im')
                y_dep = data.get('y_dep')
                y_ = data.get('y_')

                # x_im  = imagenes normales
                # x_dep = profundidad de images
                # y_im  = imagen con el angulo cambiado
                # y_    = angulo de la imagen = tengo que tratar negativos

                # Aumento mi data
                if (self.CR):
                    x_im_aug, y_im_aug = augmentData(x_im, y_im)
                    x_im_vanilla = x_im

                    if self.gpu_mode:
                        x_im_aug, y_im_aug = x_im_aug.cuda(), y_im_aug.cuda()

                if iter >= maxIter:
                    break

                if self.gpu_mode:
                    x_im, y_, y_im, x_dep, y_dep = x_im.cuda(), y_.cuda(), y_im.cuda(), x_dep.cuda(), y_dep.cuda()

                # update D network
                if not ventaja and not self.onlyGen:

                    for p in self.D.parameters():  # reset requires_grad
                        p.requires_grad = True  # they are set to False below in netG update

                    self.D_optimizer.zero_grad()

                    # Real Images
                    D_real, D_features_real = self.D(y_im, x_im, y_dep, y_)  ## Es la funcion forward `` g(z) x

                    # Fake Images
                    G_, G_dep = self.G( y_, x_im, x_dep)
                    D_fake, D_features_fake = self.D(G_, x_im, G_dep, y_)

                    # Losses
                    #  GAN Loss
                    if (self.WGAN): # de WGAN
                        D_loss_real_fake_R = - torch.mean(D_real)
                        D_loss_real_fake_F = torch.mean(D_fake)
                        #D_loss_real_fake_R = - D_loss_real_fake_R_positive

                    else:       # de Gan normal
                        D_loss_real_fake_R = self.BCEWithLogitsLoss(D_real, self.y_real_)
                        D_loss_real_fake_F = self.BCEWithLogitsLoss(D_fake, self.y_fake_)

                    D_loss = D_loss_real_fake_F + D_loss_real_fake_R

                    if self.CR:

                        # Fake Augmented Images bCR
                        x_im_aug_bCR, G_aug_bCR = augmentData(x_im_vanilla, G_.data.cpu())

                        if self.gpu_mode:
                            G_aug_bCR, x_im_aug_bCR = G_aug_bCR.cuda(), x_im_aug_bCR.cuda()

                        D_fake_bCR, D_features_fake_bCR = self.D(G_aug_bCR, x_im_aug_bCR, G_dep, y_)
                        D_real_bCR, D_features_real_bCR = self.D(y_im_aug, x_im_aug, y_dep, y_)

                        # Fake Augmented Images zCR
                        G_aug_zCR, G_dep_aug_zCR = self.G(y_, x_im_aug, x_dep)
                        D_fake_aug_zCR, D_features_fake_aug_zCR = self.D(G_aug_zCR, x_im_aug, G_dep_aug_zCR, y_)

                        #  bCR Loss (*)
                        D_loss_real = self.MSE(D_features_real, D_features_real_bCR)
                        D_loss_fake = self.MSE(D_features_fake, D_features_fake_bCR)
                        D_bCR = (D_loss_real + D_loss_fake) * self.bFactor

                        #  zCR Loss
                        D_zCR = self.MSE(D_features_fake, D_features_fake_aug_zCR) * self.zDisFactor

                        D_CR_losses = D_bCR + D_zCR
                        #D_CR_losses.backward(retain_graph=True)

                        D_loss += D_CR_losses

                        self.details_hist['D_bCR'].append(D_bCR.detach().item())
                        self.details_hist['D_zCR'].append(D_zCR.detach().item())
                    else:
                        self.details_hist['D_bCR'].append(0)
                        self.details_hist['D_zCR'].append(0)

                    self.train_hist['D_loss_train'].append(D_loss.detach().item())
                    self.details_hist['D_T_BCE_fake_real_R'].append(D_loss_real_fake_R.detach().item())
                    self.details_hist['D_T_BCE_fake_real_F'].append(D_loss_real_fake_F.detach().item())
                    if self.visdom:
                      visLossDTest.plot('Discriminator_losses',
                                           ['D_T_BCE_fake_real_R','D_T_BCE_fake_real_F', 'D_bCR', 'D_zCR'], 'train',
                                           self.details_hist)
                    #if self.WGAN:
                    #    D_loss_real_fake_F.backward(retain_graph=True)
                    #    D_loss_real_fake_R_positive.backward(mone)
                    #else:
                    #    D_loss_real_fake.backward()
                    D_loss.backward()

                    self.D_optimizer.step()

                    #WGAN
                    if (self.WGAN):
                        for p in self.D.parameters():
                            p.data.clamp_(-self.clipping, self.clipping) #Segun paper si el valor es muy chico lleva al banishing gradient
                    # Si se aplicaria la mejora en las WGANs tendiramos que sacar los batch normalizations de la red


                # update G network
                self.G_optimizer.zero_grad()

                G_, G_dep = self.G(y_, x_im, x_dep)

                if not ventaja and not self.onlyGen:
                    for p in self.D.parameters():
                        p.requires_grad = False  # to avoid computation

                    # Fake images
                    D_fake, _ = self.D(G_, x_im, G_dep, y_)

                    if (self.WGAN):
                        G_loss_fake = -torch.mean(D_fake) #de WGAN
                    else:
                        G_loss_fake = self.BCEWithLogitsLoss(D_fake, self.y_real_)

                    # loss between images (*)
                    #G_join = torch.cat((G_, G_dep), 1)
                    #y_join = torch.cat((y_im, y_dep), 1)

                    G_loss_Comp = self.L1(G_, y_im) 
                    if self.depth:
                      G_loss_Comp += self.L1(G_dep, y_dep)

                    G_loss_Dif_Comp = G_loss_Comp * self.lambdaL1

                    reverse_y = - y_ + 1
                    reverse_G, reverse_G_dep = self.G(reverse_y, G_, G_dep)
                    G_loss_Cycle = self.L1(reverse_G, x_im) 
                    if self.depth:
                      G_loss_Cycle += self.L1(reverse_G_dep, x_dep) 
                    G_loss_Cycle = G_loss_Cycle * self.lambdaL1/2


                    if (self.CR):
                        # Fake images augmented

                        G_aug, G_dep_aug = self.G(y_, x_im_aug, x_dep)
                        D_fake_aug, _ = self.D(G_aug, x_im, G_dep_aug, y_)

                        if (self.WGAN):
                            G_loss_fake = - (torch.mean(D_fake)+torch.mean(D_fake_aug))/2
                        else:
                            G_loss_fake = ( self.BCEWithLogitsLoss(D_fake, self.y_real_) +
                                            self.BCEWithLogitsLoss(D_fake_aug,self.y_real_)) / 2

                        # loss between images (*)
                        #y_aug_join = torch.cat((y_im_aug, y_dep), 1)
                        #G_aug_join = torch.cat((G_aug, G_dep_aug), 1)

                        G_loss_Comp_Aug = self.L1(G_aug, y_im_aug)
                        if self.depth:
                           G_loss_Comp_Aug += self.L1(G_dep_aug, y_dep)
                        G_loss_Dif_Comp = (G_loss_Comp + G_loss_Comp_Aug)/2 * self.lambdaL1


                    G_loss = G_loss_fake + G_loss_Dif_Comp + G_loss_Cycle

                    self.details_hist['G_T_BCE_fake_real'].append(G_loss_fake.detach().item())
                    self.details_hist['G_T_Comp_im'].append(G_loss_Dif_Comp.detach().item())
                    self.details_hist['G_T_Cycle'].append(G_loss_Cycle.detach().item())
                    self.details_hist['G_zCR'].append(0)


                else:

                    G_loss = self.L1(G_, y_im) 
                    if self.depth:
                      G_loss += self.L1(G_dep, y_dep)
                    G_loss = G_loss * self.lambdaL1
                    self.details_hist['G_T_Comp_im'].append(G_loss.detach().item())
                    self.details_hist['G_T_BCE_fake_real'].append(0)
                    self.details_hist['G_T_Cycle'].append(0)
                    self.details_hist['G_zCR'].append(0)

                G_loss.backward()
                self.G_optimizer.step()
                self.train_hist['G_loss_train'].append(G_loss.detach().item())
                if self.onlyGen:
                  self.train_hist['D_loss_train'].append(0)

                iterFinTrain += 1

                if self.visdom:
                  visLossGTest.plot('Generator_losses',
                                      ['G_T_Comp_im', 'G_T_BCE_fake_real', 'G_zCR','G_T_Cycle'],
                                       'train', self.details_hist)

                  vis.plot('loss', ['D_loss_train', 'G_loss_train'], 'train', self.train_hist)

            ##################Validation####################################
            with torch.no_grad():

                self.G.eval()
                if not self.onlyGen:
                  self.D.eval()

                for iter, data in enumerate(self.data_Validation):

                    # Aumento mi data
                    x_im = data.get('x_im')
                    x_dep = data.get('x_dep')
                    y_im = data.get('y_im')
                    y_dep = data.get('y_dep')
                    y_ = data.get('y_')
                    # x_im  = imagenes normales
                    # x_dep = profundidad de images
                    # y_im  = imagen con el angulo cambiado
                    # y_    = angulo de la imagen = tengo que tratar negativos

                    # x_im  = torch.Tensor(list(x_im))
                    # x_dep = torch.Tensor(x_dep)
                    # y_im  = torch.Tensor(y_im)
                    # print(y_.shape[0])
                    if iter == maxIterVal:
                        # print ("Break")
                        break
                    # print (y_.type(torch.LongTensor).unsqueeze(1))


                    # print("y_vec_", y_vec_)
                    # print ("z_", z_)

                    if self.gpu_mode:
                        x_im, y_, y_im, x_dep, y_dep = x_im.cuda(), y_.cuda(), y_im.cuda(), x_dep.cuda(), y_dep.cuda()
                    # D network

                    if not ventaja and not self.onlyGen:
                        # Real Images
                        D_real, _ = self.D(y_im, x_im, y_dep,y_)  ## Es la funcion forward `` g(z) x

                        # Fake Images
                        G_, G_dep = self.G(y_, x_im, x_dep)
                        D_fake, _ = self.D(G_, x_im, G_dep, y_)
                        # Losses
                        #  GAN Loss
                        if (self.WGAN):  # de WGAN
                            D_loss_real_fake_R = - torch.mean(D_real)
                            D_loss_real_fake_F = torch.mean(D_fake)

                        else:  # de Gan normal
                            D_loss_real_fake_R = self.BCEWithLogitsLoss(D_real, self.y_real_)
                            D_loss_real_fake_F = self.BCEWithLogitsLoss(D_fake, self.y_fake_)

                        D_loss_real_fake = D_loss_real_fake_F + D_loss_real_fake_R

                        D_loss = D_loss_real_fake

                        self.train_hist['D_loss_Validation'].append(D_loss.item())
                        self.details_hist['D_V_BCE_fake_real_R'].append(D_loss_real_fake_R.item())
                        self.details_hist['D_V_BCE_fake_real_F'].append(D_loss_real_fake_F.item())
                        if self.visdom:
                          visLossDValidation.plot('Discriminator_losses',
                                                     ['D_V_BCE_fake_real_R','D_V_BCE_fake_real_F'], 'Validation',
                                                     self.details_hist)

                    # G network

                    G_, G_dep = self.G(y_, x_im, x_dep)

                    if not ventaja and not self.onlyGen:
                        # Fake images
                        D_fake,_ = self.D(G_, x_im, G_dep, y_)

                        #Loss GAN
                        if (self.WGAN):
                            G_loss = -torch.mean(D_fake)  # porWGAN
                        else:
                            G_loss = self.BCEWithLogitsLoss(D_fake, self.y_real_) #de GAN NORMAL

                        self.details_hist['G_V_BCE_fake_real'].append(G_loss.item())

                        #Loss comparation
                        #G_join = torch.cat((G_, G_dep), 1)
                        #y_join = torch.cat((y_im, y_dep), 1)

                        G_loss_Comp = self.L1(G_, y_im)
                        if self.depth:
                          G_loss_Comp += self.L1(G_dep, y_dep)
                        G_loss_Comp = G_loss_Comp * self.lambdaL1

                        reverse_y = - y_ + 1                  
                        reverse_G, reverse_G_dep = self.G(reverse_y, G_, G_dep)
                        G_loss_Cycle = self.L1(reverse_G, x_im) 
                        if self.depth:
                          G_loss_Cycle += self.L1(reverse_G_dep, x_dep) 
                        G_loss_Cycle = G_loss_Cycle * self.lambdaL1/2

                        G_loss += G_loss_Comp + G_loss_Cycle 


                        self.details_hist['G_V_Comp_im'].append(G_loss_Comp.item())
                        self.details_hist['G_V_Cycle'].append(G_loss_Cycle.detach().item())

                    else:
                        G_loss = self.L1(G_, y_im) 
                        if self.depth:
                          G_loss += self.L1(G_dep, y_dep)
                        G_loss = G_loss * self.lambdaL1
                        self.details_hist['G_V_Comp_im'].append(G_loss.item())
                        self.details_hist['G_V_BCE_fake_real'].append(0)
                        self.details_hist['G_V_Cycle'].append(0)

                    self.train_hist['G_loss_Validation'].append(G_loss.item())
                    if self.onlyGen:
                      self.train_hist['D_loss_Validation'].append(0)


                    iterFinValidation += 1
                    if self.visdom:
                      visLossGValidation.plot('Generator_losses', ['G_V_Comp_im', 'G_V_BCE_fake_real','G_V_Cycle'],
                                                 'Validation', self.details_hist)
                      visValidation.plot('loss', ['D_loss_Validation', 'G_loss_Validation'], 'Validation',
                                           self.train_hist)

            ##Vis por epoch

            if ventaja or self.onlyGen:
                self.epoch_hist['D_loss_train'].append(0)
                self.epoch_hist['D_loss_Validation'].append(0)
            else:
                #inicioTr = (epoch - self.epochVentaja) * (iterFinTrain - iterIniTrain)
                #inicioTe = (epoch - self.epochVentaja) * (iterFinValidation - iterIniValidation)
                self.epoch_hist['D_loss_train'].append(mean(self.train_hist['D_loss_train'][iterIniTrain: -1]))
                self.epoch_hist['D_loss_Validation'].append(mean(self.train_hist['D_loss_Validation'][iterIniValidation: -1]))

            self.epoch_hist['G_loss_train'].append(mean(self.train_hist['G_loss_train'][iterIniTrain:iterFinTrain]))
            self.epoch_hist['G_loss_Validation'].append(
                mean(self.train_hist['G_loss_Validation'][iterIniValidation:iterFinValidation]))
            if self.visdom:
              visEpoch.plot('epoch', epoch,
                               ['D_loss_train', 'G_loss_train', 'D_loss_Validation', 'G_loss_Validation'],
                               self.epoch_hist)

            self.train_hist['D_loss_train'] = self.train_hist['D_loss_train'][-1:]
            self.train_hist['G_loss_train'] = self.train_hist['G_loss_train'][-1:]
            self.train_hist['D_loss_Validation'] = self.train_hist['D_loss_Validation'][-1:]
            self.train_hist['G_loss_Validation'] = self.train_hist['G_loss_Validation'][-1:]
            self.train_hist['per_epoch_time'] = self.train_hist['per_epoch_time'][-1:]
            self.train_hist['total_time'] = self.train_hist['total_time'][-1:]

            self.details_hist['G_T_Comp_im'] = self.details_hist['G_T_Comp_im'][-1:]
            self.details_hist['G_T_BCE_fake_real'] = self.details_hist['G_T_BCE_fake_real'][-1:]
            self.details_hist['G_T_Cycle'] = self.details_hist['G_T_Cycle'][-1:]
            self.details_hist['G_zCR'] = self.details_hist['G_zCR'][-1:]

            self.details_hist['G_V_Comp_im'] = self.details_hist['G_V_Comp_im'][-1:]
            self.details_hist['G_V_BCE_fake_real'] = self.details_hist['G_V_BCE_fake_real'][-1:]
            self.details_hist['G_V_Cycle'] = self.details_hist['G_V_Cycle'][-1:]

            self.details_hist['D_T_BCE_fake_real_R'] = self.details_hist['D_T_BCE_fake_real_R'][-1:]
            self.details_hist['D_T_BCE_fake_real_F'] = self.details_hist['D_T_BCE_fake_real_F'][-1:]
            self.details_hist['D_zCR'] = self.details_hist['D_zCR'][-1:]
            self.details_hist['D_bCR'] = self.details_hist['D_bCR'][-1:]

            self.details_hist['D_V_BCE_fake_real_R'] = self.details_hist['D_V_BCE_fake_real_R'][-1:]
            self.details_hist['D_V_BCE_fake_real_F'] = self.details_hist['D_V_BCE_fake_real_F'][-1:]
            ##Para poder tomar el promedio por epoch
            iterIniTrain = 1
            iterFinTrain = 1

            iterIniValidation = 1
            iterFinValidation = 1

            self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)

            if epoch % 10 == 0:
                self.save(str(epoch))
                with torch.no_grad():
                    if self.visdom:
                      self.visualize_results(epoch, dataprint=self.dataprint, visual=visImages)
                      self.visualize_results(epoch, dataprint=self.dataprint_test, visual=visImagesTest)
                    else:
                      imageName = self.model_name + '_' + 'Train' + '_' + str(self.seed) + '_' + str(epoch)
                      self.visualize_results(epoch, dataprint=self.dataprint, name= imageName)
                      self.visualize_results(epoch, dataprint=self.dataprint_test, name= imageName)


        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
                                                                        self.epoch, self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        #utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
        #                         self.epoch)
        #utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)

    def visualize_results(self, epoch, dataprint, visual="", name= "test"):
        with torch.no_grad():
            self.G.eval()

            #if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
            #    os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)

            # print("sample z: ",self.sample_z_,"sample y:", self.sample_y_)

            ##Podria hacer un loop
            # .zfill(4)
            #newSample = None
            #print(dataprint.shape)

            #newSample = torch.tensor([])

            #se que es ineficiente pero lo hago cada 10 epoch nomas
            newSample = []
            iter = 1
            for x_im,x_dep in zip(dataprint.get('x_im'), dataprint.get('x_dep')):
                if (iter > self.cantImages):
                    break

                #x_im = (x_im + 1) / 2
                #imgX = transforms.ToPILImage()(x_im)
                #imgX.show()

                x_im_input = x_im.repeat(2, 1, 1, 1)
                x_dep_input = x_dep.repeat(2, 1, 1, 1)

                sizeImage = x_im.shape[2]

                sample_y_ = torch.zeros((self.class_num, 1, sizeImage, sizeImage))
                for i in range(self.class_num):
                    if(int(i % self.class_num) == 1):
                        sample_y_[i] = torch.ones(( 1, sizeImage, sizeImage))

                if self.gpu_mode:
                    sample_y_, x_im_input, x_dep_input = sample_y_.cuda(), x_im_input.cuda(), x_dep_input.cuda()

                G_im, G_dep = self.G(sample_y_, x_im_input, x_dep_input)

                newSample.append(x_im.squeeze(0))
                newSample.append(x_dep.squeeze(0).expand(3, -1, -1))



                if self.wiggle:
                    im_aux, im_dep_aux = G_im, G_dep
                    for i in range(0, 2):
                        index = i
                        for j in range(0, self.wiggleDepth):

                            # print(i,j)

                            if (j == 0 and i == 1):
                                # para tomar el original
                                im_aux, im_dep_aux = G_im, G_dep
                                newSample.append(G_im.cpu()[0].squeeze(0))
                                newSample.append(G_im.cpu()[1].squeeze(0))
                            elif (i == 1):
                                # por el problema de las iteraciones proximas
                                index = 0

                            # imagen generada


                            x = im_aux[index].unsqueeze(0)
                            x_dep = im_dep_aux[index].unsqueeze(0)

                            y = sample_y_[i].unsqueeze(0)

                            if self.gpu_mode:
                                y, x, x_dep = y.cuda(), x.cuda(), x_dep.cuda()

                            im_aux, im_dep_aux = self.G(y, x, x_dep)

                            newSample.append(im_aux.cpu()[0])
                else:

                    newSample.append(G_im.cpu()[0])
                    newSample.append(G_im.cpu()[1])
                    newSample.append(G_dep.cpu()[0].expand(3, -1, -1))
                    newSample.append(G_dep.cpu()[1].expand(3, -1, -1))
                    # sadadas

                iter+=1

            if self.visdom:
                visual.plot(epoch, newSample, int(len(newSample) /self.cantImages))
            else:
                utils.save_wiggle(newSample, self.cantImages, name)
        ##TENGO QUE HACER QUE SAMPLES TENGAN COMO MAXIMO self.class_num * self.class_num

        # utils.save_images(newSample[:, :, :, :], [image_frame_dim * cantidadIm , image_frame_dim * (self.class_num+2)],
        #                  self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%04d' % epoch + '.png')

    def show_plot_images(self, images, cols=1, titles=None):
        """Display a list of images in a single figure with matplotlib.

        Parameters
        ---------
        images: List of np.arrays compatible with plt.imshow.

        cols (Default = 1): Number of columns in figure (number of rows is
                            set to np.ceil(n_images/float(cols))).

        titles: List of titles corresponding to each image. Must have
                the same length as titles.
        """
        # assert ((titles is None) or (len(images) == len(titles)))
        n_images = len(images)
        if titles is None: titles = ['Image (%d)' % i for i in range(1, n_images + 1)]
        fig = plt.figure()
        for n, (image, title) in enumerate(zip(images, titles)):
            a = fig.add_subplot(np.ceil(n_images / float(cols)), cols, n + 1)
            # print(image)
            image = (image + 1) * 255.0
            # print(image)
            # new_im = Image.fromarray(image)
            # print(new_im)
            if image.ndim == 2:
                plt.gray()
            # print("spi imshape ", image.shape)
            plt.imshow(image)
            a.set_title(title)
        fig.set_size_inches(np.array(fig.get_size_inches()) * n_images)
        plt.show()

    def joinImages(self, data):
        nData = []
        for i in range(self.class_num):
            nData.append(data)
        nData = np.array(nData)
        nData = torch.tensor(nData.tolist())
        nData = nData.type(torch.FloatTensor)

        return nData

    def save(self, epoch=''):
        save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)

        if not os.path.exists(save_dir):
            os.makedirs(save_dir)

        torch.save(self.G.state_dict(),
                   os.path.join(save_dir, self.model_name + '_' + self.seed + '_' + epoch + '_G.pkl'))
        if not self.onlyGen:
          torch.save(self.D.state_dict(),
                   os.path.join(save_dir, self.model_name + '_' + self.seed + '_' + epoch + '_D.pkl'))

        with open(os.path.join(save_dir, self.model_name + '_history_ '+self.seed+'.pkl'), 'wb') as f:
            pickle.dump(self.train_hist, f)

    def load(self):
        save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)

        self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_G.pkl')))
        if not self.wiggle:
            self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_D.pkl')))

    def wiggleEf(self):
        seed, epoch = self.seed_load.split('_')
        if self.visdom:
            visWiggle = utils.VisdomImagePlotter(env_name='Cobo_depth_wiggle_' + seed)
            self.visualize_results(epoch=epoch, dataprint=self.dataprint_test, visual=visWiggle)
        else:
            self.visualize_results(epoch=epoch, dataprint=self.dataprint_test, visual=None, name = self.name_wiggle)

    def recreate(self):

      dataloader_recreate = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='score')
      with torch.no_grad():
        self.G.eval()
        accum = 0
        for data_batch in dataloader_recreate.__iter__():
          
          #{'x_im': x1, 'x_dep': x1_dep, 'y_im': x2, 'y_dep': x2_dep, 'y_': torch.ones(1, self.imageDim, self.imageDim)}
          left,left_depth,right,right_depth,direction = data_batch.values()

          if self.gpu_mode:
            left,left_depth,right,right_depth,direction = left.cuda(),left_depth.cuda(),right.cuda(),right_depth.cuda(),direction.cuda()

          G_right, G_right_dep = self.G( direction, left, left_depth)
          
          reverse_direction = direction * 0 
          G_left, G_left_dep = self.G( reverse_direction, right, right_depth)

          for index in range(0,self.batch_size):
            image_right = (G_right[index] + 1.0)/2.0
            image_right_dep = (G_right_dep[index] + 1.0)/2.0

            image_left = (G_left[index] + 1.0)/2.0
            image_left_dep = (G_left_dep[index] + 1.0)/2.0

            

            save_image(image_right, os.path.join("results","recreate_dataset","CAM1","n_{num:0{width}}.png".format(num = index+accum, width = 4)))
            save_image(image_right_dep, os.path.join("results","recreate_dataset","CAM1","d_{num:0{width}}.png".format(num = index+accum, width = 4)))

            save_image(image_left, os.path.join("results","recreate_dataset","CAM0","n_{num:0{width}}.png".format(num = index+accum, width = 4)))
            save_image(image_left_dep, os.path.join("results","recreate_dataset","CAM0","d_{num:0{width}}.png".format(num = index+accum, width = 4)))
          accum+= self.batch_size