Rodrigo_Cobo
commited on
Commit
•
cc6c676
1
Parent(s):
9e2cd5a
added thesis
Browse files- .gitignore +7 -0
- Images/Input-Test/1.png +0 -0
- Images/Input-Test/10.png +0 -0
- Images/Input-Test/11.png +0 -0
- Images/Input-Test/12.png +0 -0
- Images/Input-Test/2.png +0 -0
- Images/Input-Test/3.png +0 -0
- Images/Input-Test/4.png +0 -0
- Images/Input-Test/6.png +0 -0
- Images/Input-Test/7.png +0 -0
- Images/Input-Test/8.png +0 -0
- Images/Input-Test/9.png +0 -0
- WiggleGAN.py +833 -0
- WiggleResults/split.py +91 -0
- app.py +6 -6
- architectures.py +1094 -0
- config.ini +259 -0
- dataloader.py +301 -0
- epochData.pkl +3 -0
- main.py +136 -0
- models/4cam/WiggleGAN/WiggleGAN_31219_110_G.pkl +3 -0
- models/4cam/WiggleGAN/WiggleGAN_66942_110_G.pkl +3 -0
- models/4cam/WiggleGAN/WiggleGAN_70466_110_G.pkl +3 -0
- models/4cam/WiggleGAN/WiggleGAN_70944_110_G.pkl +3 -0
- models/4cam/WiggleGAN/WiggleGAN_74962_110_G.pkl +3 -0
- models/4cam/WiggleGAN/WiggleGAN_82122_110_G.pkl +3 -0
- models/4cam/WiggleGAN/WiggleGAN_92332_110_G.pkl +3 -0
- pyvenv.cfg +3 -0
- requirements.txt +25 -2
- utils.py +369 -0
.gitignore
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/*
|
2 |
+
Scripts/*
|
3 |
+
Include/*
|
4 |
+
Lib/*
|
5 |
+
logs/*
|
6 |
+
WiggleGAN_mod.py
|
7 |
+
WiggleGAN_noCycle.py
|
Images/Input-Test/1.png
ADDED
Images/Input-Test/10.png
ADDED
Images/Input-Test/11.png
ADDED
Images/Input-Test/12.png
ADDED
Images/Input-Test/2.png
ADDED
Images/Input-Test/3.png
ADDED
Images/Input-Test/4.png
ADDED
Images/Input-Test/6.png
ADDED
Images/Input-Test/7.png
ADDED
Images/Input-Test/8.png
ADDED
Images/Input-Test/9.png
ADDED
WiggleGAN.py
ADDED
@@ -0,0 +1,833 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import utils, torch, time, os, pickle
|
2 |
+
import numpy as np
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.cuda as cu
|
5 |
+
import torch.optim as optim
|
6 |
+
import pickle
|
7 |
+
from torchvision import transforms
|
8 |
+
from torchvision.utils import save_image
|
9 |
+
from utils import augmentData, RGBtoL, LtoRGB
|
10 |
+
from PIL import Image
|
11 |
+
from dataloader import dataloader
|
12 |
+
from torch.autograd import Variable
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import random
|
15 |
+
from datetime import date
|
16 |
+
from statistics import mean
|
17 |
+
from architectures import depth_generator_UNet, \
|
18 |
+
depth_discriminator_noclass_UNet
|
19 |
+
|
20 |
+
|
21 |
+
class WiggleGAN(object):
|
22 |
+
def __init__(self, args):
|
23 |
+
# parameters
|
24 |
+
self.epoch = args.epoch
|
25 |
+
self.sample_num = 100
|
26 |
+
self.nCameras = args.cameras
|
27 |
+
self.batch_size = args.batch_size
|
28 |
+
self.save_dir = args.save_dir
|
29 |
+
self.result_dir = args.result_dir
|
30 |
+
self.dataset = args.dataset
|
31 |
+
self.log_dir = args.log_dir
|
32 |
+
self.gpu_mode = args.gpu_mode
|
33 |
+
self.model_name = args.gan_type
|
34 |
+
self.input_size = args.input_size
|
35 |
+
self.class_num = (args.cameras - 1) * 2 # un calculo que hice en paint
|
36 |
+
self.sample_num = self.class_num ** 2
|
37 |
+
self.imageDim = args.imageDim
|
38 |
+
self.epochVentaja = args.epochV
|
39 |
+
self.cantImages = args.cIm
|
40 |
+
self.visdom = args.visdom
|
41 |
+
self.lambdaL1 = args.lambdaL1
|
42 |
+
self.depth = args.depth
|
43 |
+
self.name_wiggle = args.name_wiggle
|
44 |
+
|
45 |
+
self.clipping = args.clipping
|
46 |
+
self.WGAN = False
|
47 |
+
if (self.clipping > 0):
|
48 |
+
self.WGAN = True
|
49 |
+
|
50 |
+
self.seed = str(random.randint(0, 99999))
|
51 |
+
self.seed_load = args.seedLoad
|
52 |
+
self.toLoad = False
|
53 |
+
if (self.seed_load != "-0000"):
|
54 |
+
self.toLoad = True
|
55 |
+
|
56 |
+
self.zGenFactor = args.zGF
|
57 |
+
self.zDisFactor = args.zDF
|
58 |
+
self.bFactor = args.bF
|
59 |
+
self.CR = False
|
60 |
+
if (self.zGenFactor > 0 or self.zDisFactor > 0 or self.bFactor > 0):
|
61 |
+
self.CR = True
|
62 |
+
|
63 |
+
self.expandGen = args.expandGen
|
64 |
+
self.expandDis = args.expandDis
|
65 |
+
|
66 |
+
self.wiggleDepth = args.wiggleDepth
|
67 |
+
self.wiggle = False
|
68 |
+
if (self.wiggleDepth > 0):
|
69 |
+
self.wiggle = True
|
70 |
+
|
71 |
+
|
72 |
+
|
73 |
+
# load dataset
|
74 |
+
|
75 |
+
self.onlyGen = args.lrD <= 0
|
76 |
+
|
77 |
+
if not self.wiggle:
|
78 |
+
self.data_loader = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='train',
|
79 |
+
trans=not self.CR)
|
80 |
+
|
81 |
+
self.data_Validation = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim,
|
82 |
+
split='validation')
|
83 |
+
|
84 |
+
self.dataprint = self.data_Validation.__iter__().__next__()
|
85 |
+
|
86 |
+
data = self.data_loader.__iter__().__next__().get('x_im')
|
87 |
+
|
88 |
+
|
89 |
+
if not self.onlyGen:
|
90 |
+
self.D = depth_discriminator_noclass_UNet(input_dim=3, output_dim=1, input_shape=data.shape,
|
91 |
+
class_num=self.class_num,
|
92 |
+
expand_net=self.expandDis, depth = self.depth, wgan = self.WGAN)
|
93 |
+
self.D_optimizer = optim.Adam(self.D.parameters(), lr=args.lrD, betas=(args.beta1, args.beta2))
|
94 |
+
|
95 |
+
self.data_Test = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='test')
|
96 |
+
self.dataprint_test = self.data_Test.__iter__().__next__()
|
97 |
+
|
98 |
+
# networks init
|
99 |
+
|
100 |
+
self.G = depth_generator_UNet(input_dim=4, output_dim=3, class_num=self.class_num, expand_net=self.expandGen, depth = self.depth)
|
101 |
+
self.G_optimizer = optim.Adam(self.G.parameters(), lr=args.lrG, betas=(args.beta1, args.beta2))
|
102 |
+
|
103 |
+
|
104 |
+
if self.gpu_mode:
|
105 |
+
self.G.cuda()
|
106 |
+
if not self.wiggle and not self.onlyGen:
|
107 |
+
self.D.cuda()
|
108 |
+
self.BCE_loss = nn.BCELoss().cuda()
|
109 |
+
self.CE_loss = nn.CrossEntropyLoss().cuda()
|
110 |
+
self.L1 = nn.L1Loss().cuda()
|
111 |
+
self.MSE = nn.MSELoss().cuda()
|
112 |
+
self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss().cuda()
|
113 |
+
else:
|
114 |
+
self.BCE_loss = nn.BCELoss()
|
115 |
+
self.CE_loss = nn.CrossEntropyLoss()
|
116 |
+
self.MSE = nn.MSELoss()
|
117 |
+
self.L1 = nn.L1Loss()
|
118 |
+
self.BCEWithLogitsLoss = nn.BCEWithLogitsLoss()
|
119 |
+
|
120 |
+
print('---------- Networks architecture -------------')
|
121 |
+
utils.print_network(self.G)
|
122 |
+
if not self.wiggle and not self.onlyGen:
|
123 |
+
utils.print_network(self.D)
|
124 |
+
print('-----------------------------------------------')
|
125 |
+
|
126 |
+
temp = torch.zeros((self.class_num, 1))
|
127 |
+
for i in range(self.class_num):
|
128 |
+
temp[i, 0] = i
|
129 |
+
|
130 |
+
temp_y = torch.zeros((self.sample_num, 1))
|
131 |
+
for i in range(self.class_num):
|
132 |
+
temp_y[i * self.class_num: (i + 1) * self.class_num] = temp
|
133 |
+
|
134 |
+
self.sample_y_ = torch.zeros((self.sample_num, self.class_num)).scatter_(1, temp_y.type(torch.LongTensor), 1)
|
135 |
+
if self.gpu_mode:
|
136 |
+
self.sample_y_ = self.sample_y_.cuda()
|
137 |
+
|
138 |
+
if (self.toLoad):
|
139 |
+
self.load()
|
140 |
+
|
141 |
+
def train(self):
|
142 |
+
|
143 |
+
if self.visdom:
|
144 |
+
random.seed(time.time())
|
145 |
+
today = date.today()
|
146 |
+
|
147 |
+
vis = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
|
148 |
+
visValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
|
149 |
+
visEpoch = utils.VisdomLineTwoPlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
|
150 |
+
visImages = utils.VisdomImagePlotter(env_name='Cobo_depth_Images_' + str(today) + '_' + self.seed)
|
151 |
+
visImagesTest = utils.VisdomImagePlotter(env_name='Cobo_depth_ImagesTest_' + str(today) + '_' + self.seed)
|
152 |
+
|
153 |
+
visLossGTest = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
|
154 |
+
visLossGValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
|
155 |
+
|
156 |
+
visLossDTest = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
|
157 |
+
visLossDValidation = utils.VisdomLinePlotter(env_name='Cobo_depth_Train-Plots_' + str(today) + '_' + self.seed)
|
158 |
+
|
159 |
+
self.train_hist = {}
|
160 |
+
self.epoch_hist = {}
|
161 |
+
self.details_hist = {}
|
162 |
+
self.train_hist['D_loss_train'] = []
|
163 |
+
self.train_hist['G_loss_train'] = []
|
164 |
+
self.train_hist['D_loss_Validation'] = []
|
165 |
+
self.train_hist['G_loss_Validation'] = []
|
166 |
+
self.train_hist['per_epoch_time'] = []
|
167 |
+
self.train_hist['total_time'] = []
|
168 |
+
|
169 |
+
self.details_hist['G_T_Comp_im'] = []
|
170 |
+
self.details_hist['G_T_BCE_fake_real'] = []
|
171 |
+
self.details_hist['G_T_Cycle'] = []
|
172 |
+
self.details_hist['G_zCR'] = []
|
173 |
+
|
174 |
+
self.details_hist['G_V_Comp_im'] = []
|
175 |
+
self.details_hist['G_V_BCE_fake_real'] = []
|
176 |
+
self.details_hist['G_V_Cycle'] = []
|
177 |
+
|
178 |
+
self.details_hist['D_T_BCE_fake_real_R'] = []
|
179 |
+
self.details_hist['D_T_BCE_fake_real_F'] = []
|
180 |
+
self.details_hist['D_zCR'] = []
|
181 |
+
self.details_hist['D_bCR'] = []
|
182 |
+
|
183 |
+
self.details_hist['D_V_BCE_fake_real_R'] = []
|
184 |
+
self.details_hist['D_V_BCE_fake_real_F'] = []
|
185 |
+
|
186 |
+
self.epoch_hist['D_loss_train'] = []
|
187 |
+
self.epoch_hist['G_loss_train'] = []
|
188 |
+
self.epoch_hist['D_loss_Validation'] = []
|
189 |
+
self.epoch_hist['G_loss_Validation'] = []
|
190 |
+
|
191 |
+
##Para poder tomar el promedio por epoch
|
192 |
+
iterIniTrain = 0
|
193 |
+
iterFinTrain = 0
|
194 |
+
|
195 |
+
iterIniValidation = 0
|
196 |
+
iterFinValidation = 0
|
197 |
+
|
198 |
+
maxIter = self.data_loader.dataset.__len__() // self.batch_size
|
199 |
+
maxIterVal = self.data_Validation.dataset.__len__() // self.batch_size
|
200 |
+
|
201 |
+
if (self.WGAN):
|
202 |
+
one = torch.tensor(1, dtype=torch.float).cuda()
|
203 |
+
mone = one * -1
|
204 |
+
else:
|
205 |
+
self.y_real_ = torch.ones(self.batch_size, 1)
|
206 |
+
self.y_fake_ = torch.zeros(self.batch_size, 1)
|
207 |
+
if self.gpu_mode:
|
208 |
+
self.y_real_, self.y_fake_ = self.y_real_.cuda(), self.y_fake_.cuda()
|
209 |
+
|
210 |
+
print('training start!!')
|
211 |
+
start_time = time.time()
|
212 |
+
|
213 |
+
for epoch in range(self.epoch):
|
214 |
+
|
215 |
+
if (epoch < self.epochVentaja):
|
216 |
+
ventaja = True
|
217 |
+
else:
|
218 |
+
ventaja = False
|
219 |
+
|
220 |
+
self.G.train()
|
221 |
+
|
222 |
+
if not self.onlyGen:
|
223 |
+
self.D.train()
|
224 |
+
epoch_start_time = time.time()
|
225 |
+
|
226 |
+
|
227 |
+
# TRAIN!!!
|
228 |
+
for iter, data in enumerate(self.data_loader):
|
229 |
+
|
230 |
+
x_im = data.get('x_im')
|
231 |
+
x_dep = data.get('x_dep')
|
232 |
+
y_im = data.get('y_im')
|
233 |
+
y_dep = data.get('y_dep')
|
234 |
+
y_ = data.get('y_')
|
235 |
+
|
236 |
+
# x_im = imagenes normales
|
237 |
+
# x_dep = profundidad de images
|
238 |
+
# y_im = imagen con el angulo cambiado
|
239 |
+
# y_ = angulo de la imagen = tengo que tratar negativos
|
240 |
+
|
241 |
+
# Aumento mi data
|
242 |
+
if (self.CR):
|
243 |
+
x_im_aug, y_im_aug = augmentData(x_im, y_im)
|
244 |
+
x_im_vanilla = x_im
|
245 |
+
|
246 |
+
if self.gpu_mode:
|
247 |
+
x_im_aug, y_im_aug = x_im_aug.cuda(), y_im_aug.cuda()
|
248 |
+
|
249 |
+
if iter >= maxIter:
|
250 |
+
break
|
251 |
+
|
252 |
+
if self.gpu_mode:
|
253 |
+
x_im, y_, y_im, x_dep, y_dep = x_im.cuda(), y_.cuda(), y_im.cuda(), x_dep.cuda(), y_dep.cuda()
|
254 |
+
|
255 |
+
# update D network
|
256 |
+
if not ventaja and not self.onlyGen:
|
257 |
+
|
258 |
+
for p in self.D.parameters(): # reset requires_grad
|
259 |
+
p.requires_grad = True # they are set to False below in netG update
|
260 |
+
|
261 |
+
self.D_optimizer.zero_grad()
|
262 |
+
|
263 |
+
# Real Images
|
264 |
+
D_real, D_features_real = self.D(y_im, x_im, y_dep, y_) ## Es la funcion forward `` g(z) x
|
265 |
+
|
266 |
+
# Fake Images
|
267 |
+
G_, G_dep = self.G( y_, x_im, x_dep)
|
268 |
+
D_fake, D_features_fake = self.D(G_, x_im, G_dep, y_)
|
269 |
+
|
270 |
+
# Losses
|
271 |
+
# GAN Loss
|
272 |
+
if (self.WGAN): # de WGAN
|
273 |
+
D_loss_real_fake_R = - torch.mean(D_real)
|
274 |
+
D_loss_real_fake_F = torch.mean(D_fake)
|
275 |
+
#D_loss_real_fake_R = - D_loss_real_fake_R_positive
|
276 |
+
|
277 |
+
else: # de Gan normal
|
278 |
+
D_loss_real_fake_R = self.BCEWithLogitsLoss(D_real, self.y_real_)
|
279 |
+
D_loss_real_fake_F = self.BCEWithLogitsLoss(D_fake, self.y_fake_)
|
280 |
+
|
281 |
+
D_loss = D_loss_real_fake_F + D_loss_real_fake_R
|
282 |
+
|
283 |
+
if self.CR:
|
284 |
+
|
285 |
+
# Fake Augmented Images bCR
|
286 |
+
x_im_aug_bCR, G_aug_bCR = augmentData(x_im_vanilla, G_.data.cpu())
|
287 |
+
|
288 |
+
if self.gpu_mode:
|
289 |
+
G_aug_bCR, x_im_aug_bCR = G_aug_bCR.cuda(), x_im_aug_bCR.cuda()
|
290 |
+
|
291 |
+
D_fake_bCR, D_features_fake_bCR = self.D(G_aug_bCR, x_im_aug_bCR, G_dep, y_)
|
292 |
+
D_real_bCR, D_features_real_bCR = self.D(y_im_aug, x_im_aug, y_dep, y_)
|
293 |
+
|
294 |
+
# Fake Augmented Images zCR
|
295 |
+
G_aug_zCR, G_dep_aug_zCR = self.G(y_, x_im_aug, x_dep)
|
296 |
+
D_fake_aug_zCR, D_features_fake_aug_zCR = self.D(G_aug_zCR, x_im_aug, G_dep_aug_zCR, y_)
|
297 |
+
|
298 |
+
# bCR Loss (*)
|
299 |
+
D_loss_real = self.MSE(D_features_real, D_features_real_bCR)
|
300 |
+
D_loss_fake = self.MSE(D_features_fake, D_features_fake_bCR)
|
301 |
+
D_bCR = (D_loss_real + D_loss_fake) * self.bFactor
|
302 |
+
|
303 |
+
# zCR Loss
|
304 |
+
D_zCR = self.MSE(D_features_fake, D_features_fake_aug_zCR) * self.zDisFactor
|
305 |
+
|
306 |
+
D_CR_losses = D_bCR + D_zCR
|
307 |
+
#D_CR_losses.backward(retain_graph=True)
|
308 |
+
|
309 |
+
D_loss += D_CR_losses
|
310 |
+
|
311 |
+
self.details_hist['D_bCR'].append(D_bCR.detach().item())
|
312 |
+
self.details_hist['D_zCR'].append(D_zCR.detach().item())
|
313 |
+
else:
|
314 |
+
self.details_hist['D_bCR'].append(0)
|
315 |
+
self.details_hist['D_zCR'].append(0)
|
316 |
+
|
317 |
+
self.train_hist['D_loss_train'].append(D_loss.detach().item())
|
318 |
+
self.details_hist['D_T_BCE_fake_real_R'].append(D_loss_real_fake_R.detach().item())
|
319 |
+
self.details_hist['D_T_BCE_fake_real_F'].append(D_loss_real_fake_F.detach().item())
|
320 |
+
if self.visdom:
|
321 |
+
visLossDTest.plot('Discriminator_losses',
|
322 |
+
['D_T_BCE_fake_real_R','D_T_BCE_fake_real_F', 'D_bCR', 'D_zCR'], 'train',
|
323 |
+
self.details_hist)
|
324 |
+
#if self.WGAN:
|
325 |
+
# D_loss_real_fake_F.backward(retain_graph=True)
|
326 |
+
# D_loss_real_fake_R_positive.backward(mone)
|
327 |
+
#else:
|
328 |
+
# D_loss_real_fake.backward()
|
329 |
+
D_loss.backward()
|
330 |
+
|
331 |
+
self.D_optimizer.step()
|
332 |
+
|
333 |
+
#WGAN
|
334 |
+
if (self.WGAN):
|
335 |
+
for p in self.D.parameters():
|
336 |
+
p.data.clamp_(-self.clipping, self.clipping) #Segun paper si el valor es muy chico lleva al banishing gradient
|
337 |
+
# Si se aplicaria la mejora en las WGANs tendiramos que sacar los batch normalizations de la red
|
338 |
+
|
339 |
+
|
340 |
+
# update G network
|
341 |
+
self.G_optimizer.zero_grad()
|
342 |
+
|
343 |
+
G_, G_dep = self.G(y_, x_im, x_dep)
|
344 |
+
|
345 |
+
if not ventaja and not self.onlyGen:
|
346 |
+
for p in self.D.parameters():
|
347 |
+
p.requires_grad = False # to avoid computation
|
348 |
+
|
349 |
+
# Fake images
|
350 |
+
D_fake, _ = self.D(G_, x_im, G_dep, y_)
|
351 |
+
|
352 |
+
if (self.WGAN):
|
353 |
+
G_loss_fake = -torch.mean(D_fake) #de WGAN
|
354 |
+
else:
|
355 |
+
G_loss_fake = self.BCEWithLogitsLoss(D_fake, self.y_real_)
|
356 |
+
|
357 |
+
# loss between images (*)
|
358 |
+
#G_join = torch.cat((G_, G_dep), 1)
|
359 |
+
#y_join = torch.cat((y_im, y_dep), 1)
|
360 |
+
|
361 |
+
G_loss_Comp = self.L1(G_, y_im)
|
362 |
+
if self.depth:
|
363 |
+
G_loss_Comp += self.L1(G_dep, y_dep)
|
364 |
+
|
365 |
+
G_loss_Dif_Comp = G_loss_Comp * self.lambdaL1
|
366 |
+
|
367 |
+
reverse_y = - y_ + 1
|
368 |
+
reverse_G, reverse_G_dep = self.G(reverse_y, G_, G_dep)
|
369 |
+
G_loss_Cycle = self.L1(reverse_G, x_im)
|
370 |
+
if self.depth:
|
371 |
+
G_loss_Cycle += self.L1(reverse_G_dep, x_dep)
|
372 |
+
G_loss_Cycle = G_loss_Cycle * self.lambdaL1/2
|
373 |
+
|
374 |
+
|
375 |
+
if (self.CR):
|
376 |
+
# Fake images augmented
|
377 |
+
|
378 |
+
G_aug, G_dep_aug = self.G(y_, x_im_aug, x_dep)
|
379 |
+
D_fake_aug, _ = self.D(G_aug, x_im, G_dep_aug, y_)
|
380 |
+
|
381 |
+
if (self.WGAN):
|
382 |
+
G_loss_fake = - (torch.mean(D_fake)+torch.mean(D_fake_aug))/2
|
383 |
+
else:
|
384 |
+
G_loss_fake = ( self.BCEWithLogitsLoss(D_fake, self.y_real_) +
|
385 |
+
self.BCEWithLogitsLoss(D_fake_aug,self.y_real_)) / 2
|
386 |
+
|
387 |
+
# loss between images (*)
|
388 |
+
#y_aug_join = torch.cat((y_im_aug, y_dep), 1)
|
389 |
+
#G_aug_join = torch.cat((G_aug, G_dep_aug), 1)
|
390 |
+
|
391 |
+
G_loss_Comp_Aug = self.L1(G_aug, y_im_aug)
|
392 |
+
if self.depth:
|
393 |
+
G_loss_Comp_Aug += self.L1(G_dep_aug, y_dep)
|
394 |
+
G_loss_Dif_Comp = (G_loss_Comp + G_loss_Comp_Aug)/2 * self.lambdaL1
|
395 |
+
|
396 |
+
|
397 |
+
G_loss = G_loss_fake + G_loss_Dif_Comp + G_loss_Cycle
|
398 |
+
|
399 |
+
self.details_hist['G_T_BCE_fake_real'].append(G_loss_fake.detach().item())
|
400 |
+
self.details_hist['G_T_Comp_im'].append(G_loss_Dif_Comp.detach().item())
|
401 |
+
self.details_hist['G_T_Cycle'].append(G_loss_Cycle.detach().item())
|
402 |
+
self.details_hist['G_zCR'].append(0)
|
403 |
+
|
404 |
+
|
405 |
+
else:
|
406 |
+
|
407 |
+
G_loss = self.L1(G_, y_im)
|
408 |
+
if self.depth:
|
409 |
+
G_loss += self.L1(G_dep, y_dep)
|
410 |
+
G_loss = G_loss * self.lambdaL1
|
411 |
+
self.details_hist['G_T_Comp_im'].append(G_loss.detach().item())
|
412 |
+
self.details_hist['G_T_BCE_fake_real'].append(0)
|
413 |
+
self.details_hist['G_T_Cycle'].append(0)
|
414 |
+
self.details_hist['G_zCR'].append(0)
|
415 |
+
|
416 |
+
G_loss.backward()
|
417 |
+
self.G_optimizer.step()
|
418 |
+
self.train_hist['G_loss_train'].append(G_loss.detach().item())
|
419 |
+
if self.onlyGen:
|
420 |
+
self.train_hist['D_loss_train'].append(0)
|
421 |
+
|
422 |
+
iterFinTrain += 1
|
423 |
+
|
424 |
+
if self.visdom:
|
425 |
+
visLossGTest.plot('Generator_losses',
|
426 |
+
['G_T_Comp_im', 'G_T_BCE_fake_real', 'G_zCR','G_T_Cycle'],
|
427 |
+
'train', self.details_hist)
|
428 |
+
|
429 |
+
vis.plot('loss', ['D_loss_train', 'G_loss_train'], 'train', self.train_hist)
|
430 |
+
|
431 |
+
##################Validation####################################
|
432 |
+
with torch.no_grad():
|
433 |
+
|
434 |
+
self.G.eval()
|
435 |
+
if not self.onlyGen:
|
436 |
+
self.D.eval()
|
437 |
+
|
438 |
+
for iter, data in enumerate(self.data_Validation):
|
439 |
+
|
440 |
+
# Aumento mi data
|
441 |
+
x_im = data.get('x_im')
|
442 |
+
x_dep = data.get('x_dep')
|
443 |
+
y_im = data.get('y_im')
|
444 |
+
y_dep = data.get('y_dep')
|
445 |
+
y_ = data.get('y_')
|
446 |
+
# x_im = imagenes normales
|
447 |
+
# x_dep = profundidad de images
|
448 |
+
# y_im = imagen con el angulo cambiado
|
449 |
+
# y_ = angulo de la imagen = tengo que tratar negativos
|
450 |
+
|
451 |
+
# x_im = torch.Tensor(list(x_im))
|
452 |
+
# x_dep = torch.Tensor(x_dep)
|
453 |
+
# y_im = torch.Tensor(y_im)
|
454 |
+
# print(y_.shape[0])
|
455 |
+
if iter == maxIterVal:
|
456 |
+
# print ("Break")
|
457 |
+
break
|
458 |
+
# print (y_.type(torch.LongTensor).unsqueeze(1))
|
459 |
+
|
460 |
+
|
461 |
+
# print("y_vec_", y_vec_)
|
462 |
+
# print ("z_", z_)
|
463 |
+
|
464 |
+
if self.gpu_mode:
|
465 |
+
x_im, y_, y_im, x_dep, y_dep = x_im.cuda(), y_.cuda(), y_im.cuda(), x_dep.cuda(), y_dep.cuda()
|
466 |
+
# D network
|
467 |
+
|
468 |
+
if not ventaja and not self.onlyGen:
|
469 |
+
# Real Images
|
470 |
+
D_real, _ = self.D(y_im, x_im, y_dep,y_) ## Es la funcion forward `` g(z) x
|
471 |
+
|
472 |
+
# Fake Images
|
473 |
+
G_, G_dep = self.G(y_, x_im, x_dep)
|
474 |
+
D_fake, _ = self.D(G_, x_im, G_dep, y_)
|
475 |
+
# Losses
|
476 |
+
# GAN Loss
|
477 |
+
if (self.WGAN): # de WGAN
|
478 |
+
D_loss_real_fake_R = - torch.mean(D_real)
|
479 |
+
D_loss_real_fake_F = torch.mean(D_fake)
|
480 |
+
|
481 |
+
else: # de Gan normal
|
482 |
+
D_loss_real_fake_R = self.BCEWithLogitsLoss(D_real, self.y_real_)
|
483 |
+
D_loss_real_fake_F = self.BCEWithLogitsLoss(D_fake, self.y_fake_)
|
484 |
+
|
485 |
+
D_loss_real_fake = D_loss_real_fake_F + D_loss_real_fake_R
|
486 |
+
|
487 |
+
D_loss = D_loss_real_fake
|
488 |
+
|
489 |
+
self.train_hist['D_loss_Validation'].append(D_loss.item())
|
490 |
+
self.details_hist['D_V_BCE_fake_real_R'].append(D_loss_real_fake_R.item())
|
491 |
+
self.details_hist['D_V_BCE_fake_real_F'].append(D_loss_real_fake_F.item())
|
492 |
+
if self.visdom:
|
493 |
+
visLossDValidation.plot('Discriminator_losses',
|
494 |
+
['D_V_BCE_fake_real_R','D_V_BCE_fake_real_F'], 'Validation',
|
495 |
+
self.details_hist)
|
496 |
+
|
497 |
+
# G network
|
498 |
+
|
499 |
+
G_, G_dep = self.G(y_, x_im, x_dep)
|
500 |
+
|
501 |
+
if not ventaja and not self.onlyGen:
|
502 |
+
# Fake images
|
503 |
+
D_fake,_ = self.D(G_, x_im, G_dep, y_)
|
504 |
+
|
505 |
+
#Loss GAN
|
506 |
+
if (self.WGAN):
|
507 |
+
G_loss = -torch.mean(D_fake) # porWGAN
|
508 |
+
else:
|
509 |
+
G_loss = self.BCEWithLogitsLoss(D_fake, self.y_real_) #de GAN NORMAL
|
510 |
+
|
511 |
+
self.details_hist['G_V_BCE_fake_real'].append(G_loss.item())
|
512 |
+
|
513 |
+
#Loss comparation
|
514 |
+
#G_join = torch.cat((G_, G_dep), 1)
|
515 |
+
#y_join = torch.cat((y_im, y_dep), 1)
|
516 |
+
|
517 |
+
G_loss_Comp = self.L1(G_, y_im)
|
518 |
+
if self.depth:
|
519 |
+
G_loss_Comp += self.L1(G_dep, y_dep)
|
520 |
+
G_loss_Comp = G_loss_Comp * self.lambdaL1
|
521 |
+
|
522 |
+
reverse_y = - y_ + 1
|
523 |
+
reverse_G, reverse_G_dep = self.G(reverse_y, G_, G_dep)
|
524 |
+
G_loss_Cycle = self.L1(reverse_G, x_im)
|
525 |
+
if self.depth:
|
526 |
+
G_loss_Cycle += self.L1(reverse_G_dep, x_dep)
|
527 |
+
G_loss_Cycle = G_loss_Cycle * self.lambdaL1/2
|
528 |
+
|
529 |
+
G_loss += G_loss_Comp + G_loss_Cycle
|
530 |
+
|
531 |
+
|
532 |
+
self.details_hist['G_V_Comp_im'].append(G_loss_Comp.item())
|
533 |
+
self.details_hist['G_V_Cycle'].append(G_loss_Cycle.detach().item())
|
534 |
+
|
535 |
+
else:
|
536 |
+
G_loss = self.L1(G_, y_im)
|
537 |
+
if self.depth:
|
538 |
+
G_loss += self.L1(G_dep, y_dep)
|
539 |
+
G_loss = G_loss * self.lambdaL1
|
540 |
+
self.details_hist['G_V_Comp_im'].append(G_loss.item())
|
541 |
+
self.details_hist['G_V_BCE_fake_real'].append(0)
|
542 |
+
self.details_hist['G_V_Cycle'].append(0)
|
543 |
+
|
544 |
+
self.train_hist['G_loss_Validation'].append(G_loss.item())
|
545 |
+
if self.onlyGen:
|
546 |
+
self.train_hist['D_loss_Validation'].append(0)
|
547 |
+
|
548 |
+
|
549 |
+
iterFinValidation += 1
|
550 |
+
if self.visdom:
|
551 |
+
visLossGValidation.plot('Generator_losses', ['G_V_Comp_im', 'G_V_BCE_fake_real','G_V_Cycle'],
|
552 |
+
'Validation', self.details_hist)
|
553 |
+
visValidation.plot('loss', ['D_loss_Validation', 'G_loss_Validation'], 'Validation',
|
554 |
+
self.train_hist)
|
555 |
+
|
556 |
+
##Vis por epoch
|
557 |
+
|
558 |
+
if ventaja or self.onlyGen:
|
559 |
+
self.epoch_hist['D_loss_train'].append(0)
|
560 |
+
self.epoch_hist['D_loss_Validation'].append(0)
|
561 |
+
else:
|
562 |
+
#inicioTr = (epoch - self.epochVentaja) * (iterFinTrain - iterIniTrain)
|
563 |
+
#inicioTe = (epoch - self.epochVentaja) * (iterFinValidation - iterIniValidation)
|
564 |
+
self.epoch_hist['D_loss_train'].append(mean(self.train_hist['D_loss_train'][iterIniTrain: -1]))
|
565 |
+
self.epoch_hist['D_loss_Validation'].append(mean(self.train_hist['D_loss_Validation'][iterIniValidation: -1]))
|
566 |
+
|
567 |
+
self.epoch_hist['G_loss_train'].append(mean(self.train_hist['G_loss_train'][iterIniTrain:iterFinTrain]))
|
568 |
+
self.epoch_hist['G_loss_Validation'].append(
|
569 |
+
mean(self.train_hist['G_loss_Validation'][iterIniValidation:iterFinValidation]))
|
570 |
+
if self.visdom:
|
571 |
+
visEpoch.plot('epoch', epoch,
|
572 |
+
['D_loss_train', 'G_loss_train', 'D_loss_Validation', 'G_loss_Validation'],
|
573 |
+
self.epoch_hist)
|
574 |
+
|
575 |
+
self.train_hist['D_loss_train'] = self.train_hist['D_loss_train'][-1:]
|
576 |
+
self.train_hist['G_loss_train'] = self.train_hist['G_loss_train'][-1:]
|
577 |
+
self.train_hist['D_loss_Validation'] = self.train_hist['D_loss_Validation'][-1:]
|
578 |
+
self.train_hist['G_loss_Validation'] = self.train_hist['G_loss_Validation'][-1:]
|
579 |
+
self.train_hist['per_epoch_time'] = self.train_hist['per_epoch_time'][-1:]
|
580 |
+
self.train_hist['total_time'] = self.train_hist['total_time'][-1:]
|
581 |
+
|
582 |
+
self.details_hist['G_T_Comp_im'] = self.details_hist['G_T_Comp_im'][-1:]
|
583 |
+
self.details_hist['G_T_BCE_fake_real'] = self.details_hist['G_T_BCE_fake_real'][-1:]
|
584 |
+
self.details_hist['G_T_Cycle'] = self.details_hist['G_T_Cycle'][-1:]
|
585 |
+
self.details_hist['G_zCR'] = self.details_hist['G_zCR'][-1:]
|
586 |
+
|
587 |
+
self.details_hist['G_V_Comp_im'] = self.details_hist['G_V_Comp_im'][-1:]
|
588 |
+
self.details_hist['G_V_BCE_fake_real'] = self.details_hist['G_V_BCE_fake_real'][-1:]
|
589 |
+
self.details_hist['G_V_Cycle'] = self.details_hist['G_V_Cycle'][-1:]
|
590 |
+
|
591 |
+
self.details_hist['D_T_BCE_fake_real_R'] = self.details_hist['D_T_BCE_fake_real_R'][-1:]
|
592 |
+
self.details_hist['D_T_BCE_fake_real_F'] = self.details_hist['D_T_BCE_fake_real_F'][-1:]
|
593 |
+
self.details_hist['D_zCR'] = self.details_hist['D_zCR'][-1:]
|
594 |
+
self.details_hist['D_bCR'] = self.details_hist['D_bCR'][-1:]
|
595 |
+
|
596 |
+
self.details_hist['D_V_BCE_fake_real_R'] = self.details_hist['D_V_BCE_fake_real_R'][-1:]
|
597 |
+
self.details_hist['D_V_BCE_fake_real_F'] = self.details_hist['D_V_BCE_fake_real_F'][-1:]
|
598 |
+
##Para poder tomar el promedio por epoch
|
599 |
+
iterIniTrain = 1
|
600 |
+
iterFinTrain = 1
|
601 |
+
|
602 |
+
iterIniValidation = 1
|
603 |
+
iterFinValidation = 1
|
604 |
+
|
605 |
+
self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
|
606 |
+
|
607 |
+
if epoch % 10 == 0:
|
608 |
+
self.save(str(epoch))
|
609 |
+
with torch.no_grad():
|
610 |
+
if self.visdom:
|
611 |
+
self.visualize_results(epoch, dataprint=self.dataprint, visual=visImages)
|
612 |
+
self.visualize_results(epoch, dataprint=self.dataprint_test, visual=visImagesTest)
|
613 |
+
else:
|
614 |
+
imageName = self.model_name + '_' + 'Train' + '_' + str(self.seed) + '_' + str(epoch)
|
615 |
+
self.visualize_results(epoch, dataprint=self.dataprint, name= imageName)
|
616 |
+
self.visualize_results(epoch, dataprint=self.dataprint_test, name= imageName)
|
617 |
+
|
618 |
+
|
619 |
+
self.train_hist['total_time'].append(time.time() - start_time)
|
620 |
+
print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
|
621 |
+
self.epoch, self.train_hist['total_time'][0]))
|
622 |
+
print("Training finish!... save training results")
|
623 |
+
|
624 |
+
self.save()
|
625 |
+
#utils.generate_animation(self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name,
|
626 |
+
# self.epoch)
|
627 |
+
#utils.loss_plot(self.train_hist, os.path.join(self.save_dir, self.dataset, self.model_name), self.model_name)
|
628 |
+
|
629 |
+
def visualize_results(self, epoch, dataprint, visual="", name= "test"):
|
630 |
+
with torch.no_grad():
|
631 |
+
self.G.eval()
|
632 |
+
|
633 |
+
#if not os.path.exists(self.result_dir + '/' + self.dataset + '/' + self.model_name):
|
634 |
+
# os.makedirs(self.result_dir + '/' + self.dataset + '/' + self.model_name)
|
635 |
+
|
636 |
+
# print("sample z: ",self.sample_z_,"sample y:", self.sample_y_)
|
637 |
+
|
638 |
+
##Podria hacer un loop
|
639 |
+
# .zfill(4)
|
640 |
+
#newSample = None
|
641 |
+
#print(dataprint.shape)
|
642 |
+
|
643 |
+
#newSample = torch.tensor([])
|
644 |
+
|
645 |
+
#se que es ineficiente pero lo hago cada 10 epoch nomas
|
646 |
+
newSample = []
|
647 |
+
iter = 1
|
648 |
+
for x_im,x_dep in zip(dataprint.get('x_im'), dataprint.get('x_dep')):
|
649 |
+
if (iter > self.cantImages):
|
650 |
+
break
|
651 |
+
|
652 |
+
#x_im = (x_im + 1) / 2
|
653 |
+
#imgX = transforms.ToPILImage()(x_im)
|
654 |
+
#imgX.show()
|
655 |
+
|
656 |
+
x_im_input = x_im.repeat(2, 1, 1, 1)
|
657 |
+
x_dep_input = x_dep.repeat(2, 1, 1, 1)
|
658 |
+
|
659 |
+
sizeImage = x_im.shape[2]
|
660 |
+
|
661 |
+
sample_y_ = torch.zeros((self.class_num, 1, sizeImage, sizeImage))
|
662 |
+
for i in range(self.class_num):
|
663 |
+
if(int(i % self.class_num) == 1):
|
664 |
+
sample_y_[i] = torch.ones(( 1, sizeImage, sizeImage))
|
665 |
+
|
666 |
+
if self.gpu_mode:
|
667 |
+
sample_y_, x_im_input, x_dep_input = sample_y_.cuda(), x_im_input.cuda(), x_dep_input.cuda()
|
668 |
+
|
669 |
+
G_im, G_dep = self.G(sample_y_, x_im_input, x_dep_input)
|
670 |
+
|
671 |
+
newSample.append(x_im.squeeze(0))
|
672 |
+
newSample.append(x_dep.squeeze(0).expand(3, -1, -1))
|
673 |
+
|
674 |
+
|
675 |
+
|
676 |
+
if self.wiggle:
|
677 |
+
im_aux, im_dep_aux = G_im, G_dep
|
678 |
+
for i in range(0, 2):
|
679 |
+
index = i
|
680 |
+
for j in range(0, self.wiggleDepth):
|
681 |
+
|
682 |
+
# print(i,j)
|
683 |
+
|
684 |
+
if (j == 0 and i == 1):
|
685 |
+
# para tomar el original
|
686 |
+
im_aux, im_dep_aux = G_im, G_dep
|
687 |
+
newSample.append(G_im.cpu()[0].squeeze(0))
|
688 |
+
newSample.append(G_im.cpu()[1].squeeze(0))
|
689 |
+
elif (i == 1):
|
690 |
+
# por el problema de las iteraciones proximas
|
691 |
+
index = 0
|
692 |
+
|
693 |
+
# imagen generada
|
694 |
+
|
695 |
+
|
696 |
+
x = im_aux[index].unsqueeze(0)
|
697 |
+
x_dep = im_dep_aux[index].unsqueeze(0)
|
698 |
+
|
699 |
+
y = sample_y_[i].unsqueeze(0)
|
700 |
+
|
701 |
+
if self.gpu_mode:
|
702 |
+
y, x, x_dep = y.cuda(), x.cuda(), x_dep.cuda()
|
703 |
+
|
704 |
+
im_aux, im_dep_aux = self.G(y, x, x_dep)
|
705 |
+
|
706 |
+
newSample.append(im_aux.cpu()[0])
|
707 |
+
else:
|
708 |
+
|
709 |
+
newSample.append(G_im.cpu()[0])
|
710 |
+
newSample.append(G_im.cpu()[1])
|
711 |
+
newSample.append(G_dep.cpu()[0].expand(3, -1, -1))
|
712 |
+
newSample.append(G_dep.cpu()[1].expand(3, -1, -1))
|
713 |
+
# sadadas
|
714 |
+
|
715 |
+
iter+=1
|
716 |
+
|
717 |
+
if self.visdom:
|
718 |
+
visual.plot(epoch, newSample, int(len(newSample) /self.cantImages))
|
719 |
+
else:
|
720 |
+
utils.save_wiggle(newSample, self.cantImages, name)
|
721 |
+
##TENGO QUE HACER QUE SAMPLES TENGAN COMO MAXIMO self.class_num * self.class_num
|
722 |
+
|
723 |
+
# utils.save_images(newSample[:, :, :, :], [image_frame_dim * cantidadIm , image_frame_dim * (self.class_num+2)],
|
724 |
+
# self.result_dir + '/' + self.dataset + '/' + self.model_name + '/' + self.model_name + '_epoch%04d' % epoch + '.png')
|
725 |
+
|
726 |
+
def show_plot_images(self, images, cols=1, titles=None):
|
727 |
+
"""Display a list of images in a single figure with matplotlib.
|
728 |
+
|
729 |
+
Parameters
|
730 |
+
---------
|
731 |
+
images: List of np.arrays compatible with plt.imshow.
|
732 |
+
|
733 |
+
cols (Default = 1): Number of columns in figure (number of rows is
|
734 |
+
set to np.ceil(n_images/float(cols))).
|
735 |
+
|
736 |
+
titles: List of titles corresponding to each image. Must have
|
737 |
+
the same length as titles.
|
738 |
+
"""
|
739 |
+
# assert ((titles is None) or (len(images) == len(titles)))
|
740 |
+
n_images = len(images)
|
741 |
+
if titles is None: titles = ['Image (%d)' % i for i in range(1, n_images + 1)]
|
742 |
+
fig = plt.figure()
|
743 |
+
for n, (image, title) in enumerate(zip(images, titles)):
|
744 |
+
a = fig.add_subplot(np.ceil(n_images / float(cols)), cols, n + 1)
|
745 |
+
# print(image)
|
746 |
+
image = (image + 1) * 255.0
|
747 |
+
# print(image)
|
748 |
+
# new_im = Image.fromarray(image)
|
749 |
+
# print(new_im)
|
750 |
+
if image.ndim == 2:
|
751 |
+
plt.gray()
|
752 |
+
# print("spi imshape ", image.shape)
|
753 |
+
plt.imshow(image)
|
754 |
+
a.set_title(title)
|
755 |
+
fig.set_size_inches(np.array(fig.get_size_inches()) * n_images)
|
756 |
+
plt.show()
|
757 |
+
|
758 |
+
def joinImages(self, data):
|
759 |
+
nData = []
|
760 |
+
for i in range(self.class_num):
|
761 |
+
nData.append(data)
|
762 |
+
nData = np.array(nData)
|
763 |
+
nData = torch.tensor(nData.tolist())
|
764 |
+
nData = nData.type(torch.FloatTensor)
|
765 |
+
|
766 |
+
return nData
|
767 |
+
|
768 |
+
def save(self, epoch=''):
|
769 |
+
save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
|
770 |
+
|
771 |
+
if not os.path.exists(save_dir):
|
772 |
+
os.makedirs(save_dir)
|
773 |
+
|
774 |
+
torch.save(self.G.state_dict(),
|
775 |
+
os.path.join(save_dir, self.model_name + '_' + self.seed + '_' + epoch + '_G.pkl'))
|
776 |
+
if not self.onlyGen:
|
777 |
+
torch.save(self.D.state_dict(),
|
778 |
+
os.path.join(save_dir, self.model_name + '_' + self.seed + '_' + epoch + '_D.pkl'))
|
779 |
+
|
780 |
+
with open(os.path.join(save_dir, self.model_name + '_history_ '+self.seed+'.pkl'), 'wb') as f:
|
781 |
+
pickle.dump(self.train_hist, f)
|
782 |
+
|
783 |
+
def load(self):
|
784 |
+
save_dir = os.path.join(self.save_dir, self.dataset, self.model_name)
|
785 |
+
|
786 |
+
self.G.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_G.pkl')))
|
787 |
+
if not self.wiggle:
|
788 |
+
self.D.load_state_dict(torch.load(os.path.join(save_dir, self.model_name + '_' + self.seed_load + '_D.pkl')))
|
789 |
+
|
790 |
+
def wiggleEf(self):
|
791 |
+
seed, epoch = self.seed_load.split('_')
|
792 |
+
if self.visdom:
|
793 |
+
visWiggle = utils.VisdomImagePlotter(env_name='Cobo_depth_wiggle_' + seed)
|
794 |
+
self.visualize_results(epoch=epoch, dataprint=self.dataprint_test, visual=visWiggle)
|
795 |
+
else:
|
796 |
+
self.visualize_results(epoch=epoch, dataprint=self.dataprint_test, visual=None, name = self.name_wiggle)
|
797 |
+
|
798 |
+
def recreate(self):
|
799 |
+
|
800 |
+
dataloader_recreate = dataloader(self.dataset, self.input_size, self.batch_size, self.imageDim, split='score')
|
801 |
+
with torch.no_grad():
|
802 |
+
self.G.eval()
|
803 |
+
accum = 0
|
804 |
+
for data_batch in dataloader_recreate.__iter__():
|
805 |
+
|
806 |
+
#{'x_im': x1, 'x_dep': x1_dep, 'y_im': x2, 'y_dep': x2_dep, 'y_': torch.ones(1, self.imageDim, self.imageDim)}
|
807 |
+
left,left_depth,right,right_depth,direction = data_batch.values()
|
808 |
+
|
809 |
+
if self.gpu_mode:
|
810 |
+
left,left_depth,right,right_depth,direction = left.cuda(),left_depth.cuda(),right.cuda(),right_depth.cuda(),direction.cuda()
|
811 |
+
|
812 |
+
G_right, G_right_dep = self.G( direction, left, left_depth)
|
813 |
+
|
814 |
+
reverse_direction = direction * 0
|
815 |
+
G_left, G_left_dep = self.G( reverse_direction, right, right_depth)
|
816 |
+
|
817 |
+
for index in range(0,self.batch_size):
|
818 |
+
image_right = (G_right[index] + 1.0)/2.0
|
819 |
+
image_right_dep = (G_right_dep[index] + 1.0)/2.0
|
820 |
+
|
821 |
+
image_left = (G_left[index] + 1.0)/2.0
|
822 |
+
image_left_dep = (G_left_dep[index] + 1.0)/2.0
|
823 |
+
|
824 |
+
|
825 |
+
|
826 |
+
save_image(image_right, os.path.join("results","recreate_dataset","CAM1","n_{num:0{width}}.png".format(num = index+accum, width = 4)))
|
827 |
+
save_image(image_right_dep, os.path.join("results","recreate_dataset","CAM1","d_{num:0{width}}.png".format(num = index+accum, width = 4)))
|
828 |
+
|
829 |
+
save_image(image_left, os.path.join("results","recreate_dataset","CAM0","n_{num:0{width}}.png".format(num = index+accum, width = 4)))
|
830 |
+
save_image(image_left_dep, os.path.join("results","recreate_dataset","CAM0","d_{num:0{width}}.png".format(num = index+accum, width = 4)))
|
831 |
+
accum+= self.batch_size
|
832 |
+
|
833 |
+
|
WiggleResults/split.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image
|
3 |
+
import argparse
|
4 |
+
|
5 |
+
parser = argparse.ArgumentParser(description='change to useful name')
|
6 |
+
parser.add_argument('--dim', default=128, type=int, help='dimention image')
|
7 |
+
args = parser.parse_args()
|
8 |
+
|
9 |
+
path = "."
|
10 |
+
dirs = os.listdir(path)
|
11 |
+
|
12 |
+
dim = args.dim
|
13 |
+
|
14 |
+
def gif_order (data, center=True):
|
15 |
+
gif = []
|
16 |
+
base = 1
|
17 |
+
|
18 |
+
#primera mitad
|
19 |
+
i = int((len(data)-2)/2)
|
20 |
+
while(i > base ):
|
21 |
+
gif.append(data[i])
|
22 |
+
#print(i)
|
23 |
+
i -= 1
|
24 |
+
|
25 |
+
|
26 |
+
#el del medio izq
|
27 |
+
gif.append(data[int((len(data)-2)/2) + 1])
|
28 |
+
#print(int((len(data)-2)/2) + 1)
|
29 |
+
|
30 |
+
#el inicial
|
31 |
+
if center:
|
32 |
+
gif.append(data[0])
|
33 |
+
#print(0)
|
34 |
+
|
35 |
+
# el del medio der
|
36 |
+
gif.append(data[int((len(data) - 2) / 2) + 2])
|
37 |
+
#print(int((len(data) - 2) / 2) +2)
|
38 |
+
#segunda mitad
|
39 |
+
i = int((len(data)-2)/2) + 3
|
40 |
+
while (i < len(data)):
|
41 |
+
gif.append(data[i])
|
42 |
+
#print(i)
|
43 |
+
i += 1
|
44 |
+
#print("---------")
|
45 |
+
|
46 |
+
invertedgif = gif[::-1]
|
47 |
+
invertedgif = invertedgif[1:]
|
48 |
+
|
49 |
+
gif = gif[1:] + invertedgif
|
50 |
+
#print(gif)
|
51 |
+
#for image in gif:
|
52 |
+
# image.show()
|
53 |
+
#gsdfgsfgf
|
54 |
+
return gif
|
55 |
+
|
56 |
+
|
57 |
+
# This would print all the files and directories
|
58 |
+
for file in dirs:
|
59 |
+
if ".jpg" in file or ".png" in file:
|
60 |
+
rowImages = []
|
61 |
+
im = Image.open("./" + file)
|
62 |
+
width, height = im.size
|
63 |
+
im = im.convert('RGB')
|
64 |
+
|
65 |
+
#CROP (left, top, right, bottom)
|
66 |
+
|
67 |
+
pointleft = 3
|
68 |
+
pointtop = 3
|
69 |
+
i = 0
|
70 |
+
while (pointtop < height):
|
71 |
+
while (pointleft < width):
|
72 |
+
im1 = im.crop((pointleft, pointtop, dim+pointleft, dim+pointtop))
|
73 |
+
rowImages.append(im1.quantize())
|
74 |
+
#im1.show()
|
75 |
+
pointleft+= dim+4
|
76 |
+
# Ya tengo todas las imagenes podria hacer el gif aca
|
77 |
+
rowImages = gif_order(rowImages,center=False)
|
78 |
+
name = file[:-4] + "_" + str(i) + '.gif'
|
79 |
+
rowImages[0].save(name, save_all=True,format='GIF', append_images=rowImages[1:], optimize=True, duration=100, loop=0)
|
80 |
+
pointtop += dim + 4
|
81 |
+
pointleft = 3
|
82 |
+
rowImages = []
|
83 |
+
i+=1
|
84 |
+
#im2 = im.crop((width / 2, 0, width, height))
|
85 |
+
# im2.show()
|
86 |
+
|
87 |
+
#im1.save("./2" + file[:-4] + ".png")
|
88 |
+
#im2.save("./" + file[:-4] + ".png")
|
89 |
+
|
90 |
+
# Deleted
|
91 |
+
#os.remove("data/" + file)
|
app.py
CHANGED
@@ -16,7 +16,6 @@ def calculate_depth(model_type, img):
|
|
16 |
|
17 |
img.save(filename, "JPEG")
|
18 |
|
19 |
-
#model_type = "DPT_Hybrid"
|
20 |
midas = torch.hub.load("intel-isl/MiDaS", model_type)
|
21 |
|
22 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
@@ -61,18 +60,19 @@ def wiggle_effect(slider):
|
|
61 |
|
62 |
with gr.Blocks() as demo:
|
63 |
gr.Markdown("Start typing below and then click **Run** to see the output.")
|
64 |
-
|
65 |
-
|
|
|
66 |
midas_models = ["DPT_Large","DPT_Hybrid","MiDaS_small"]
|
67 |
-
|
68 |
-
inp.append(gr.inputs.Dropdown(midas_models, default="MiDaS_small", label="Depth estimation model type"))
|
69 |
-
|
70 |
with gr.Row():
|
71 |
inp.append(gr.Image(type="pil", label="Input"))
|
72 |
out = gr.Image(type="file", label="depth_estimation")
|
73 |
btn = gr.Button("Calculate depth")
|
74 |
btn.click(fn=calculate_depth, inputs=inp, outputs=out)
|
75 |
|
|
|
|
|
76 |
inp = [gr.Slider(1,15, default = 2, label='StepCycles',step= 1)]
|
77 |
with gr.Row():
|
78 |
out = [ gr.Image(type="file", label="Output_images"), #TODO change to gallery
|
|
|
16 |
|
17 |
img.save(filename, "JPEG")
|
18 |
|
|
|
19 |
midas = torch.hub.load("intel-isl/MiDaS", model_type)
|
20 |
|
21 |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
|
|
60 |
|
61 |
with gr.Blocks() as demo:
|
62 |
gr.Markdown("Start typing below and then click **Run** to see the output.")
|
63 |
+
|
64 |
+
|
65 |
+
## Depth Estimation
|
66 |
midas_models = ["DPT_Large","DPT_Hybrid","MiDaS_small"]
|
67 |
+
inp = [gr.inputs.Dropdown(midas_models, default="MiDaS_small", label="Depth estimation model type")]
|
|
|
|
|
68 |
with gr.Row():
|
69 |
inp.append(gr.Image(type="pil", label="Input"))
|
70 |
out = gr.Image(type="file", label="depth_estimation")
|
71 |
btn = gr.Button("Calculate depth")
|
72 |
btn.click(fn=calculate_depth, inputs=inp, outputs=out)
|
73 |
|
74 |
+
|
75 |
+
## Wigglegram
|
76 |
inp = [gr.Slider(1,15, default = 2, label='StepCycles',step= 1)]
|
77 |
with gr.Row():
|
78 |
out = [ gr.Image(type="file", label="Output_images"), #TODO change to gallery
|
architectures.py
ADDED
@@ -0,0 +1,1094 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import utils, torch
|
3 |
+
from torch.autograd import Variable
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
|
7 |
+
class generator(nn.Module):
|
8 |
+
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
|
9 |
+
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
|
10 |
+
def __init__(self, input_dim=4, output_dim=1, input_shape=3, class_num=10, height=10, width=10):
|
11 |
+
super(generator, self).__init__()
|
12 |
+
self.input_dim = input_dim
|
13 |
+
self.output_dim = output_dim
|
14 |
+
# print ("self.output_dim", self.output_dim)
|
15 |
+
self.class_num = class_num
|
16 |
+
self.input_shape = list(input_shape)
|
17 |
+
self.toPreDecov = 1024
|
18 |
+
self.toDecov = 1
|
19 |
+
self.height = height
|
20 |
+
self.width = width
|
21 |
+
|
22 |
+
self.input_shape[1] = self.input_dim # esto cambio despues por colores
|
23 |
+
|
24 |
+
# print("input shpe gen",self.input_shape)
|
25 |
+
|
26 |
+
self.conv1 = nn.Sequential(
|
27 |
+
nn.Conv2d(self.input_dim, 10, 4, 2, 1), # para mi el 2 tendria que ser 1
|
28 |
+
nn.Conv2d(10, 4, 4, 2, 1),
|
29 |
+
nn.BatchNorm2d(4),
|
30 |
+
nn.LeakyReLU(0.2),
|
31 |
+
nn.Conv2d(4, 3, 4, 2, 1),
|
32 |
+
nn.BatchNorm2d(3),
|
33 |
+
nn.LeakyReLU(0.2),
|
34 |
+
)
|
35 |
+
|
36 |
+
self.n_size = self._get_conv_output(self.input_shape)
|
37 |
+
# print ("self.n_size",self.n_size)
|
38 |
+
self.cubic = (self.n_size // 8192)
|
39 |
+
# print("self.cubic: ",self.cubic)
|
40 |
+
|
41 |
+
self.fc1 = nn.Sequential(
|
42 |
+
nn.Linear(self.n_size, self.n_size),
|
43 |
+
nn.BatchNorm1d(self.n_size),
|
44 |
+
nn.LeakyReLU(0.2),
|
45 |
+
)
|
46 |
+
|
47 |
+
self.preDeconv = nn.Sequential(
|
48 |
+
##############RED SUPER CHICA PARA QUE ANDE TO DO PORQUE RAM Y MEMORY
|
49 |
+
|
50 |
+
# nn.Linear(self.toPreDecov + self.zdim + self.class_num, 1024),
|
51 |
+
# nn.BatchNorm1d(1024),
|
52 |
+
# nn.LeakyReLU(0.2),
|
53 |
+
# nn.Linear(1024, self.toDecov * self.height // 64 * self.width// 64),
|
54 |
+
# nn.BatchNorm1d(self.toDecov * self.height // 64 * self.width// 64),
|
55 |
+
# nn.LeakyReLU(0.2),
|
56 |
+
# nn.Linear(self.toDecov * self.height // 64 * self.width // 64 , self.toDecov * self.height // 32 * self.width // 32),
|
57 |
+
# nn.BatchNorm1d(self.toDecov * self.height // 32 * self.width // 32),
|
58 |
+
# nn.LeakyReLU(0.2),
|
59 |
+
# nn.Linear(self.toDecov * self.height // 32 * self.width // 32,
|
60 |
+
# 1 * self.height * self.width),
|
61 |
+
# nn.BatchNorm1d(1 * self.height * self.width),
|
62 |
+
# nn.LeakyReLU(0.2),
|
63 |
+
|
64 |
+
nn.Linear(self.n_size + self.class_num, 400),
|
65 |
+
nn.BatchNorm1d(400),
|
66 |
+
nn.LeakyReLU(0.2),
|
67 |
+
nn.Linear(400, 800),
|
68 |
+
nn.BatchNorm1d(800),
|
69 |
+
nn.LeakyReLU(0.2),
|
70 |
+
nn.Linear(800, self.output_dim * self.height * self.width),
|
71 |
+
nn.BatchNorm1d(self.output_dim * self.height * self.width),
|
72 |
+
nn.Tanh(), # Cambio porque hago como que termino ahi
|
73 |
+
|
74 |
+
)
|
75 |
+
|
76 |
+
"""
|
77 |
+
self.deconv = nn.Sequential(
|
78 |
+
nn.ConvTranspose2d(self.toDecov, 2, 4, 2, 0),
|
79 |
+
nn.BatchNorm2d(2),
|
80 |
+
nn.ReLU(),
|
81 |
+
nn.ConvTranspose2d(2, self.output_dim, 4, 2, 1),
|
82 |
+
nn.Tanh(), #esta recomendado que la ultima sea TanH de la Generadora da valores entre -1 y 1
|
83 |
+
)
|
84 |
+
"""
|
85 |
+
utils.initialize_weights(self)
|
86 |
+
|
87 |
+
def _get_conv_output(self, shape):
|
88 |
+
bs = 1
|
89 |
+
input = Variable(torch.rand(bs, *shape))
|
90 |
+
# print("inShape:",input.shape)
|
91 |
+
output_feat = self.conv1(input.squeeze())
|
92 |
+
# print ("output_feat",output_feat.shape)
|
93 |
+
n_size = output_feat.data.view(bs, -1).size(1)
|
94 |
+
# print ("n",n_size // 4)
|
95 |
+
return n_size // 4
|
96 |
+
|
97 |
+
def forward(self, clase, im):
|
98 |
+
##Esto es lo que voy a hacer
|
99 |
+
# Cat entre la imagen y la profundidad
|
100 |
+
# print ("H",self.height,"W",self.width)
|
101 |
+
# imDep = imDep[:, None, :, :]
|
102 |
+
# im = im[:, None, :, :]
|
103 |
+
x = im
|
104 |
+
|
105 |
+
# Ref Conv de ese cat
|
106 |
+
x = self.conv1(x)
|
107 |
+
x = x.view(x.size(0), -1)
|
108 |
+
# print ("x:", x.shape)
|
109 |
+
x = self.fc1(x)
|
110 |
+
# print ("x:",x.shape)
|
111 |
+
|
112 |
+
# cat entre el ruido y la clase
|
113 |
+
y = clase
|
114 |
+
# print("Cat entre input y clase", y.shape) #podria separarlo, unir primero con clase y despues con ruido
|
115 |
+
|
116 |
+
# Red Lineal que une la Conv con el cat anterior
|
117 |
+
x = torch.cat([x, y], 1)
|
118 |
+
x = self.preDeconv(x)
|
119 |
+
# print ("antes de deconv", x.shape)
|
120 |
+
x = x.view(-1, self.output_dim, self.height, self.width)
|
121 |
+
# print("Despues View: ", x.shape)
|
122 |
+
# Red que saca produce la imagen final
|
123 |
+
# x = self.deconv(x)
|
124 |
+
# print("La salida de la generadora es: ",x.shape)
|
125 |
+
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
class discriminator(nn.Module):
|
130 |
+
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
|
131 |
+
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
|
132 |
+
def __init__(self, input_dim=1, output_dim=1, input_shape=2, class_num=10):
|
133 |
+
super(discriminator, self).__init__()
|
134 |
+
self.input_dim = input_dim * 2 # ya que le doy el origen
|
135 |
+
self.output_dim = output_dim
|
136 |
+
self.input_shape = list(input_shape)
|
137 |
+
self.class_num = class_num
|
138 |
+
|
139 |
+
self.input_shape[1] = self.input_dim # esto cambio despues por colores
|
140 |
+
# print(self.input_shape)
|
141 |
+
|
142 |
+
"""""
|
143 |
+
in_channels (int): Number of channels in the input image
|
144 |
+
out_channels (int): Number of channels produced by the convolution
|
145 |
+
kernel_size (int or tuple): Size of the convolving kernel - lo que se agarra para la conv
|
146 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
147 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input.
|
148 |
+
"""""
|
149 |
+
|
150 |
+
"""
|
151 |
+
nn.Conv2d(self.input_dim, 64, 4, 2, 1), #para mi el 2 tendria que ser 1
|
152 |
+
nn.LeakyReLU(0.2),
|
153 |
+
nn.Conv2d(64, 32, 4, 2, 1),
|
154 |
+
nn.LeakyReLU(0.2),
|
155 |
+
nn.MaxPool2d(4, stride=2),
|
156 |
+
nn.Conv2d(32, 32, 4, 2, 1),
|
157 |
+
nn.LeakyReLU(0.2),
|
158 |
+
nn.MaxPool2d(4, stride=2),
|
159 |
+
nn.Conv2d(32, 20, 4, 2, 1),
|
160 |
+
nn.BatchNorm2d(20),
|
161 |
+
nn.LeakyReLU(0.2),
|
162 |
+
"""
|
163 |
+
|
164 |
+
self.conv = nn.Sequential(
|
165 |
+
|
166 |
+
nn.Conv2d(self.input_dim, 4, 4, 2, 1), # para mi el 2 tendria que ser 1
|
167 |
+
nn.LeakyReLU(0.2),
|
168 |
+
nn.Conv2d(4, 8, 4, 2, 1),
|
169 |
+
nn.BatchNorm2d(8),
|
170 |
+
nn.LeakyReLU(0.2),
|
171 |
+
nn.Conv2d(8, 16, 4, 2, 1),
|
172 |
+
nn.BatchNorm2d(16),
|
173 |
+
|
174 |
+
)
|
175 |
+
|
176 |
+
self.n_size = self._get_conv_output(self.input_shape)
|
177 |
+
|
178 |
+
self.fc1 = nn.Sequential(
|
179 |
+
nn.Linear(self.n_size // 4, 1024),
|
180 |
+
nn.BatchNorm1d(1024),
|
181 |
+
nn.LeakyReLU(0.2),
|
182 |
+
nn.Linear(1024, 512),
|
183 |
+
nn.BatchNorm1d(512),
|
184 |
+
nn.LeakyReLU(0.2),
|
185 |
+
nn.Linear(512, 256),
|
186 |
+
nn.BatchNorm1d(256),
|
187 |
+
nn.LeakyReLU(0.2),
|
188 |
+
nn.Linear(256, 128),
|
189 |
+
nn.BatchNorm1d(128),
|
190 |
+
nn.LeakyReLU(0.2),
|
191 |
+
nn.Linear(128, 64),
|
192 |
+
nn.BatchNorm1d(64),
|
193 |
+
nn.LeakyReLU(0.2),
|
194 |
+
)
|
195 |
+
self.dc = nn.Sequential(
|
196 |
+
nn.Linear(64, self.output_dim),
|
197 |
+
nn.Sigmoid(),
|
198 |
+
)
|
199 |
+
self.cl = nn.Sequential(
|
200 |
+
nn.Linear(64, self.class_num),
|
201 |
+
nn.Sigmoid(),
|
202 |
+
)
|
203 |
+
utils.initialize_weights(self)
|
204 |
+
|
205 |
+
# generate input sample and forward to get shape
|
206 |
+
|
207 |
+
def _get_conv_output(self, shape):
|
208 |
+
bs = 1
|
209 |
+
input = Variable(torch.rand(bs, *shape))
|
210 |
+
output_feat = self.conv(input.squeeze())
|
211 |
+
n_size = output_feat.data.view(bs, -1).size(1)
|
212 |
+
return n_size
|
213 |
+
|
214 |
+
def forward(self, input, origen):
|
215 |
+
# esto va a cambiar cuando tenga color
|
216 |
+
# if (len(input.shape) <= 3):
|
217 |
+
# input = input[:, None, :, :]
|
218 |
+
# im = im[:, None, :, :]
|
219 |
+
# print("D in shape",input.shape)
|
220 |
+
|
221 |
+
# print(input.shape)
|
222 |
+
# print("this si X:", x)
|
223 |
+
# print("now shape", x.shape)
|
224 |
+
x = input
|
225 |
+
x = x.type(torch.FloatTensor)
|
226 |
+
x = x.to(device='cuda:0')
|
227 |
+
|
228 |
+
x = torch.cat((x, origen), 1)
|
229 |
+
x = self.conv(x)
|
230 |
+
x = x.view(x.size(0), -1)
|
231 |
+
x = self.fc1(x)
|
232 |
+
d = self.dc(x)
|
233 |
+
c = self.cl(x)
|
234 |
+
|
235 |
+
return d, c
|
236 |
+
|
237 |
+
|
238 |
+
#######################################################################################################################
|
239 |
+
class UnetConvBlock(nn.Module):
|
240 |
+
'''
|
241 |
+
Convolutional block of a U-Net:
|
242 |
+
Conv2d - Batch normalization - LeakyReLU
|
243 |
+
Conv2D - Batch normalization - LeakyReLU
|
244 |
+
Basic Dropout (optional)
|
245 |
+
'''
|
246 |
+
|
247 |
+
def __init__(self, in_size, out_size, dropout=0.0, stride=1, batch_norm = True):
|
248 |
+
'''
|
249 |
+
Constructor of the convolutional block
|
250 |
+
'''
|
251 |
+
super(UnetConvBlock, self).__init__()
|
252 |
+
|
253 |
+
# Convolutional layer with IN_SIZE --> OUT_SIZE
|
254 |
+
conv1 = nn.Conv2d(in_channels=in_size, out_channels=out_size, kernel_size=3, stride=1,
|
255 |
+
padding=1) # podria aplicar stride 2
|
256 |
+
# Activation unit
|
257 |
+
activ_unit1 = nn.LeakyReLU(0.2)
|
258 |
+
# Add batch normalization if necessary
|
259 |
+
if batch_norm:
|
260 |
+
self.conv1 = nn.Sequential(conv1, nn.BatchNorm2d(out_size), activ_unit1)
|
261 |
+
else:
|
262 |
+
self.conv1 = nn.Sequential(conv1, activ_unit1)
|
263 |
+
|
264 |
+
# Convolutional layer with OUT_SIZE --> OUT_SIZE
|
265 |
+
conv2 = nn.Conv2d(in_channels=out_size, out_channels=out_size, kernel_size=3, stride=stride,
|
266 |
+
padding=1) # podria aplicar stride 2
|
267 |
+
# Activation unit
|
268 |
+
activ_unit2 = nn.LeakyReLU(0.2)
|
269 |
+
|
270 |
+
# Add batch normalization
|
271 |
+
if batch_norm:
|
272 |
+
self.conv2 = nn.Sequential(conv2, nn.BatchNorm2d(out_size), activ_unit2)
|
273 |
+
else:
|
274 |
+
self.conv2 = nn.Sequential(conv2, activ_unit2)
|
275 |
+
# Dropout
|
276 |
+
if dropout > 0.0:
|
277 |
+
self.drop = nn.Dropout(dropout)
|
278 |
+
else:
|
279 |
+
self.drop = None
|
280 |
+
|
281 |
+
def forward(self, inputs):
|
282 |
+
'''
|
283 |
+
Do a forward pass
|
284 |
+
'''
|
285 |
+
outputs = self.conv1(inputs)
|
286 |
+
outputs = self.conv2(outputs)
|
287 |
+
if not (self.drop is None):
|
288 |
+
outputs = self.drop(outputs)
|
289 |
+
return outputs
|
290 |
+
|
291 |
+
|
292 |
+
class UnetDeSingleConvBlock(nn.Module):
|
293 |
+
'''
|
294 |
+
DeConvolutional block of a U-Net:
|
295 |
+
Conv2d - Batch normalization - LeakyReLU
|
296 |
+
Basic Dropout (optional)
|
297 |
+
'''
|
298 |
+
|
299 |
+
def __init__(self, in_size, out_size, dropout=0.0, stride=1, padding=1, batch_norm = True ):
|
300 |
+
'''
|
301 |
+
Constructor of the convolutional block
|
302 |
+
'''
|
303 |
+
super(UnetDeSingleConvBlock, self).__init__()
|
304 |
+
|
305 |
+
# Convolutional layer with IN_SIZE --> OUT_SIZE
|
306 |
+
conv1 = nn.Conv2d(in_channels=in_size, out_channels=out_size, kernel_size=3, stride=stride, padding=1)
|
307 |
+
# Activation unit
|
308 |
+
activ_unit1 = nn.LeakyReLU(0.2)
|
309 |
+
# Add batch normalization if necessary
|
310 |
+
if batch_norm:
|
311 |
+
self.conv1 = nn.Sequential(conv1, nn.BatchNorm2d(out_size), activ_unit1)
|
312 |
+
else:
|
313 |
+
self.conv1 = nn.Sequential(conv1, activ_unit1)
|
314 |
+
|
315 |
+
# Dropout
|
316 |
+
if dropout > 0.0:
|
317 |
+
self.drop = nn.Dropout(dropout)
|
318 |
+
else:
|
319 |
+
self.drop = None
|
320 |
+
|
321 |
+
def forward(self, inputs):
|
322 |
+
'''
|
323 |
+
Do a forward pass
|
324 |
+
'''
|
325 |
+
outputs = self.conv1(inputs)
|
326 |
+
if not (self.drop is None):
|
327 |
+
outputs = self.drop(outputs)
|
328 |
+
return outputs
|
329 |
+
|
330 |
+
|
331 |
+
class UnetDeconvBlock(nn.Module):
|
332 |
+
'''
|
333 |
+
DeConvolutional block of a U-Net:
|
334 |
+
UnetDeSingleConvBlock (skip_connection)
|
335 |
+
Cat last_layer + skip_connection
|
336 |
+
UnetDeSingleConvBlock ( Cat )
|
337 |
+
Basic Dropout (optional)
|
338 |
+
'''
|
339 |
+
|
340 |
+
def __init__(self, in_size_layer, in_size_skip_con, out_size, dropout=0.0):
|
341 |
+
'''
|
342 |
+
Constructor of the convolutional block
|
343 |
+
'''
|
344 |
+
super(UnetDeconvBlock, self).__init__()
|
345 |
+
|
346 |
+
self.conv1 = UnetDeSingleConvBlock(in_size_skip_con, in_size_skip_con, dropout)
|
347 |
+
self.conv2 = UnetDeSingleConvBlock(in_size_layer + in_size_skip_con, out_size, dropout)
|
348 |
+
|
349 |
+
# Dropout
|
350 |
+
if dropout > 0.0:
|
351 |
+
self.drop = nn.Dropout(dropout)
|
352 |
+
else:
|
353 |
+
self.drop = None
|
354 |
+
|
355 |
+
def forward(self, inputs_layer, inputs_skip):
|
356 |
+
'''
|
357 |
+
Do a forward pass
|
358 |
+
'''
|
359 |
+
|
360 |
+
outputs = self.conv1(inputs_skip)
|
361 |
+
|
362 |
+
#outputs = changeDim(outputs, inputs_layer)
|
363 |
+
|
364 |
+
outputs = torch.cat((inputs_layer, outputs), 1)
|
365 |
+
outputs = self.conv2(outputs)
|
366 |
+
|
367 |
+
return outputs
|
368 |
+
|
369 |
+
|
370 |
+
class UpBlock(nn.Module):
|
371 |
+
"""Upscaling then double conv"""
|
372 |
+
|
373 |
+
def __init__(self, in_size_layer, in_size_skip_con, out_size, bilinear=True):
|
374 |
+
super(UpBlock, self).__init__()
|
375 |
+
|
376 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
377 |
+
if bilinear:
|
378 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
379 |
+
else:
|
380 |
+
self.up = nn.ConvTranspose2d(in_size_layer // 2, in_size_layer // 2, kernel_size=2, stride=2)
|
381 |
+
|
382 |
+
self.conv = UnetDeconvBlock(in_size_layer, in_size_skip_con, out_size)
|
383 |
+
|
384 |
+
def forward(self, inputs_layer, inputs_skip):
|
385 |
+
|
386 |
+
inputs_layer = self.up(inputs_layer)
|
387 |
+
|
388 |
+
# input is CHW
|
389 |
+
#inputs_layer = changeDim(inputs_layer, inputs_skip)
|
390 |
+
|
391 |
+
return self.conv(inputs_layer, inputs_skip)
|
392 |
+
|
393 |
+
|
394 |
+
class lastBlock(nn.Module):
|
395 |
+
'''
|
396 |
+
DeConvolutional block of a U-Net:
|
397 |
+
Conv2d - Batch normalization - LeakyReLU
|
398 |
+
Basic Dropout (optional)
|
399 |
+
'''
|
400 |
+
|
401 |
+
def __init__(self, in_size, out_size, dropout=0.0):
|
402 |
+
'''
|
403 |
+
Constructor of the convolutional block
|
404 |
+
'''
|
405 |
+
super(lastBlock, self).__init__()
|
406 |
+
|
407 |
+
# Convolutional layer with IN_SIZE --> OUT_SIZE
|
408 |
+
conv1 = nn.Conv2d(in_channels=in_size, out_channels=out_size, kernel_size=3, stride=1, padding=1)
|
409 |
+
# Activation unit
|
410 |
+
activ_unit1 = nn.Tanh()
|
411 |
+
# Add batch normalization if necessary
|
412 |
+
self.conv1 = nn.Sequential(conv1, nn.BatchNorm2d(out_size), activ_unit1)
|
413 |
+
|
414 |
+
# Dropout
|
415 |
+
if dropout > 0.0:
|
416 |
+
self.drop = nn.Dropout(dropout)
|
417 |
+
else:
|
418 |
+
self.drop = None
|
419 |
+
|
420 |
+
def forward(self, inputs):
|
421 |
+
'''
|
422 |
+
Do a forward pass
|
423 |
+
'''
|
424 |
+
outputs = self.conv1(inputs)
|
425 |
+
if not (self.drop is None):
|
426 |
+
outputs = self.drop(outputs)
|
427 |
+
return outputs
|
428 |
+
|
429 |
+
|
430 |
+
################
|
431 |
+
|
432 |
+
class generator_UNet(nn.Module):
|
433 |
+
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
|
434 |
+
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
|
435 |
+
def __init__(self, input_dim=4, output_dim=1, input_shape=3, class_num=2, expand_net=3):
|
436 |
+
super(generator_UNet, self).__init__()
|
437 |
+
self.input_dim = input_dim + 1 # por la clase
|
438 |
+
self.output_dim = output_dim
|
439 |
+
# print ("self.output_dim", self.output_dim)
|
440 |
+
self.class_num = class_num
|
441 |
+
self.input_shape = list(input_shape)
|
442 |
+
|
443 |
+
self.input_shape[1] = self.input_dim # esto cambio despues por colores
|
444 |
+
|
445 |
+
self.expandNet = expand_net # 5
|
446 |
+
|
447 |
+
# Downsampling
|
448 |
+
self.conv1 = UnetConvBlock(self.input_dim, pow(2, self.expandNet), stride=1)
|
449 |
+
# self.maxpool1 = nn.MaxPool2d(kernel_size=2)
|
450 |
+
self.conv2 = UnetConvBlock(pow(2, self.expandNet), pow(2, self.expandNet + 1), stride=2)
|
451 |
+
# self.maxpool2 = nn.MaxPool2d(kernel_size=2)
|
452 |
+
self.conv3 = UnetConvBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 2), stride=2)
|
453 |
+
# self.maxpool3 = nn.MaxPool2d(kernel_size=2)
|
454 |
+
# Middle ground
|
455 |
+
self.conv4 = UnetDeSingleConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), stride=2)
|
456 |
+
# UpSampling
|
457 |
+
self.up1 = UpBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), pow(2, self.expandNet + 1),
|
458 |
+
bilinear=True)
|
459 |
+
self.up2 = UpBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 1), pow(2, self.expandNet),
|
460 |
+
bilinear=True)
|
461 |
+
self.up3 = UpBlock(pow(2, self.expandNet), pow(2, self.expandNet), 8, bilinear=True)
|
462 |
+
self.last = lastBlock(8, self.output_dim)
|
463 |
+
|
464 |
+
utils.initialize_weights(self)
|
465 |
+
|
466 |
+
def _get_conv_output(self, shape):
|
467 |
+
bs = 1
|
468 |
+
input = Variable(torch.rand(bs, *shape))
|
469 |
+
# print("inShape:",input.shape)
|
470 |
+
output_feat = self.conv1(input.squeeze()) ##CAMBIAR
|
471 |
+
# print ("output_feat",output_feat.shape)
|
472 |
+
n_size = output_feat.data.view(bs, -1).size(1)
|
473 |
+
# print ("n",n_size // 4)
|
474 |
+
return n_size // 4
|
475 |
+
|
476 |
+
def forward(self, clase, im):
|
477 |
+
x = im
|
478 |
+
|
479 |
+
##PARA TENER LA CLASE DEL CORRIMIENTO
|
480 |
+
cl = ((clase == 1))
|
481 |
+
cl = cl[:, 1]
|
482 |
+
cl = cl.type(torch.FloatTensor)
|
483 |
+
max = (clase.size())[1] - 1
|
484 |
+
cl = cl / float(max)
|
485 |
+
|
486 |
+
##crear imagen layer de corrimiento
|
487 |
+
tam = im.size()
|
488 |
+
layerClase = torch.ones([tam[0], tam[2], tam[3]], dtype=torch.float32, device="cuda:0")
|
489 |
+
for idx, item in enumerate(layerClase):
|
490 |
+
layerClase[idx] = item * cl[idx]
|
491 |
+
layerClase = layerClase.unsqueeze(0)
|
492 |
+
layerClase = layerClase.transpose(1, 0)
|
493 |
+
|
494 |
+
##unir layer el rgb de la imagen
|
495 |
+
x = torch.cat((x, layerClase), 1)
|
496 |
+
|
497 |
+
x1 = self.conv1(x)
|
498 |
+
x2 = self.conv2(x1) # self.maxpool1(x1))
|
499 |
+
x3 = self.conv3(x2) # self.maxpool2(x2))
|
500 |
+
x4 = self.conv4(x3) # self.maxpool3(x3))
|
501 |
+
x = self.up1(x4, x3)
|
502 |
+
x = self.up2(x, x2)
|
503 |
+
x = self.up3(x, x1)
|
504 |
+
x = changeDim(x, im)
|
505 |
+
x = self.last(x)
|
506 |
+
|
507 |
+
return x
|
508 |
+
|
509 |
+
|
510 |
+
class discriminator_UNet(nn.Module):
|
511 |
+
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
|
512 |
+
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
|
513 |
+
def __init__(self, input_dim=1, output_dim=1, input_shape=[2, 2], class_num=10, expand_net = 2):
|
514 |
+
super(discriminator_UNet, self).__init__()
|
515 |
+
self.input_dim = input_dim * 2 # ya que le doy el origen
|
516 |
+
self.output_dim = output_dim
|
517 |
+
self.input_shape = list(input_shape)
|
518 |
+
self.class_num = class_num
|
519 |
+
|
520 |
+
self.input_shape[1] = self.input_dim # esto cambio despues por colores
|
521 |
+
|
522 |
+
self.expandNet = expand_net # 4
|
523 |
+
|
524 |
+
# Downsampling
|
525 |
+
self.conv1 = UnetConvBlock(self.input_dim, pow(2, self.expandNet), stride=1, dropout=0.3)
|
526 |
+
self.conv2 = UnetConvBlock(pow(2, self.expandNet), pow(2, self.expandNet + 1), stride=2, dropout=0.5)
|
527 |
+
self.conv3 = UnetConvBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 2), stride=2, dropout=0.4)
|
528 |
+
|
529 |
+
# Middle ground
|
530 |
+
self.conv4 = UnetDeSingleConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), stride=2,
|
531 |
+
dropout=0.3)
|
532 |
+
|
533 |
+
self.n_size = self._get_conv_output(self.input_shape)
|
534 |
+
|
535 |
+
self.fc1 = nn.Sequential(
|
536 |
+
nn.Linear(self.n_size // 4, 1024),
|
537 |
+
nn.BatchNorm1d(1024),
|
538 |
+
nn.LeakyReLU(0.2),
|
539 |
+
)
|
540 |
+
|
541 |
+
self.dc = nn.Sequential(
|
542 |
+
nn.Linear(1024, self.output_dim),
|
543 |
+
# nn.Sigmoid(),
|
544 |
+
)
|
545 |
+
self.cl = nn.Sequential(
|
546 |
+
nn.Linear(1024, self.class_num),
|
547 |
+
nn.Softmax(dim=1), # poner el que la suma da 1
|
548 |
+
)
|
549 |
+
utils.initialize_weights(self)
|
550 |
+
|
551 |
+
# generate input sample and forward to get shape
|
552 |
+
|
553 |
+
def _get_conv_output(self, shape):
|
554 |
+
bs = 1
|
555 |
+
input = Variable(torch.rand(bs, *shape))
|
556 |
+
x = input.squeeze()
|
557 |
+
x = self.conv1(x)
|
558 |
+
x = self.conv2(x)
|
559 |
+
x = self.conv3(x)
|
560 |
+
x = self.conv4(x)
|
561 |
+
n_size = x.data.view(bs, -1).size(1)
|
562 |
+
return n_size
|
563 |
+
|
564 |
+
def forward(self, input, origen):
|
565 |
+
# esto va a cambiar cuando tenga color
|
566 |
+
# if (len(input.shape) <= 3):
|
567 |
+
# input = input[:, None, :, :]
|
568 |
+
# im = im[:, None, :, :]
|
569 |
+
# print("D in shape",input.shape)
|
570 |
+
|
571 |
+
# print(input.shape)
|
572 |
+
# print("this si X:", x)
|
573 |
+
# print("now shape", x.shape)
|
574 |
+
x = input
|
575 |
+
x = x.type(torch.FloatTensor)
|
576 |
+
x = x.to(device='cuda:0')
|
577 |
+
|
578 |
+
x = torch.cat((x, origen), 1)
|
579 |
+
x = self.conv1(x)
|
580 |
+
x = self.conv2(x)
|
581 |
+
x = self.conv3(x)
|
582 |
+
x = self.conv4(x)
|
583 |
+
x = x.view(x.size(0), -1)
|
584 |
+
x = self.fc1(x)
|
585 |
+
d = self.dc(x)
|
586 |
+
c = self.cl(x)
|
587 |
+
|
588 |
+
return d, c
|
589 |
+
|
590 |
+
|
591 |
+
def changeDim(x, y):
|
592 |
+
''' Change dim-image from x to y '''
|
593 |
+
|
594 |
+
diffY = torch.tensor([y.size()[2] - x.size()[2]])
|
595 |
+
diffX = torch.tensor([y.size()[3] - x.size()[3]])
|
596 |
+
x = F.pad(x, [diffX // 2, diffX - diffX // 2,
|
597 |
+
diffY // 2, diffY - diffY // 2])
|
598 |
+
return x
|
599 |
+
|
600 |
+
|
601 |
+
######################################## ACGAN ###########################################################
|
602 |
+
|
603 |
+
class depth_generator(nn.Module):
|
604 |
+
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
|
605 |
+
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
|
606 |
+
def __init__(self, input_dim=4, output_dim=1, input_shape=3, class_num=10, zdim=1, height=10, width=10):
|
607 |
+
super(depth_generator, self).__init__()
|
608 |
+
self.input_dim = input_dim
|
609 |
+
self.output_dim = output_dim
|
610 |
+
self.class_num = class_num
|
611 |
+
# print ("self.output_dim", self.output_dim)
|
612 |
+
self.input_shape = list(input_shape)
|
613 |
+
self.zdim = zdim
|
614 |
+
self.toPreDecov = 1024
|
615 |
+
self.toDecov = 1
|
616 |
+
self.height = height
|
617 |
+
self.width = width
|
618 |
+
|
619 |
+
self.input_shape[1] = self.input_dim # esto cambio despues por colores
|
620 |
+
|
621 |
+
# print("input shpe gen",self.input_shape)
|
622 |
+
|
623 |
+
self.conv1 = nn.Sequential(
|
624 |
+
##############RED SUPER CHICA PARA QUE ANDE TO DO PORQUE RAM Y MEMORY
|
625 |
+
nn.Conv2d(self.input_dim, 2, 4, 2, 1), # para mi el 2 tendria que ser 1
|
626 |
+
nn.Conv2d(2, 1, 4, 2, 1),
|
627 |
+
nn.BatchNorm2d(1),
|
628 |
+
nn.LeakyReLU(0.2),
|
629 |
+
)
|
630 |
+
|
631 |
+
self.n_size = self._get_conv_output(self.input_shape)
|
632 |
+
# print ("self.n_size",self.n_size)
|
633 |
+
self.cubic = (self.n_size // 8192)
|
634 |
+
# print("self.cubic: ",self.cubic)
|
635 |
+
|
636 |
+
self.fc1 = nn.Sequential(
|
637 |
+
nn.Linear(self.n_size, self.n_size),
|
638 |
+
nn.BatchNorm1d(self.n_size),
|
639 |
+
nn.LeakyReLU(0.2),
|
640 |
+
)
|
641 |
+
|
642 |
+
self.preDeconv = nn.Sequential(
|
643 |
+
##############RED SUPER CHICA PARA QUE ANDE TO DO PORQUE RAM Y MEMORY
|
644 |
+
|
645 |
+
# nn.Linear(self.toPreDecov + self.zdim + self.class_num, 1024),
|
646 |
+
# nn.BatchNorm1d(1024),
|
647 |
+
# nn.LeakyReLU(0.2),
|
648 |
+
# nn.Linear(1024, self.toDecov * self.height // 64 * self.width// 64),
|
649 |
+
# nn.BatchNorm1d(self.toDecov * self.height // 64 * self.width// 64),
|
650 |
+
# nn.LeakyReLU(0.2),
|
651 |
+
# nn.Linear(self.toDecov * self.height // 64 * self.width // 64 , self.toDecov * self.height // 32 * self.width // 32),
|
652 |
+
# nn.BatchNorm1d(self.toDecov * self.height // 32 * self.width // 32),
|
653 |
+
# nn.LeakyReLU(0.2),
|
654 |
+
# nn.Linear(self.toDecov * self.height // 32 * self.width // 32,
|
655 |
+
# 1 * self.height * self.width),
|
656 |
+
# nn.BatchNorm1d(1 * self.height * self.width),
|
657 |
+
# nn.LeakyReLU(0.2),
|
658 |
+
|
659 |
+
nn.Linear(self.n_size + self.zdim + self.class_num, 50),
|
660 |
+
nn.BatchNorm1d(50),
|
661 |
+
nn.LeakyReLU(0.2),
|
662 |
+
nn.Linear(50, 200),
|
663 |
+
nn.BatchNorm1d(200),
|
664 |
+
nn.LeakyReLU(0.2),
|
665 |
+
nn.Linear(200, self.output_dim * self.height * self.width),
|
666 |
+
nn.BatchNorm1d(self.output_dim * self.height * self.width),
|
667 |
+
nn.Tanh(), # Cambio porque hago como que termino ahi
|
668 |
+
|
669 |
+
)
|
670 |
+
|
671 |
+
"""
|
672 |
+
self.deconv = nn.Sequential(
|
673 |
+
nn.ConvTranspose2d(self.toDecov, 2, 4, 2, 0),
|
674 |
+
nn.BatchNorm2d(2),
|
675 |
+
nn.ReLU(),
|
676 |
+
nn.ConvTranspose2d(2, self.output_dim, 4, 2, 1),
|
677 |
+
nn.Tanh(), #esta recomendado que la ultima sea TanH de la Generadora da valores entre -1 y 1
|
678 |
+
)
|
679 |
+
"""
|
680 |
+
utils.initialize_weights(self)
|
681 |
+
|
682 |
+
def _get_conv_output(self, shape):
|
683 |
+
bs = 1
|
684 |
+
input = Variable(torch.rand(bs, *shape))
|
685 |
+
# print("inShape:",input.shape)
|
686 |
+
output_feat = self.conv1(input.squeeze())
|
687 |
+
# print ("output_feat",output_feat.shape)
|
688 |
+
n_size = output_feat.data.view(bs, -1).size(1)
|
689 |
+
# print ("n",n_size // 4)
|
690 |
+
return n_size // 4
|
691 |
+
|
692 |
+
def forward(self, input, clase, im, imDep):
|
693 |
+
##Esto es lo que voy a hacer
|
694 |
+
# Cat entre la imagen y la profundidad
|
695 |
+
print ("H", self.height, "W", self.width)
|
696 |
+
# imDep = imDep[:, None, :, :]
|
697 |
+
# im = im[:, None, :, :]
|
698 |
+
print ("imdep", imDep.shape)
|
699 |
+
print ("im", im.shape)
|
700 |
+
x = torch.cat([im, imDep], 1)
|
701 |
+
|
702 |
+
# Ref Conv de ese cat
|
703 |
+
x = self.conv1(x)
|
704 |
+
x = x.view(x.size(0), -1)
|
705 |
+
print ("x:", x.shape)
|
706 |
+
x = self.fc1(x)
|
707 |
+
# print ("x:",x.shape)
|
708 |
+
|
709 |
+
# cat entre el ruido y la clase
|
710 |
+
y = torch.cat([input, clase], 1)
|
711 |
+
print("Cat entre input y clase", y.shape) # podria separarlo, unir primero con clase y despues con ruido
|
712 |
+
|
713 |
+
# Red Lineal que une la Conv con el cat anterior
|
714 |
+
x = torch.cat([x, y], 1)
|
715 |
+
x = self.preDeconv(x)
|
716 |
+
print ("antes de deconv", x.shape)
|
717 |
+
x = x.view(-1, self.output_dim, self.height, self.width)
|
718 |
+
print("Despues View: ", x.shape)
|
719 |
+
# Red que saca produce la imagen final
|
720 |
+
# x = self.deconv(x)
|
721 |
+
print("La salida de la generadora es: ", x.shape)
|
722 |
+
|
723 |
+
return x
|
724 |
+
|
725 |
+
|
726 |
+
class depth_discriminator(nn.Module):
|
727 |
+
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
|
728 |
+
# Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
|
729 |
+
def __init__(self, input_dim=1, output_dim=1, input_shape=2, class_num=10):
|
730 |
+
super(depth_discriminator, self).__init__()
|
731 |
+
self.input_dim = input_dim
|
732 |
+
self.output_dim = output_dim
|
733 |
+
self.input_shape = list(input_shape)
|
734 |
+
self.class_num = class_num
|
735 |
+
|
736 |
+
self.input_shape[1] = self.input_dim # esto cambio despues por colores
|
737 |
+
print(self.input_shape)
|
738 |
+
|
739 |
+
"""""
|
740 |
+
in_channels (int): Number of channels in the input image
|
741 |
+
out_channels (int): Number of channels produced by the convolution
|
742 |
+
kernel_size (int or tuple): Size of the convolving kernel - lo que se agarra para la conv
|
743 |
+
stride (int or tuple, optional): Stride of the convolution. Default: 1
|
744 |
+
padding (int or tuple, optional): Zero-padding added to both sides of the input.
|
745 |
+
"""""
|
746 |
+
|
747 |
+
"""
|
748 |
+
nn.Conv2d(self.input_dim, 64, 4, 2, 1), #para mi el 2 tendria que ser 1
|
749 |
+
nn.LeakyReLU(0.2),
|
750 |
+
nn.Conv2d(64, 32, 4, 2, 1),
|
751 |
+
nn.LeakyReLU(0.2),
|
752 |
+
nn.MaxPool2d(4, stride=2),
|
753 |
+
nn.Conv2d(32, 32, 4, 2, 1),
|
754 |
+
nn.LeakyReLU(0.2),
|
755 |
+
nn.MaxPool2d(4, stride=2),
|
756 |
+
nn.Conv2d(32, 20, 4, 2, 1),
|
757 |
+
nn.BatchNorm2d(20),
|
758 |
+
nn.LeakyReLU(0.2),
|
759 |
+
"""
|
760 |
+
|
761 |
+
self.conv = nn.Sequential(
|
762 |
+
|
763 |
+
nn.Conv2d(self.input_dim, 4, 4, 2, 1), # para mi el 2 tendria que ser 1
|
764 |
+
nn.LeakyReLU(0.2),
|
765 |
+
nn.Conv2d(4, 8, 4, 2, 1),
|
766 |
+
nn.BatchNorm2d(8),
|
767 |
+
nn.LeakyReLU(0.2),
|
768 |
+
nn.Conv2d(8, 16, 4, 2, 1),
|
769 |
+
nn.BatchNorm2d(16),
|
770 |
+
|
771 |
+
)
|
772 |
+
|
773 |
+
self.n_size = self._get_conv_output(self.input_shape)
|
774 |
+
|
775 |
+
self.fc1 = nn.Sequential(
|
776 |
+
nn.Linear(self.n_size // 4, 1024),
|
777 |
+
nn.BatchNorm1d(1024),
|
778 |
+
nn.LeakyReLU(0.2),
|
779 |
+
nn.Linear(1024, 512),
|
780 |
+
nn.BatchNorm1d(512),
|
781 |
+
nn.LeakyReLU(0.2),
|
782 |
+
nn.Linear(512, 256),
|
783 |
+
nn.BatchNorm1d(256),
|
784 |
+
nn.LeakyReLU(0.2),
|
785 |
+
nn.Linear(256, 128),
|
786 |
+
nn.BatchNorm1d(128),
|
787 |
+
nn.LeakyReLU(0.2),
|
788 |
+
nn.Linear(128, 64),
|
789 |
+
nn.BatchNorm1d(64),
|
790 |
+
nn.LeakyReLU(0.2),
|
791 |
+
)
|
792 |
+
self.dc = nn.Sequential(
|
793 |
+
nn.Linear(64, self.output_dim),
|
794 |
+
nn.Sigmoid(),
|
795 |
+
)
|
796 |
+
self.cl = nn.Sequential(
|
797 |
+
nn.Linear(64, self.class_num),
|
798 |
+
nn.Sigmoid(),
|
799 |
+
)
|
800 |
+
utils.initialize_weights(self)
|
801 |
+
|
802 |
+
# generate input sample and forward to get shape
|
803 |
+
|
804 |
+
def _get_conv_output(self, shape):
|
805 |
+
bs = 1
|
806 |
+
input = Variable(torch.rand(bs, *shape))
|
807 |
+
output_feat = self.conv(input.squeeze())
|
808 |
+
n_size = output_feat.data.view(bs, -1).size(1)
|
809 |
+
return n_size
|
810 |
+
|
811 |
+
def forward(self, input, im):
|
812 |
+
# esto va a cambiar cuando tenga color
|
813 |
+
# if (len(input.shape) <= 3):
|
814 |
+
# input = input[:, None, :, :]
|
815 |
+
# im = im[:, None, :, :]
|
816 |
+
print("D in shape", input.shape)
|
817 |
+
print("D im shape", im.shape)
|
818 |
+
x = torch.cat([input, im], 1)
|
819 |
+
print(input.shape)
|
820 |
+
# print("this si X:", x)
|
821 |
+
# print("now shape", x.shape)
|
822 |
+
x = x.type(torch.FloatTensor)
|
823 |
+
x = x.to(device='cuda:0')
|
824 |
+
x = self.conv(x)
|
825 |
+
x = x.view(x.size(0), -1)
|
826 |
+
x = self.fc1(x)
|
827 |
+
d = self.dc(x)
|
828 |
+
c = self.cl(x)
|
829 |
+
|
830 |
+
return d, c
|
831 |
+
|
832 |
+
|
833 |
+
class depth_generator_UNet(nn.Module):
|
834 |
+
# Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
|
835 |
+
# Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
|
836 |
+
def __init__(self, input_dim=4, output_dim=1, class_num=10, expand_net=3, depth=True):
|
837 |
+
super(depth_generator_UNet, self).__init__()
|
838 |
+
|
839 |
+
if depth:
|
840 |
+
self.input_dim = input_dim + 1
|
841 |
+
else:
|
842 |
+
self.input_dim = input_dim
|
843 |
+
self.output_dim = output_dim
|
844 |
+
self.class_num = class_num
|
845 |
+
# print ("self.output_dim", self.output_dim)
|
846 |
+
|
847 |
+
self.expandNet = expand_net # 5
|
848 |
+
self.depth = depth
|
849 |
+
|
850 |
+
# Downsampling
|
851 |
+
self.conv1 = UnetConvBlock(self.input_dim, pow(2, self.expandNet))
|
852 |
+
# self.maxpool1 = nn.MaxPool2d(kernel_size=2)
|
853 |
+
self.conv2 = UnetConvBlock(pow(2, self.expandNet), pow(2, self.expandNet + 1), stride=2)
|
854 |
+
# self.maxpool2 = nn.MaxPool2d(kernel_size=2)
|
855 |
+
self.conv3 = UnetConvBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 2), stride=2)
|
856 |
+
# self.maxpool3 = nn.MaxPool2d(kernel_size=2)
|
857 |
+
# Middle ground
|
858 |
+
self.conv4 = UnetDeSingleConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), stride=2)
|
859 |
+
# UpSampling
|
860 |
+
self.up1 = UpBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), pow(2, self.expandNet + 1))
|
861 |
+
self.up2 = UpBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 1), pow(2, self.expandNet))
|
862 |
+
self.up3 = UpBlock(pow(2, self.expandNet), pow(2, self.expandNet), 8)
|
863 |
+
self.last = lastBlock(8, self.output_dim)
|
864 |
+
|
865 |
+
if depth:
|
866 |
+
self.upDep1 = UpBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), pow(2, self.expandNet + 1))
|
867 |
+
self.upDep2 = UpBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 1), pow(2, self.expandNet))
|
868 |
+
self.upDep3 = UpBlock(pow(2, self.expandNet), pow(2, self.expandNet), 8)
|
869 |
+
self.lastDep = lastBlock(8, 1)
|
870 |
+
|
871 |
+
|
872 |
+
|
873 |
+
utils.initialize_weights(self)
|
874 |
+
|
875 |
+
|
876 |
+
def forward(self, clase, im, imDep):
|
877 |
+
##Hago algo con el z?
|
878 |
+
#print (im.shape)
|
879 |
+
#print (z.shape)
|
880 |
+
#print (z)
|
881 |
+
#imz = torch.repeat_interleave(z, repeats=torch.tensor([2, 2]), dim=1)
|
882 |
+
#print (imz.shape)
|
883 |
+
#print (imz)
|
884 |
+
#sdadsadas
|
885 |
+
if self.depth:
|
886 |
+
x = torch.cat([im, imDep], 1)
|
887 |
+
x = torch.cat((x, clase), 1)
|
888 |
+
else:
|
889 |
+
x = torch.cat((im, clase), 1)
|
890 |
+
##unir layer el rgb de la imagen
|
891 |
+
|
892 |
+
|
893 |
+
x1 = self.conv1(x)
|
894 |
+
x2 = self.conv2(x1) # self.maxpool1(x1))
|
895 |
+
x3 = self.conv3(x2) # self.maxpool2(x2))
|
896 |
+
x4 = self.conv4(x3) # self.maxpool3(x3))
|
897 |
+
|
898 |
+
x = self.up1(x4, x3)
|
899 |
+
x = self.up2(x, x2)
|
900 |
+
x = self.up3(x, x1)
|
901 |
+
#x = changeDim(x, im)
|
902 |
+
x = self.last(x)
|
903 |
+
|
904 |
+
#x = x[:, :3, :, :] #cambio teorico
|
905 |
+
|
906 |
+
if self.depth:
|
907 |
+
dep = self.upDep1(x4, x3)
|
908 |
+
dep = self.upDep2(dep, x2)
|
909 |
+
dep = self.upDep3(dep, x1)
|
910 |
+
# x = changeDim(x, im)
|
911 |
+
dep = self.lastDep(dep)
|
912 |
+
return x, dep
|
913 |
+
else:
|
914 |
+
return x,imDep
|
915 |
+
|
916 |
+
|
917 |
+
class depth_discriminator_UNet(nn.Module):
|
918 |
+
def __init__(self, input_dim=1, output_dim=1, input_shape=[8, 7, 128, 128], class_num=2, expand_net=2):
|
919 |
+
super(depth_discriminator_UNet, self).__init__()
|
920 |
+
self.input_dim = input_dim * 2 + 1
|
921 |
+
|
922 |
+
#discriminator_UNet.__init__(self, input_dim=self.input_dim, output_dim=output_dim, input_shape=input_shape,
|
923 |
+
# class_num=class_num, expand_net = expand_net)
|
924 |
+
|
925 |
+
self.output_dim = output_dim
|
926 |
+
self.input_shape = list(input_shape)
|
927 |
+
self.class_num = class_num
|
928 |
+
self.expandNet = expand_net
|
929 |
+
|
930 |
+
self.input_dim = input_dim * 2 + 1 # ya que le doy el origen + mapa de profundidad
|
931 |
+
self.conv1 = UnetConvBlock(self.input_dim, pow(2, self.expandNet), stride=1, dropout=0.3)
|
932 |
+
self.conv2 = UnetConvBlock(pow(2, self.expandNet), pow(2, self.expandNet + 1), stride=2, dropout=0.2)
|
933 |
+
self.conv3 = UnetConvBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 2), stride=2, dropout=0.2)
|
934 |
+
self.conv4 = UnetDeSingleConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), stride=2,
|
935 |
+
dropout=0.3)
|
936 |
+
|
937 |
+
self.input_shape[1] = self.input_dim
|
938 |
+
self.n_size = self._get_conv_output(self.input_shape)
|
939 |
+
|
940 |
+
self.fc1 = nn.Sequential(
|
941 |
+
nn.Linear(self.n_size, 1024),
|
942 |
+
)
|
943 |
+
|
944 |
+
self.BnLr = nn.Sequential(
|
945 |
+
nn.BatchNorm1d(1024),
|
946 |
+
nn.LeakyReLU(0.2),
|
947 |
+
)
|
948 |
+
|
949 |
+
self.dc = nn.Sequential(
|
950 |
+
nn.Linear(1024, self.output_dim),
|
951 |
+
#nn.Sigmoid(),
|
952 |
+
)
|
953 |
+
self.cl = nn.Sequential(
|
954 |
+
nn.Linear(1024, self.class_num),
|
955 |
+
# nn.Softmax(dim=1), # poner el que la suma da 1
|
956 |
+
)
|
957 |
+
|
958 |
+
utils.initialize_weights(self)
|
959 |
+
|
960 |
+
def _get_conv_output(self, shape):
|
961 |
+
bs = 1
|
962 |
+
input = Variable(torch.rand(bs, *shape))
|
963 |
+
x = input.squeeze()
|
964 |
+
x = self.conv1(x)
|
965 |
+
x = self.conv2(x)
|
966 |
+
x = self.conv3(x)
|
967 |
+
x = self.conv4(x)
|
968 |
+
x = x.view(x.size(0), -1)
|
969 |
+
return x.shape[1]
|
970 |
+
|
971 |
+
def forward(self, input, origen, dep):
|
972 |
+
# esto va a cambiar cuando tenga color
|
973 |
+
# if (len(input.shape) <= 3):
|
974 |
+
# input = input[:, None, :, :]
|
975 |
+
# im = im[:, None, :, :]
|
976 |
+
# print("D in shape",input.shape)
|
977 |
+
|
978 |
+
# print(input.shape)
|
979 |
+
# print("this si X:", x)
|
980 |
+
# print("now shape", x.shape)
|
981 |
+
x = input
|
982 |
+
|
983 |
+
x = torch.cat((x, origen), 1)
|
984 |
+
x = torch.cat((x, dep), 1)
|
985 |
+
x = self.conv1(x)
|
986 |
+
x = self.conv2(x)
|
987 |
+
x = self.conv3(x)
|
988 |
+
x = self.conv4(x)
|
989 |
+
x = x.view(x.size(0), -1)
|
990 |
+
features = self.fc1(x)
|
991 |
+
x = self.BnLr(features)
|
992 |
+
d = self.dc(x)
|
993 |
+
c = self.cl(x)
|
994 |
+
|
995 |
+
return d, c, features
|
996 |
+
|
997 |
+
class depth_discriminator_noclass_UNet(nn.Module):
|
998 |
+
def __init__(self, input_dim=1, output_dim=1, input_shape=[8, 7, 128, 128], class_num=2, expand_net=2, depth=True, wgan = False):
|
999 |
+
super(depth_discriminator_noclass_UNet, self).__init__()
|
1000 |
+
|
1001 |
+
#discriminator_UNet.__init__(self, input_dim=self.input_dim, output_dim=output_dim, input_shape=input_shape,
|
1002 |
+
# class_num=class_num, expand_net = expand_net)
|
1003 |
+
|
1004 |
+
self.output_dim = output_dim
|
1005 |
+
self.input_shape = list(input_shape)
|
1006 |
+
self.class_num = class_num
|
1007 |
+
self.expandNet = expand_net
|
1008 |
+
self.depth = depth
|
1009 |
+
self.wgan = wgan
|
1010 |
+
|
1011 |
+
if depth:
|
1012 |
+
self.input_dim = input_dim * 2 + 2 # ya que le doy el origen + Dep + class
|
1013 |
+
else:
|
1014 |
+
self.input_dim = input_dim * 2 + 1 # ya que le doy el origen + class
|
1015 |
+
self.conv1 = UnetConvBlock(self.input_dim, pow(2, self.expandNet), stride=1, dropout=0.0, batch_norm = False )
|
1016 |
+
self.conv2 = UnetConvBlock(pow(2, self.expandNet), pow(2, self.expandNet + 1), stride=2, dropout=0.0, batch_norm = False )
|
1017 |
+
self.conv3 = UnetConvBlock(pow(2, self.expandNet + 1), pow(2, self.expandNet + 2), stride=2, dropout=0.0, batch_norm = False )
|
1018 |
+
self.conv4 = UnetConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 3), stride=2, dropout=0.0, batch_norm = False )
|
1019 |
+
self.conv5 = UnetDeSingleConvBlock(pow(2, self.expandNet + 3), pow(2, self.expandNet + 2), stride=1, dropout=0.0, batch_norm = False )
|
1020 |
+
|
1021 |
+
self.lastconvs = []
|
1022 |
+
imagesize = self.input_shape[2] / 8
|
1023 |
+
while imagesize > 4:
|
1024 |
+
self.lastconvs.append(UnetDeSingleConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 2), stride=2, dropout=0.0, batch_norm = False ))
|
1025 |
+
imagesize = imagesize/2
|
1026 |
+
else:
|
1027 |
+
self.lastconvs.append(UnetDeSingleConvBlock(pow(2, self.expandNet + 2), pow(2, self.expandNet + 1), stride=1, dropout=0.0, batch_norm = False ))
|
1028 |
+
|
1029 |
+
self.input_shape[1] = self.input_dim
|
1030 |
+
self.n_size = self._get_conv_output(self.input_shape)
|
1031 |
+
|
1032 |
+
for layer in self.lastconvs:
|
1033 |
+
layer = layer.cuda()
|
1034 |
+
|
1035 |
+
self.fc1 = nn.Sequential(
|
1036 |
+
nn.Linear(self.n_size, 256),
|
1037 |
+
)
|
1038 |
+
|
1039 |
+
self.BnLr = nn.Sequential(
|
1040 |
+
nn.BatchNorm1d(256),
|
1041 |
+
nn.LeakyReLU(0.2),
|
1042 |
+
)
|
1043 |
+
|
1044 |
+
self.dc = nn.Sequential(
|
1045 |
+
nn.Linear(256, self.output_dim),
|
1046 |
+
#nn.Sigmoid(),
|
1047 |
+
)
|
1048 |
+
|
1049 |
+
utils.initialize_weights(self)
|
1050 |
+
|
1051 |
+
def _get_conv_output(self, shape):
|
1052 |
+
bs = 1
|
1053 |
+
input = Variable(torch.rand(bs, *shape))
|
1054 |
+
x = input.squeeze()
|
1055 |
+
x = self.conv1(x)
|
1056 |
+
x = self.conv2(x)
|
1057 |
+
x = self.conv3(x)
|
1058 |
+
x = self.conv4(x)
|
1059 |
+
x = self.conv5(x)
|
1060 |
+
for layer in self.lastconvs:
|
1061 |
+
x = layer(x)
|
1062 |
+
x = x.view(x.size(0), -1)
|
1063 |
+
return x.shape[1]
|
1064 |
+
|
1065 |
+
def forward(self, input, origen, dep, clase):
|
1066 |
+
# esto va a cambiar cuando tenga color
|
1067 |
+
# if (len(input.shape) <= 3):
|
1068 |
+
# input = input[:, None, :, :]
|
1069 |
+
# im = im[:, None, :, :]
|
1070 |
+
# print("D in shape",input.shape)
|
1071 |
+
|
1072 |
+
# print(input.shape)
|
1073 |
+
# print("this si X:", x)
|
1074 |
+
# print("now shape", x.shape)
|
1075 |
+
x = input
|
1076 |
+
##unir layer el rgb de la imagen
|
1077 |
+
x = torch.cat((x, clase), 1)
|
1078 |
+
|
1079 |
+
x = torch.cat((x, origen), 1)
|
1080 |
+
if self.depth:
|
1081 |
+
x = torch.cat((x, dep), 1)
|
1082 |
+
x = self.conv1(x)
|
1083 |
+
x = self.conv2(x)
|
1084 |
+
x = self.conv3(x)
|
1085 |
+
x = self.conv4(x)
|
1086 |
+
x = self.conv5(x)
|
1087 |
+
for layer in self.lastconvs:
|
1088 |
+
x = layer(x)
|
1089 |
+
feature_vector = x.view(x.size(0), -1)
|
1090 |
+
x = self.fc1(feature_vector)
|
1091 |
+
x = self.BnLr(x)
|
1092 |
+
d = self.dc(x)
|
1093 |
+
|
1094 |
+
return d, feature_vector
|
config.ini
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
[validation]
|
3 |
+
total = 50
|
4 |
+
0 = 2822
|
5 |
+
1 = 3038
|
6 |
+
2 = 3760
|
7 |
+
3 = 3512
|
8 |
+
4 = 3349
|
9 |
+
5 = 2812
|
10 |
+
6 = 3383
|
11 |
+
7 = 3606
|
12 |
+
8 = 3612
|
13 |
+
9 = 3666
|
14 |
+
10 = 2933
|
15 |
+
11 = 3613
|
16 |
+
12 = 2881
|
17 |
+
13 = 3609
|
18 |
+
14 = 3066
|
19 |
+
15 = 3654
|
20 |
+
16 = 2821
|
21 |
+
17 = 2784
|
22 |
+
18 = 3186
|
23 |
+
19 = 3138
|
24 |
+
20 = 3187
|
25 |
+
21 = 3482
|
26 |
+
22 = 2701
|
27 |
+
23 = 3320
|
28 |
+
24 = 3716
|
29 |
+
25 = 3501
|
30 |
+
26 = 3441
|
31 |
+
27 = 3768
|
32 |
+
28 = 3158
|
33 |
+
29 = 2841
|
34 |
+
30 = 3466
|
35 |
+
31 = 3547
|
36 |
+
32 = 2920
|
37 |
+
33 = 3439
|
38 |
+
34 = 2669
|
39 |
+
35 = 3183
|
40 |
+
36 = 2760
|
41 |
+
37 = 3605
|
42 |
+
38 = 2941
|
43 |
+
39 = 3729
|
44 |
+
40 = 2958
|
45 |
+
41 = 3745
|
46 |
+
42 = 3417
|
47 |
+
43 = 3218
|
48 |
+
44 = 3093
|
49 |
+
45 = 3699
|
50 |
+
46 = 3255
|
51 |
+
47 = 3616
|
52 |
+
48 = 3623
|
53 |
+
49 = 3590
|
54 |
+
50 = 3496
|
55 |
+
[test]
|
56 |
+
total = 1
|
57 |
+
[train]
|
58 |
+
total = 200
|
59 |
+
0 = 3192
|
60 |
+
1 = 3086
|
61 |
+
2 = 3205
|
62 |
+
3 = 3061
|
63 |
+
4 = 2688
|
64 |
+
5 = 3347
|
65 |
+
6 = 2850
|
66 |
+
7 = 3508
|
67 |
+
8 = 3285
|
68 |
+
9 = 3487
|
69 |
+
10 = 3433
|
70 |
+
11 = 2687
|
71 |
+
12 = 2860
|
72 |
+
13 = 3353
|
73 |
+
14 = 3526
|
74 |
+
15 = 3112
|
75 |
+
16 = 3123
|
76 |
+
17 = 3109
|
77 |
+
18 = 2825
|
78 |
+
19 = 3114
|
79 |
+
20 = 3413
|
80 |
+
21 = 2876
|
81 |
+
22 = 2910
|
82 |
+
23 = 3339
|
83 |
+
24 = 3011
|
84 |
+
25 = 2753
|
85 |
+
26 = 3551
|
86 |
+
27 = 2942
|
87 |
+
28 = 2998
|
88 |
+
29 = 3370
|
89 |
+
30 = 3560
|
90 |
+
31 = 3446
|
91 |
+
32 = 3017
|
92 |
+
33 = 3703
|
93 |
+
34 = 3327
|
94 |
+
35 = 3498
|
95 |
+
36 = 2884
|
96 |
+
37 = 2934
|
97 |
+
38 = 2671
|
98 |
+
39 = 2871
|
99 |
+
40 = 2727
|
100 |
+
41 = 3144
|
101 |
+
42 = 3393
|
102 |
+
43 = 3693
|
103 |
+
44 = 2761
|
104 |
+
45 = 2895
|
105 |
+
46 = 3537
|
106 |
+
47 = 3735
|
107 |
+
48 = 2755
|
108 |
+
49 = 2710
|
109 |
+
50 = 3379
|
110 |
+
51 = 3475
|
111 |
+
52 = 2750
|
112 |
+
53 = 3390
|
113 |
+
54 = 3189
|
114 |
+
55 = 2817
|
115 |
+
56 = 3765
|
116 |
+
57 = 3653
|
117 |
+
58 = 2776
|
118 |
+
59 = 3568
|
119 |
+
60 = 2782
|
120 |
+
61 = 3079
|
121 |
+
62 = 3283
|
122 |
+
63 = 2999
|
123 |
+
64 = 3586
|
124 |
+
65 = 2740
|
125 |
+
66 = 3651
|
126 |
+
67 = 3549
|
127 |
+
68 = 3106
|
128 |
+
69 = 3160
|
129 |
+
70 = 3092
|
130 |
+
71 = 2940
|
131 |
+
72 = 3603
|
132 |
+
73 = 3733
|
133 |
+
74 = 3371
|
134 |
+
75 = 3290
|
135 |
+
76 = 3091
|
136 |
+
77 = 2978
|
137 |
+
78 = 3730
|
138 |
+
79 = 2961
|
139 |
+
80 = 2748
|
140 |
+
81 = 3094
|
141 |
+
82 = 2914
|
142 |
+
83 = 3490
|
143 |
+
84 = 3120
|
144 |
+
85 = 3759
|
145 |
+
86 = 2715
|
146 |
+
87 = 3287
|
147 |
+
88 = 3723
|
148 |
+
89 = 3776
|
149 |
+
90 = 3305
|
150 |
+
91 = 2830
|
151 |
+
92 = 3313
|
152 |
+
93 = 3368
|
153 |
+
94 = 2944
|
154 |
+
95 = 2925
|
155 |
+
96 = 3780
|
156 |
+
97 = 2680
|
157 |
+
98 = 3622
|
158 |
+
99 = 3065
|
159 |
+
100 = 2905
|
160 |
+
101 = 3346
|
161 |
+
102 = 3397
|
162 |
+
103 = 2875
|
163 |
+
104 = 3262
|
164 |
+
105 = 2783
|
165 |
+
106 = 3485
|
166 |
+
107 = 3234
|
167 |
+
108 = 3330
|
168 |
+
109 = 3099
|
169 |
+
110 = 3625
|
170 |
+
111 = 3540
|
171 |
+
112 = 3523
|
172 |
+
113 = 3279
|
173 |
+
114 = 3280
|
174 |
+
115 = 3428
|
175 |
+
116 = 3372
|
176 |
+
117 = 3497
|
177 |
+
118 = 3626
|
178 |
+
119 = 2733
|
179 |
+
120 = 3578
|
180 |
+
121 = 3593
|
181 |
+
122 = 3700
|
182 |
+
123 = 3167
|
183 |
+
124 = 2848
|
184 |
+
125 = 2775
|
185 |
+
126 = 3726
|
186 |
+
127 = 3425
|
187 |
+
128 = 3751
|
188 |
+
129 = 3520
|
189 |
+
130 = 3458
|
190 |
+
131 = 3164
|
191 |
+
132 = 3381
|
192 |
+
133 = 2873
|
193 |
+
134 = 2890
|
194 |
+
135 = 3548
|
195 |
+
136 = 3728
|
196 |
+
137 = 2745
|
197 |
+
138 = 3041
|
198 |
+
139 = 3663
|
199 |
+
140 = 3098
|
200 |
+
141 = 3631
|
201 |
+
142 = 3127
|
202 |
+
143 = 3704
|
203 |
+
144 = 3658
|
204 |
+
145 = 3629
|
205 |
+
146 = 3467
|
206 |
+
147 = 2676
|
207 |
+
148 = 3178
|
208 |
+
149 = 3275
|
209 |
+
150 = 3324
|
210 |
+
151 = 2756
|
211 |
+
152 = 3200
|
212 |
+
153 = 3034
|
213 |
+
154 = 3749
|
214 |
+
155 = 3558
|
215 |
+
156 = 3173
|
216 |
+
157 = 3792
|
217 |
+
158 = 2681
|
218 |
+
159 = 3367
|
219 |
+
160 = 3579
|
220 |
+
161 = 3155
|
221 |
+
162 = 3128
|
222 |
+
163 = 2816
|
223 |
+
164 = 2973
|
224 |
+
165 = 3246
|
225 |
+
166 = 3129
|
226 |
+
167 = 3762
|
227 |
+
168 = 2939
|
228 |
+
169 = 2929
|
229 |
+
170 = 3711
|
230 |
+
171 = 3608
|
231 |
+
172 = 2679
|
232 |
+
173 = 3214
|
233 |
+
174 = 3687
|
234 |
+
175 = 3291
|
235 |
+
176 = 2700
|
236 |
+
177 = 3131
|
237 |
+
178 = 3597
|
238 |
+
179 = 3519
|
239 |
+
180 = 3481
|
240 |
+
181 = 2725
|
241 |
+
182 = 3761
|
242 |
+
183 = 3610
|
243 |
+
184 = 3073
|
244 |
+
185 = 3135
|
245 |
+
186 = 2891
|
246 |
+
187 = 3769
|
247 |
+
188 = 3557
|
248 |
+
189 = 2967
|
249 |
+
190 = 2697
|
250 |
+
191 = 2861
|
251 |
+
192 = 2956
|
252 |
+
193 = 3052
|
253 |
+
194 = 2995
|
254 |
+
195 = 3054
|
255 |
+
196 = 3588
|
256 |
+
197 = 2960
|
257 |
+
198 = 2952
|
258 |
+
199 = 2766
|
259 |
+
200 = 2917
|
dataloader.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
from torchvision import datasets, transforms
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
import torch
|
5 |
+
from configparser import ConfigParser
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import os
|
8 |
+
import torch as th
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
import random
|
12 |
+
from PIL import ImageMath
|
13 |
+
import random
|
14 |
+
|
15 |
+
def dataloader(dataset, input_size, batch_size,dim,split='train', trans=False):
|
16 |
+
#transform = transforms.Compose([transforms.Resize((input_size, input_size)), transforms.ToTensor(),
|
17 |
+
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
18 |
+
if dataset == 'mnist':
|
19 |
+
data_loader = DataLoader(
|
20 |
+
datasets.MNIST('data/mnist', train=True, download=True, transform=transform),
|
21 |
+
batch_size=batch_size, shuffle=True)
|
22 |
+
elif dataset == 'fashion-mnist':
|
23 |
+
data_loader = DataLoader(
|
24 |
+
datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transform),
|
25 |
+
batch_size=batch_size, shuffle=True)
|
26 |
+
elif dataset == 'cifar10':
|
27 |
+
data_loader = DataLoader(
|
28 |
+
datasets.CIFAR10('data/cifar10', train=True, download=True, transform=transform),
|
29 |
+
batch_size=batch_size, shuffle=True)
|
30 |
+
elif dataset == 'svhn':
|
31 |
+
data_loader = DataLoader(
|
32 |
+
datasets.SVHN('data/svhn', split=split, download=True, transform=transform),
|
33 |
+
batch_size=batch_size, shuffle=True)
|
34 |
+
elif dataset == 'stl10':
|
35 |
+
data_loader = DataLoader(
|
36 |
+
datasets.STL10('data/stl10', split=split, download=True, transform=transform),
|
37 |
+
batch_size=batch_size, shuffle=True)
|
38 |
+
elif dataset == 'lsun-bed':
|
39 |
+
data_loader = DataLoader(
|
40 |
+
datasets.LSUN('data/lsun', classes=['bedroom_train'], transform=transform),
|
41 |
+
batch_size=batch_size, shuffle=True)
|
42 |
+
elif dataset == '4cam':
|
43 |
+
if split == 'score':
|
44 |
+
cams = ScoreDataset(root_dir=os.getcwd() + '/Images/Score-Test', dim=dim, name=split, cant_images=300) #hardcode is bad but quick
|
45 |
+
return DataLoader(cams, batch_size=batch_size, shuffle=False, num_workers=0)
|
46 |
+
if split != 'test':
|
47 |
+
cams = ImagesDataset(root_dir=os.getcwd() + '/Images/ActualDataset', dim=dim, name=split, transform=trans)
|
48 |
+
return DataLoader(cams, batch_size=batch_size, shuffle=True, num_workers=0)
|
49 |
+
else:
|
50 |
+
cams = TestingDataset(root_dir=os.getcwd() + '/Images/Input-Test', dim=dim, name=split)
|
51 |
+
return DataLoader(cams, batch_size=batch_size, shuffle=False, num_workers=0)
|
52 |
+
|
53 |
+
return data_loader
|
54 |
+
|
55 |
+
|
56 |
+
class ImagesDataset(Dataset):
|
57 |
+
"""My dataset."""
|
58 |
+
|
59 |
+
def __init__(self, root_dir, dim, name, transform):
|
60 |
+
"""
|
61 |
+
Args:
|
62 |
+
root_dir (string): Directory with all the images.
|
63 |
+
transform (callable, optional): Optional transform to be applied
|
64 |
+
on a sample.
|
65 |
+
"""
|
66 |
+
self.root_dir = root_dir
|
67 |
+
self.nCameras = 2
|
68 |
+
self.imageDim = dim
|
69 |
+
self.name = name
|
70 |
+
self.parser = ConfigParser()
|
71 |
+
self.parser.read('config.ini')
|
72 |
+
self.transform = transform
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
|
76 |
+
return self.parser.getint(self.name, 'total')
|
77 |
+
#oneCameRoot = self.root_dir + '\CAM1'
|
78 |
+
#return int(len([name for name in os.listdir(oneCameRoot) if os.path.isfile(os.path.join(oneCameRoot, name))])/2) #por el depth
|
79 |
+
|
80 |
+
|
81 |
+
def __getitem__(self, idx):
|
82 |
+
if th.is_tensor(idx):
|
83 |
+
idx = idx.tolist()
|
84 |
+
idx = self.parser.get(self.name, str(idx))
|
85 |
+
if self.transform:
|
86 |
+
brighness = random.uniform(0.7, 1.2)
|
87 |
+
saturation = random.uniform(0, 2)
|
88 |
+
contrast = random.uniform(0.4, 2)
|
89 |
+
gamma = random.uniform(0.7, 1.3)
|
90 |
+
hue = random.uniform(-0.3, 0.3) # 0.01
|
91 |
+
|
92 |
+
oneCameRoot = self.root_dir + '/CAM0'
|
93 |
+
|
94 |
+
# foto normal
|
95 |
+
img_name = os.path.join(oneCameRoot, "n_" + idx + ".png")
|
96 |
+
img = Image.open(img_name).convert('RGB') # .convert('L')
|
97 |
+
if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
|
98 |
+
img = img.resize((self.imageDim, self.imageDim))
|
99 |
+
if self.transform:
|
100 |
+
img = transforms.functional.adjust_gamma(img, gamma)
|
101 |
+
img = transforms.functional.adjust_brightness(img, brighness)
|
102 |
+
img = transforms.functional.adjust_contrast(img, contrast)
|
103 |
+
img = transforms.functional.adjust_saturation(img, saturation)
|
104 |
+
img = transforms.functional.adjust_hue(img, hue)
|
105 |
+
x1 = transforms.ToTensor()(img)
|
106 |
+
x1 = (x1 * 2) - 1
|
107 |
+
|
108 |
+
# foto produndidad
|
109 |
+
img_name = os.path.join(oneCameRoot, "d_" + idx + ".png")
|
110 |
+
img = Image.open(img_name).convert('I')
|
111 |
+
img = convert_I_to_L(img)
|
112 |
+
if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
|
113 |
+
img = img.resize((self.imageDim, self.imageDim))
|
114 |
+
x1_dep = transforms.ToTensor()(img)
|
115 |
+
x1_dep = (x1_dep * 2) - 1
|
116 |
+
|
117 |
+
oneCameRoot = self.root_dir + '/CAM1'
|
118 |
+
|
119 |
+
# foto normal
|
120 |
+
img_name = os.path.join(oneCameRoot, "n_" + idx + ".png")
|
121 |
+
img = Image.open(img_name).convert('RGB') # .convert('L')
|
122 |
+
if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
|
123 |
+
img = img.resize((self.imageDim, self.imageDim))
|
124 |
+
if self.transform:
|
125 |
+
img = transforms.functional.adjust_gamma(img, gamma)
|
126 |
+
img = transforms.functional.adjust_brightness(img, brighness)
|
127 |
+
img = transforms.functional.adjust_contrast(img, contrast)
|
128 |
+
img = transforms.functional.adjust_saturation(img, saturation)
|
129 |
+
img = transforms.functional.adjust_hue(img, hue)
|
130 |
+
x2 = transforms.ToTensor()(img)
|
131 |
+
x2 = (x2 * 2) - 1
|
132 |
+
|
133 |
+
# foto produndidad
|
134 |
+
img_name = os.path.join(oneCameRoot, "d_" + idx + ".png")
|
135 |
+
img = Image.open(img_name).convert('I')
|
136 |
+
img = convert_I_to_L(img)
|
137 |
+
if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
|
138 |
+
img = img.resize((self.imageDim, self.imageDim))
|
139 |
+
x2_dep = transforms.ToTensor()(img)
|
140 |
+
x2_dep = (x2_dep * 2) - 1
|
141 |
+
|
142 |
+
|
143 |
+
#random izq o derecha
|
144 |
+
if (bool(random.getrandbits(1))):
|
145 |
+
sample = {'x_im': x1, 'x_dep': x1_dep, 'y_im': x2, 'y_dep': x2_dep, 'y_': torch.ones(1, self.imageDim, self.imageDim)}
|
146 |
+
else:
|
147 |
+
sample = {'x_im': x2, 'x_dep': x2_dep, 'y_im': x1, 'y_dep': x1_dep, 'y_': torch.zeros(1, self.imageDim, self.imageDim)}
|
148 |
+
|
149 |
+
return sample
|
150 |
+
|
151 |
+
def __iter__(self):
|
152 |
+
|
153 |
+
for i in range(this.__len__()):
|
154 |
+
list.append(this.__getitem__(i))
|
155 |
+
return iter(list)
|
156 |
+
|
157 |
+
class TestingDataset(Dataset):
|
158 |
+
"""My dataset."""
|
159 |
+
|
160 |
+
def __init__(self, root_dir, dim, name):
|
161 |
+
"""
|
162 |
+
Args:
|
163 |
+
root_dir (string): Directory with all the images.
|
164 |
+
transform (callable, optional): Optional transform to be applied
|
165 |
+
on a sample.
|
166 |
+
"""
|
167 |
+
self.root_dir = root_dir
|
168 |
+
self.imageDim = dim
|
169 |
+
self.name = name
|
170 |
+
files = os.listdir(self.root_dir)
|
171 |
+
self.files = [ele for ele in files if not ele.endswith('_d.png')]
|
172 |
+
|
173 |
+
def __len__(self):
|
174 |
+
|
175 |
+
#return self.parser.getint(self.name, 'total')
|
176 |
+
#oneCameRoot = self.root_dir + '\CAM1'
|
177 |
+
#return int(len([name for name in os.listdir(self.root_dir) if os.path.isfile(os.path.join(self.root_dir, name))])/2) #por el depth
|
178 |
+
return len(self.files)
|
179 |
+
|
180 |
+
|
181 |
+
def __getitem__(self, idx):
|
182 |
+
if th.is_tensor(idx):
|
183 |
+
idx = idx.tolist()
|
184 |
+
|
185 |
+
# foto normal
|
186 |
+
img_name = os.path.join(self.root_dir, self.files[idx])
|
187 |
+
img = Image.open(img_name).convert('RGB') # .convert('L')
|
188 |
+
if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
|
189 |
+
img = img.resize((self.imageDim, self.imageDim))
|
190 |
+
x1 = transforms.ToTensor()(img)
|
191 |
+
x1 = (x1 * 2) - 1
|
192 |
+
|
193 |
+
|
194 |
+
# foto produndidad
|
195 |
+
img_name = os.path.join(self.root_dir , self.files[idx][:-4] + "_d.png")
|
196 |
+
img = Image.open(img_name).convert('I')
|
197 |
+
img = convert_I_to_L(img)
|
198 |
+
if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
|
199 |
+
img = img.resize((self.imageDim, self.imageDim))
|
200 |
+
x1_dep = transforms.ToTensor()(img)
|
201 |
+
x1_dep = (x1_dep * 2) - 1
|
202 |
+
|
203 |
+
sample = {'x_im': x1, 'x_dep': x1_dep}
|
204 |
+
|
205 |
+
return sample
|
206 |
+
|
207 |
+
def __iter__(self):
|
208 |
+
|
209 |
+
for i in range(this.__len__()):
|
210 |
+
list.append(this.__getitem__(i))
|
211 |
+
return iter(list)
|
212 |
+
|
213 |
+
|
214 |
+
def show_image(t_data, grey=False):
|
215 |
+
|
216 |
+
#from numpy
|
217 |
+
t_data2 = t_data.transpose(1, 2, 0)
|
218 |
+
t_data2 = t_data2 * 255.0
|
219 |
+
t_data2 = t_data2.astype(np.uint8)
|
220 |
+
if (not grey):
|
221 |
+
outIm = Image.fromarray(t_data2, mode='RGB')
|
222 |
+
else:
|
223 |
+
t_data2 = np.squeeze(t_data2, axis=2)
|
224 |
+
outIm = Image.fromarray(t_data2, mode='L')
|
225 |
+
outIm.show()
|
226 |
+
|
227 |
+
def convert_I_to_L(img):
|
228 |
+
array = np.uint8(np.array(img) / 256) #el numero esta bien, sino genera espacios en negro en la imagen
|
229 |
+
return Image.fromarray(array)
|
230 |
+
|
231 |
+
class ScoreDataset(Dataset):
|
232 |
+
"""My dataset."""
|
233 |
+
|
234 |
+
def __init__(self, root_dir, dim, name, cant_images):
|
235 |
+
"""
|
236 |
+
Args:
|
237 |
+
root_dir (string): Directory with all the images.
|
238 |
+
transform (callable, optional): Optional transform to be applied
|
239 |
+
on a sample.
|
240 |
+
"""
|
241 |
+
self.root_dir = root_dir
|
242 |
+
self.nCameras = 2
|
243 |
+
self.imageDim = dim
|
244 |
+
self.name = name
|
245 |
+
self.size = cant_images
|
246 |
+
|
247 |
+
def __len__(self):
|
248 |
+
|
249 |
+
return self.size
|
250 |
+
|
251 |
+
|
252 |
+
def __getitem__(self, idx):
|
253 |
+
|
254 |
+
oneCameRoot = self.root_dir + '/CAM0'
|
255 |
+
|
256 |
+
idx = "{:04d}".format(idx)
|
257 |
+
# foto normal
|
258 |
+
img_name = os.path.join(oneCameRoot, "n_" + idx + ".png")
|
259 |
+
img = Image.open(img_name).convert('RGB') # .convert('L')
|
260 |
+
if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
|
261 |
+
img = img.resize((self.imageDim, self.imageDim))
|
262 |
+
x1 = transforms.ToTensor()(img)
|
263 |
+
x1 = (x1 * 2) - 1
|
264 |
+
|
265 |
+
# foto produndidad
|
266 |
+
img_name = os.path.join(oneCameRoot, "d_" + idx + ".png")
|
267 |
+
img = Image.open(img_name).convert('I')
|
268 |
+
img = convert_I_to_L(img)
|
269 |
+
if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
|
270 |
+
img = img.resize((self.imageDim, self.imageDim))
|
271 |
+
x1_dep = transforms.ToTensor()(img)
|
272 |
+
x1_dep = (x1_dep * 2) - 1
|
273 |
+
|
274 |
+
oneCameRoot = self.root_dir + '/CAM1'
|
275 |
+
|
276 |
+
# foto normal
|
277 |
+
img_name = os.path.join(oneCameRoot, "n_" + idx + ".png")
|
278 |
+
img = Image.open(img_name).convert('RGB') # .convert('L')
|
279 |
+
if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
|
280 |
+
img = img.resize((self.imageDim, self.imageDim))
|
281 |
+
x2 = transforms.ToTensor()(img)
|
282 |
+
x2 = (x2 * 2) - 1
|
283 |
+
|
284 |
+
# foto produndidad
|
285 |
+
img_name = os.path.join(oneCameRoot, "d_" + idx + ".png")
|
286 |
+
img = Image.open(img_name).convert('I')
|
287 |
+
img = convert_I_to_L(img)
|
288 |
+
if (img.size[0] != self.imageDim or img.size[1] != self.imageDim):
|
289 |
+
img = img.resize((self.imageDim, self.imageDim))
|
290 |
+
x2_dep = transforms.ToTensor()(img)
|
291 |
+
x2_dep = (x2_dep * 2) - 1
|
292 |
+
|
293 |
+
|
294 |
+
sample = {'x_im': x1, 'x_dep': x1_dep, 'y_im': x2, 'y_dep': x2_dep, 'y_': torch.ones(1, self.imageDim, self.imageDim)}
|
295 |
+
return sample
|
296 |
+
|
297 |
+
def __iter__(self):
|
298 |
+
|
299 |
+
for i in range(self.__len__()):
|
300 |
+
list.append(self.__getitem__(i))
|
301 |
+
return iter(list)
|
epochData.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:baf9bf7acbc95f817b9f79d9be24fe553e8beeacda79854ebcfe9fc2707df120
|
3 |
+
size 210
|
main.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from WiggleGAN import WiggleGAN
|
5 |
+
#from MyACGAN import MyACGAN
|
6 |
+
#from MyGAN import MyGAN
|
7 |
+
|
8 |
+
"""parsing and configuration"""
|
9 |
+
|
10 |
+
|
11 |
+
def parse_args():
|
12 |
+
desc = "Pytorch implementation of GAN collections"
|
13 |
+
parser = argparse.ArgumentParser(description=desc)
|
14 |
+
|
15 |
+
parser.add_argument('--gan_type', type=str, default='WiggleGAN',
|
16 |
+
choices=['MyACGAN', 'MyGAN', 'WiggleGAN'],
|
17 |
+
help='The type of GAN')
|
18 |
+
parser.add_argument('--dataset', type=str, default='4cam',
|
19 |
+
choices=['mnist', 'fashion-mnist', 'cifar10', 'cifar100', 'svhn', 'stl10', 'lsun-bed', '4cam'],
|
20 |
+
help='The name of dataset')
|
21 |
+
parser.add_argument('--split', type=str, default='', help='The split flag for svhn and stl10')
|
22 |
+
parser.add_argument('--epoch', type=int, default=50, help='The number of epochs to run')
|
23 |
+
parser.add_argument('--batch_size', type=int, default=16, help='The size of batch')
|
24 |
+
parser.add_argument('--input_size', type=int, default=10, help='The size of input image')
|
25 |
+
parser.add_argument('--save_dir', type=str, default='models',
|
26 |
+
help='Directory name to save the model')
|
27 |
+
parser.add_argument('--result_dir', type=str, default='results', help='Directory name to save the generated images')
|
28 |
+
parser.add_argument('--log_dir', type=str, default='logs', help='Directory name to save training logs')
|
29 |
+
parser.add_argument('--lrG', type=float, default=0.0002)
|
30 |
+
parser.add_argument('--lrD', type=float, default=0.001)
|
31 |
+
parser.add_argument('--beta1', type=float, default=0.5)
|
32 |
+
parser.add_argument('--beta2', type=float, default=0.999)
|
33 |
+
parser.add_argument('--gpu_mode', type=str2bool, default=True)
|
34 |
+
parser.add_argument('--benchmark_mode', type=str2bool, default=True)
|
35 |
+
parser.add_argument('--cameras', type=int, default=2)
|
36 |
+
parser.add_argument('--imageDim', type=int, default=128)
|
37 |
+
parser.add_argument('--epochV', type=int, default=0)
|
38 |
+
parser.add_argument('--cIm', type=int, default=4)
|
39 |
+
parser.add_argument('--seedLoad', type=str, default="-0000")
|
40 |
+
parser.add_argument('--zGF', type=float, default=0.2)
|
41 |
+
parser.add_argument('--zDF', type=float, default=0.2)
|
42 |
+
parser.add_argument('--bF', type=float, default=0.2)
|
43 |
+
parser.add_argument('--expandGen', type=int, default=3)
|
44 |
+
parser.add_argument('--expandDis', type=int, default=3)
|
45 |
+
parser.add_argument('--wiggleDepth', type=int, default=-1)
|
46 |
+
parser.add_argument('--visdom', type=str2bool, default=True)
|
47 |
+
parser.add_argument('--lambdaL1', type=int, default=100)
|
48 |
+
parser.add_argument('--clipping', type=float, default=-1)
|
49 |
+
parser.add_argument('--depth', type=str2bool, default=True)
|
50 |
+
parser.add_argument('--recreate', type=str2bool, default=False)
|
51 |
+
parser.add_argument('--name_wiggle', type=str, default='wiggle-result')
|
52 |
+
|
53 |
+
return check_args(parser.parse_args())
|
54 |
+
|
55 |
+
|
56 |
+
"""checking arguments"""
|
57 |
+
|
58 |
+
def str2bool(v):
|
59 |
+
if isinstance(v, bool):
|
60 |
+
return v
|
61 |
+
if v.lower() in ('yes', 'true', 't', 'y', '1'):
|
62 |
+
return True
|
63 |
+
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
|
64 |
+
return False
|
65 |
+
else:
|
66 |
+
raise argparse.ArgumentTypeError('Boolean value expected.')
|
67 |
+
|
68 |
+
|
69 |
+
def check_args(args):
|
70 |
+
# --save_dir
|
71 |
+
if not os.path.exists(args.save_dir):
|
72 |
+
os.makedirs(args.save_dir)
|
73 |
+
|
74 |
+
# --result_dir
|
75 |
+
if not os.path.exists(args.result_dir):
|
76 |
+
os.makedirs(args.result_dir)
|
77 |
+
|
78 |
+
# --result_dir
|
79 |
+
if not os.path.exists(args.log_dir):
|
80 |
+
os.makedirs(args.log_dir)
|
81 |
+
|
82 |
+
# --epoch
|
83 |
+
try:
|
84 |
+
assert args.epoch >= 1
|
85 |
+
except:
|
86 |
+
print('number of epochs must be larger than or equal to one')
|
87 |
+
|
88 |
+
# --batch_size
|
89 |
+
try:
|
90 |
+
assert args.batch_size >= 1
|
91 |
+
except:
|
92 |
+
print('batch size must be larger than or equal to one')
|
93 |
+
|
94 |
+
return args
|
95 |
+
|
96 |
+
|
97 |
+
"""main"""
|
98 |
+
|
99 |
+
|
100 |
+
def main():
|
101 |
+
# parse arguments
|
102 |
+
args = parse_args()
|
103 |
+
if args is None:
|
104 |
+
exit()
|
105 |
+
|
106 |
+
if args.benchmark_mode:
|
107 |
+
torch.backends.cudnn.benchmark = True
|
108 |
+
|
109 |
+
# declare instance for GAN
|
110 |
+
if args.gan_type == 'WiggleGAN':
|
111 |
+
gan = WiggleGAN(args)
|
112 |
+
#elif args.gan_type == 'MyACGAN':
|
113 |
+
# gan = MyACGAN(args)
|
114 |
+
#elif args.gan_type == 'MyGAN':
|
115 |
+
# gan = MyGAN(args)
|
116 |
+
else:
|
117 |
+
raise Exception("[!] There is no option for " + args.gan_type)
|
118 |
+
|
119 |
+
# launch the graph in a session
|
120 |
+
if (args.wiggleDepth < 0 and not args.recreate):
|
121 |
+
print(" [*] Training Starting!")
|
122 |
+
gan.train()
|
123 |
+
print(" [*] Training finished!")
|
124 |
+
else:
|
125 |
+
if not args.recreate:
|
126 |
+
print(" [*] Wiggle Started!")
|
127 |
+
gan.wiggleEf()
|
128 |
+
print(" [*] Wiggle finished!")
|
129 |
+
else:
|
130 |
+
print(" [*] Dataset recreation Started")
|
131 |
+
gan.recreate()
|
132 |
+
print(" [*] Dataset recreation finished")
|
133 |
+
|
134 |
+
|
135 |
+
if __name__ == '__main__':
|
136 |
+
main()
|
models/4cam/WiggleGAN/WiggleGAN_31219_110_G.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4b39604e99319045e9070632a7aa31cd5adbd0220126515093856f97af622ff
|
3 |
+
size 1252850
|
models/4cam/WiggleGAN/WiggleGAN_66942_110_G.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:da994f51205701f9754dc1688cffd12b72f593f37c61833ec4b7c8860e152236
|
3 |
+
size 1252850
|
models/4cam/WiggleGAN/WiggleGAN_70466_110_G.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:310b22bf4f5375174b23347b85b64c9de7934cafec6a61b3d647bfb7f24b5ae7
|
3 |
+
size 1252850
|
models/4cam/WiggleGAN/WiggleGAN_70944_110_G.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5734a5e102c75e4afde944f2898171fb34373c002b651ca84901ed9f55ae385d
|
3 |
+
size 1252850
|
models/4cam/WiggleGAN/WiggleGAN_74962_110_G.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d06a0da4295b6b6c5277f3cf987327a60818460780fb3aec42e514cbc3f71c71
|
3 |
+
size 1252850
|
models/4cam/WiggleGAN/WiggleGAN_82122_110_G.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:170c8e095c66665ef87f199e5308a39d90fe2f5d0f2dfa5d8c789675657e0423
|
3 |
+
size 1252850
|
models/4cam/WiggleGAN/WiggleGAN_92332_110_G.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5a9cbec7ad0978008bcda05a96865b71016663278ed18c935b25875f7b08a979
|
3 |
+
size 1252850
|
pyvenv.cfg
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
home = C:\Program Files (x86)\Microsoft Visual Studio\Shared\Python37_64
|
2 |
+
include-system-site-packages = false
|
3 |
+
version = 3.7.8
|
requirements.txt
CHANGED
@@ -1,4 +1,27 @@
|
|
1 |
timm
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
torch
|
4 |
-
|
|
|
|
1 |
timm
|
2 |
+
opencv-python
|
3 |
+
certifi==2019.11.28
|
4 |
+
chardet==3.0.4
|
5 |
+
cycler==0.10.0
|
6 |
+
idna==2.8
|
7 |
+
imageio==2.5.0
|
8 |
+
jsonpatch==1.24
|
9 |
+
jsonpointer==2.0
|
10 |
+
kiwisolver==1.1.0
|
11 |
+
matplotlib==3.1.1
|
12 |
+
numpy==1.17.2
|
13 |
+
Pillow==6.1.0
|
14 |
+
pyparsing==2.4.2
|
15 |
+
python-dateutil==2.8.0
|
16 |
+
PyYAML==5.1.2
|
17 |
+
pyzmq==18.1.1
|
18 |
+
requests==2.22.0
|
19 |
+
scipy==1.1.0
|
20 |
+
six==1.12.0
|
21 |
+
urllib3==1.25.7
|
22 |
+
visdom==0.1.8.9
|
23 |
+
websocket-client==0.56.0
|
24 |
+
tornado==6.0.3
|
25 |
torch
|
26 |
+
torchfile==0.1.0
|
27 |
+
torchvision==0.2.1
|
utils.py
ADDED
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, gzip, torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import numpy as np
|
4 |
+
import scipy.misc
|
5 |
+
import imageio
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from PIL import Image
|
8 |
+
from torchvision import datasets, transforms
|
9 |
+
import visdom
|
10 |
+
import random
|
11 |
+
|
12 |
+
def save_wiggle(images, rows=1, name="test"):
|
13 |
+
|
14 |
+
|
15 |
+
width = images[0].shape[1]
|
16 |
+
height = images[0].shape[2]
|
17 |
+
columns = int(len(images)/rows)
|
18 |
+
rows = int(rows)
|
19 |
+
margin = 4
|
20 |
+
|
21 |
+
total_width = (width + margin) * columns
|
22 |
+
total_height = (height + margin) * rows
|
23 |
+
|
24 |
+
new_im = Image.new('RGB', (total_width, total_height))
|
25 |
+
|
26 |
+
transToPil = transforms.ToPILImage()
|
27 |
+
|
28 |
+
x_offset = 3
|
29 |
+
y_offset = 3
|
30 |
+
for y in range(rows):
|
31 |
+
for x in range(columns):
|
32 |
+
im = images[x+y*columns]
|
33 |
+
im = transToPil((im+1)/2)
|
34 |
+
new_im.paste(im, (x_offset, y_offset))
|
35 |
+
x_offset += width + margin
|
36 |
+
x_offset = 3
|
37 |
+
y_offset += height + margin
|
38 |
+
|
39 |
+
new_im.save('./WiggleResults/' + name + '.jpg')
|
40 |
+
|
41 |
+
def load_mnist(dataset):
|
42 |
+
data_dir = os.path.join("./data", dataset)
|
43 |
+
|
44 |
+
def extract_data(filename, num_data, head_size, data_size):
|
45 |
+
with gzip.open(filename) as bytestream:
|
46 |
+
bytestream.read(head_size)
|
47 |
+
buf = bytestream.read(data_size * num_data)
|
48 |
+
data = np.frombuffer(buf, dtype=np.uint8).astype(np.float)
|
49 |
+
return data
|
50 |
+
|
51 |
+
data = extract_data(data_dir + '/train-images-idx3-ubyte.gz', 60000, 16, 28 * 28)
|
52 |
+
trX = data.reshape((60000, 28, 28, 1))
|
53 |
+
|
54 |
+
data = extract_data(data_dir + '/train-labels-idx1-ubyte.gz', 60000, 8, 1)
|
55 |
+
trY = data.reshape((60000))
|
56 |
+
|
57 |
+
data = extract_data(data_dir + '/t10k-images-idx3-ubyte.gz', 10000, 16, 28 * 28)
|
58 |
+
teX = data.reshape((10000, 28, 28, 1))
|
59 |
+
|
60 |
+
data = extract_data(data_dir + '/t10k-labels-idx1-ubyte.gz', 10000, 8, 1)
|
61 |
+
teY = data.reshape((10000))
|
62 |
+
|
63 |
+
trY = np.asarray(trY).astype(np.int)
|
64 |
+
teY = np.asarray(teY)
|
65 |
+
|
66 |
+
X = np.concatenate((trX, teX), axis=0)
|
67 |
+
y = np.concatenate((trY, teY), axis=0).astype(np.int)
|
68 |
+
|
69 |
+
seed = 547
|
70 |
+
np.random.seed(seed)
|
71 |
+
np.random.shuffle(X)
|
72 |
+
np.random.seed(seed)
|
73 |
+
np.random.shuffle(y)
|
74 |
+
|
75 |
+
y_vec = np.zeros((len(y), 10), dtype=np.float)
|
76 |
+
for i, label in enumerate(y):
|
77 |
+
y_vec[i, y[i]] = 1
|
78 |
+
|
79 |
+
X = X.transpose(0, 3, 1, 2) / 255.
|
80 |
+
# y_vec = y_vec.transpose(0, 3, 1, 2)
|
81 |
+
|
82 |
+
X = torch.from_numpy(X).type(torch.FloatTensor)
|
83 |
+
y_vec = torch.from_numpy(y_vec).type(torch.FloatTensor)
|
84 |
+
return X, y_vec
|
85 |
+
|
86 |
+
def load_celebA(dir, transform, batch_size, shuffle):
|
87 |
+
# transform = transforms.Compose([
|
88 |
+
# transforms.CenterCrop(160),
|
89 |
+
# transform.Scale(64)
|
90 |
+
# transforms.ToTensor(),
|
91 |
+
# transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
92 |
+
# ])
|
93 |
+
|
94 |
+
# data_dir = 'data/celebA' # this path depends on your computer
|
95 |
+
dset = datasets.ImageFolder(dir, transform)
|
96 |
+
data_loader = torch.utils.data.DataLoader(dset, batch_size, shuffle)
|
97 |
+
|
98 |
+
return data_loader
|
99 |
+
|
100 |
+
|
101 |
+
def print_network(net):
|
102 |
+
num_params = 0
|
103 |
+
for param in net.parameters():
|
104 |
+
num_params += param.numel()
|
105 |
+
print(net)
|
106 |
+
print('Total number of parameters: %d' % num_params)
|
107 |
+
|
108 |
+
def save_images(images, size, image_path):
|
109 |
+
return imsave(images, size, image_path)
|
110 |
+
|
111 |
+
def imsave(images, size, path):
|
112 |
+
image = np.squeeze(merge(images, size))
|
113 |
+
return scipy.misc.imsave(path, image)
|
114 |
+
|
115 |
+
def merge(images, size):
|
116 |
+
#print ("shape", images.shape)
|
117 |
+
h, w = images.shape[1], images.shape[2]
|
118 |
+
if (images.shape[3] in (3,4)):
|
119 |
+
c = images.shape[3]
|
120 |
+
img = np.zeros((h * size[0], w * size[1], c))
|
121 |
+
for idx, image in enumerate(images):
|
122 |
+
i = idx % size[1]
|
123 |
+
j = idx // size[1]
|
124 |
+
img[j * h:j * h + h, i * w:i * w + w, :] = image
|
125 |
+
return img
|
126 |
+
elif images.shape[3]== 1:
|
127 |
+
img = np.zeros((h * size[0], w * size[1]))
|
128 |
+
for idx, image in enumerate(images):
|
129 |
+
#print("indez ",idx)
|
130 |
+
i = idx % size[1]
|
131 |
+
j = idx // size[1]
|
132 |
+
img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
|
133 |
+
return img
|
134 |
+
else:
|
135 |
+
raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')
|
136 |
+
|
137 |
+
def generate_animation(path, num):
|
138 |
+
images = []
|
139 |
+
for e in range(num):
|
140 |
+
img_name = path + '_epoch%04d' % (e+1) + '.png'
|
141 |
+
images.append(imageio.imread(img_name))
|
142 |
+
imageio.mimsave(path + '_generate_animation.gif', images, fps=5)
|
143 |
+
|
144 |
+
def loss_plot(hist, path = 'Train_hist.png', model_name = ''):
|
145 |
+
x1 = range(len(hist['D_loss_train']))
|
146 |
+
x2 = range(len(hist['G_loss_train']))
|
147 |
+
|
148 |
+
y1 = hist['D_loss_train']
|
149 |
+
y2 = hist['G_loss_train']
|
150 |
+
|
151 |
+
if (x1 != x2):
|
152 |
+
y1 = [0.0] * (len(y2) - len(y1)) + y1
|
153 |
+
x1 = x2
|
154 |
+
|
155 |
+
plt.plot(x1, y1, label='D_loss_train')
|
156 |
+
|
157 |
+
plt.plot(x2, y2, label='G_loss_train')
|
158 |
+
|
159 |
+
plt.xlabel('Iter')
|
160 |
+
plt.ylabel('Loss')
|
161 |
+
|
162 |
+
plt.legend(loc=4)
|
163 |
+
plt.grid(True)
|
164 |
+
plt.tight_layout()
|
165 |
+
|
166 |
+
path = os.path.join(path, model_name + '_loss.png')
|
167 |
+
|
168 |
+
plt.savefig(path)
|
169 |
+
|
170 |
+
plt.close()
|
171 |
+
|
172 |
+
def initialize_weights(net):
|
173 |
+
for m in net.modules():
|
174 |
+
if isinstance(m, nn.Conv2d):
|
175 |
+
m.weight.data.normal_(0, 0.02)
|
176 |
+
m.bias.data.zero_()
|
177 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
178 |
+
m.weight.data.normal_(0, 0.02)
|
179 |
+
m.bias.data.zero_()
|
180 |
+
elif isinstance(m, nn.Linear):
|
181 |
+
m.weight.data.normal_(0, 0.02)
|
182 |
+
m.bias.data.zero_()
|
183 |
+
|
184 |
+
class VisdomLinePlotter(object):
|
185 |
+
"""Plots to Visdom"""
|
186 |
+
def __init__(self, env_name='main'):
|
187 |
+
self.viz = visdom.Visdom()
|
188 |
+
self.env = env_name
|
189 |
+
self.ini = False
|
190 |
+
self.count = 1
|
191 |
+
def plot(self, var_name,names, split_name, hist):
|
192 |
+
|
193 |
+
|
194 |
+
|
195 |
+
x = []
|
196 |
+
y = []
|
197 |
+
for i, name in enumerate(names):
|
198 |
+
x.append(self.count)
|
199 |
+
y.append(hist[name])
|
200 |
+
self.count+=1
|
201 |
+
#x1 = (len(hist['D_loss_' +split_name]))
|
202 |
+
#x2 = (len(hist['G_loss_' +split_name]))
|
203 |
+
|
204 |
+
#y1 = hist['D_loss_'+split_name]
|
205 |
+
#y2 = hist['G_loss_'+split_name]
|
206 |
+
|
207 |
+
|
208 |
+
np.array(x)
|
209 |
+
|
210 |
+
|
211 |
+
for i,n in enumerate(names):
|
212 |
+
x[i] = np.arange(1, x[i]+1)
|
213 |
+
|
214 |
+
if not self.ini:
|
215 |
+
for i, name in enumerate(names):
|
216 |
+
if i == 0:
|
217 |
+
self.win = self.viz.line(X=x[i], Y=np.array(y[i]), env=self.env,name = name,opts=dict(
|
218 |
+
title=var_name + '_'+split_name, showlegend = True
|
219 |
+
))
|
220 |
+
else:
|
221 |
+
self.viz.line(X=x[i], Y=np.array(y[i]), env=self.env,win=self.win, name=name, update='append')
|
222 |
+
self.ini = True
|
223 |
+
else:
|
224 |
+
x[0] = np.array([x[0][-2], x[0][-1]])
|
225 |
+
|
226 |
+
for i,n in enumerate(names):
|
227 |
+
y[i] = np.array([y[i][-2], y[i][-1]])
|
228 |
+
self.viz.line(X=x[0], Y=np.array(y[i]), env=self.env, win=self.win, name=n, update='append')
|
229 |
+
|
230 |
+
|
231 |
+
class VisdomLineTwoPlotter(VisdomLinePlotter):
|
232 |
+
|
233 |
+
def plot(self, var_name, epoch,names, hist):
|
234 |
+
|
235 |
+
x1 = epoch
|
236 |
+
y1 = hist[names[0]]
|
237 |
+
y2 = hist[names[1]]
|
238 |
+
y3 = hist[names[2]]
|
239 |
+
y4 = hist[names[3]]
|
240 |
+
|
241 |
+
|
242 |
+
#y1 = hist['D_loss_' + split_name]
|
243 |
+
#y2 = hist['G_loss_' + split_name]
|
244 |
+
#y3 = hist['D_loss_' + split_name2]
|
245 |
+
#y4 = hist['G_loss_' + split_name2]
|
246 |
+
|
247 |
+
|
248 |
+
#x1 = np.arange(1, x1+1)
|
249 |
+
|
250 |
+
if not self.ini:
|
251 |
+
self.win = self.viz.line(X=np.array([x1]), Y=np.array(y1), env=self.env,name = names[0],opts=dict(
|
252 |
+
title=var_name,
|
253 |
+
showlegend = True,
|
254 |
+
linecolor = np.array([[0, 0, 255]])
|
255 |
+
))
|
256 |
+
self.viz.line(X=np.array([x1]), Y=np.array(y2), env=self.env,win=self.win, name=names[1],
|
257 |
+
update='append', opts=dict(
|
258 |
+
linecolor=np.array([[255, 153, 51]])
|
259 |
+
))
|
260 |
+
self.viz.line(X=np.array([x1]), Y=np.array(y3), env=self.env, win=self.win, name=names[2],
|
261 |
+
update='append', opts=dict(
|
262 |
+
linecolor=np.array([[0, 51, 153]])
|
263 |
+
))
|
264 |
+
self.viz.line(X=np.array([x1]), Y=np.array(y4), env=self.env, win=self.win, name=names[3],
|
265 |
+
update='append', opts=dict(
|
266 |
+
linecolor=np.array([[204, 51, 0]])
|
267 |
+
))
|
268 |
+
self.ini = True
|
269 |
+
else:
|
270 |
+
|
271 |
+
y4 = np.array([y4[-2], y4[-1]])
|
272 |
+
y3 = np.array([y3[-2], y3[-1]])
|
273 |
+
y2 = np.array([y2[-2], y2[-1]])
|
274 |
+
y1 = np.array([y1[-2], y1[-1]])
|
275 |
+
x1 = np.array([x1 - 1, x1])
|
276 |
+
self.viz.line(X=x1, Y=np.array(y1), env=self.env, win=self.win, name=names[0], update='append')
|
277 |
+
self.viz.line(X=x1, Y=np.array(y2), env=self.env, win=self.win, name=names[1], update='append')
|
278 |
+
self.viz.line(X=x1, Y=np.array(y3), env=self.env, win=self.win, name=names[2],
|
279 |
+
update='append')
|
280 |
+
self.viz.line(X=x1, Y=np.array(y4), env=self.env, win=self.win, name=names[3],
|
281 |
+
update='append')
|
282 |
+
|
283 |
+
class VisdomImagePlotter(object):
|
284 |
+
"""Plots to Visdom"""
|
285 |
+
def __init__(self, env_name='main'):
|
286 |
+
self.viz = visdom.Visdom()
|
287 |
+
self.env = env_name
|
288 |
+
def plot(self, epoch,images,rows):
|
289 |
+
|
290 |
+
list_images = []
|
291 |
+
for image in images:
|
292 |
+
#transforms.ToPILImage()(image)
|
293 |
+
image = (image + 1)/2
|
294 |
+
image = image.detach().numpy() * 255
|
295 |
+
list_images.append(image)
|
296 |
+
self.viz.images(
|
297 |
+
list_images,
|
298 |
+
padding=2,
|
299 |
+
nrow =rows,
|
300 |
+
opts=dict(title="epoch: " + str(epoch)),
|
301 |
+
env=self.env
|
302 |
+
)
|
303 |
+
|
304 |
+
|
305 |
+
def augmentData(x,y, randomness = 1, percent_noise = 0.1):
|
306 |
+
"""
|
307 |
+
:param x: image X
|
308 |
+
:param y: image Y
|
309 |
+
:param randomness: Value of randomness (between 1 and 0)
|
310 |
+
:return: data x,y augmented
|
311 |
+
"""
|
312 |
+
|
313 |
+
|
314 |
+
sampleX = torch.tensor([])
|
315 |
+
sampleY = torch.tensor([])
|
316 |
+
|
317 |
+
for aumX, aumY in zip(x,y):
|
318 |
+
|
319 |
+
# Preparing to get image # transforms.ToPILImage()(pil_to_tensor.squeeze_(0))
|
320 |
+
#percent_noise = percent_noise
|
321 |
+
#noise = torch.randn(aumX.shape)
|
322 |
+
|
323 |
+
#aumX = noise * percent_noise + aumX * (1 - percent_noise)
|
324 |
+
#aumY = noise * percent_noise + aumY * (1 - percent_noise)
|
325 |
+
|
326 |
+
aumX = (aumX + 1) / 2
|
327 |
+
aumY = (aumY + 1) / 2
|
328 |
+
|
329 |
+
imgX = transforms.ToPILImage()(aumX)
|
330 |
+
imgY = transforms.ToPILImage()(aumY)
|
331 |
+
|
332 |
+
# Values for augmentation #
|
333 |
+
brighness = random.uniform(0.7, 1.2)* randomness + (1-randomness)
|
334 |
+
saturation = random.uniform(0, 2)* randomness + (1-randomness)
|
335 |
+
contrast = random.uniform(0.4, 2)* randomness + (1-randomness)
|
336 |
+
gamma = random.uniform(0.7, 1.3)* randomness + (1-randomness)
|
337 |
+
hue = random.uniform(-0.3, 0.3)* randomness #0.01
|
338 |
+
|
339 |
+
imgX = transforms.functional.adjust_gamma(imgX, gamma)
|
340 |
+
imgX = transforms.functional.adjust_brightness(imgX, brighness)
|
341 |
+
imgX = transforms.functional.adjust_contrast(imgX, contrast)
|
342 |
+
imgX = transforms.functional.adjust_saturation(imgX, saturation)
|
343 |
+
imgX = transforms.functional.adjust_hue(imgX, hue)
|
344 |
+
#imgX.show()
|
345 |
+
|
346 |
+
imgY = transforms.functional.adjust_gamma(imgY, gamma)
|
347 |
+
imgY = transforms.functional.adjust_brightness(imgY, brighness)
|
348 |
+
imgY = transforms.functional.adjust_contrast(imgY, contrast)
|
349 |
+
imgY = transforms.functional.adjust_saturation(imgY, saturation)
|
350 |
+
imgY = transforms.functional.adjust_hue(imgY, hue)
|
351 |
+
#imgY.show()
|
352 |
+
|
353 |
+
sx = transforms.ToTensor()(imgX)
|
354 |
+
sx = (sx * 2)-1
|
355 |
+
|
356 |
+
sy = transforms.ToTensor()(imgY)
|
357 |
+
sy = (sy * 2)-1
|
358 |
+
|
359 |
+
sampleX = torch.cat((sampleX, sx.unsqueeze_(0)), 0)
|
360 |
+
sampleY = torch.cat((sampleY, sy.unsqueeze_(0)), 0)
|
361 |
+
return sampleX,sampleY
|
362 |
+
|
363 |
+
def RGBtoL (x):
|
364 |
+
|
365 |
+
return x[:,0,:,:].unsqueeze(0).transpose(0,1)
|
366 |
+
|
367 |
+
def LtoRGB (x):
|
368 |
+
|
369 |
+
return x.repeat(1, 3, 1, 1)
|