gaur3009 commited on
Commit
9862b96
·
verified ·
1 Parent(s): aa0fb88

Upload 5 files

Browse files
Files changed (5) hide show
  1. networks.py +541 -0
  2. requirements (2).txt +6 -0
  3. test.py +226 -0
  4. train.py +232 -0
  5. visualization.py +64 -0
networks.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import torch
3
+ import torch.nn as nn
4
+ from torch.nn import init
5
+ from torchvision import models
6
+ import os
7
+
8
+ import numpy as np
9
+
10
+
11
+ def weights_init_normal(m):
12
+ classname = m.__class__.__name__
13
+ if classname.find('Conv') != -1:
14
+ init.normal_(m.weight.data, 0.0, 0.02)
15
+ elif classname.find('Linear') != -1:
16
+ init.normal(m.weight.data, 0.0, 0.02)
17
+ elif classname.find('BatchNorm2d') != -1:
18
+ init.normal_(m.weight.data, 1.0, 0.02)
19
+ init.constant_(m.bias.data, 0.0)
20
+
21
+
22
+ def weights_init_xavier(m):
23
+ classname = m.__class__.__name__
24
+ if classname.find('Conv') != -1:
25
+ init.xavier_normal_(m.weight.data, gain=0.02)
26
+ elif classname.find('Linear') != -1:
27
+ init.xavier_normal_(m.weight.data, gain=0.02)
28
+ elif classname.find('BatchNorm2d') != -1:
29
+ init.normal_(m.weight.data, 1.0, 0.02)
30
+ init.constant_(m.bias.data, 0.0)
31
+
32
+
33
+ def weights_init_kaiming(m):
34
+ classname = m.__class__.__name__
35
+ if classname.find('Conv') != -1:
36
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
37
+ elif classname.find('Linear') != -1:
38
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
39
+ elif classname.find('BatchNorm2d') != -1:
40
+ init.normal_(m.weight.data, 1.0, 0.02)
41
+ init.constant_(m.bias.data, 0.0)
42
+
43
+
44
+ def init_weights(net, init_type='normal'):
45
+ print('initialization method [%s]' % init_type)
46
+ if init_type == 'normal':
47
+ net.apply(weights_init_normal)
48
+ elif init_type == 'xavier':
49
+ net.apply(weights_init_xavier)
50
+ elif init_type == 'kaiming':
51
+ net.apply(weights_init_kaiming)
52
+ else:
53
+ raise NotImplementedError(
54
+ 'initialization method [%s] is not implemented' % init_type)
55
+
56
+
57
+ class FeatureExtraction(nn.Module):
58
+ def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_dropout=False):
59
+ super(FeatureExtraction, self).__init__()
60
+ downconv = nn.Conv2d(input_nc, ngf, kernel_size=4, stride=2, padding=1)
61
+ model = [downconv, nn.ReLU(True), norm_layer(ngf)]
62
+ for i in range(n_layers):
63
+ in_ngf = 2**i * ngf if 2**i * ngf < 512 else 512
64
+ out_ngf = 2**(i+1) * ngf if 2**i * ngf < 512 else 512
65
+ downconv = nn.Conv2d(
66
+ in_ngf, out_ngf, kernel_size=4, stride=2, padding=1)
67
+ model += [downconv, nn.ReLU(True)]
68
+ model += [norm_layer(out_ngf)]
69
+ model += [nn.Conv2d(512, 512, kernel_size=3,
70
+ stride=1, padding=1), nn.ReLU(True)]
71
+ model += [norm_layer(512)]
72
+ model += [nn.Conv2d(512, 512, kernel_size=3,
73
+ stride=1, padding=1), nn.ReLU(True)]
74
+
75
+ self.model = nn.Sequential(*model)
76
+ init_weights(self.model, init_type='normal')
77
+
78
+ def forward(self, x):
79
+ return self.model(x)
80
+
81
+
82
+ class FeatureL2Norm(torch.nn.Module):
83
+ def __init__(self):
84
+ super(FeatureL2Norm, self).__init__()
85
+
86
+ def forward(self, feature):
87
+ epsilon = 1e-6
88
+ norm = torch.pow(torch.sum(torch.pow(feature, 2), 1) +
89
+ epsilon, 0.5).unsqueeze(1).expand_as(feature)
90
+ return torch.div(feature, norm)
91
+
92
+
93
+ class FeatureCorrelation(nn.Module):
94
+ def __init__(self):
95
+ super(FeatureCorrelation, self).__init__()
96
+
97
+ def forward(self, feature_A, feature_B):
98
+ b, c, h, w = feature_A.size()
99
+ # reshape features for matrix multiplication
100
+ feature_A = feature_A.transpose(2, 3).contiguous().view(b, c, h*w)
101
+ feature_B = feature_B.view(b, c, h*w).transpose(1, 2)
102
+ # perform matrix mult.
103
+ feature_mul = torch.bmm(feature_B, feature_A)
104
+ correlation_tensor = feature_mul.view(
105
+ b, h, w, h*w).transpose(2, 3).transpose(1, 2)
106
+ return correlation_tensor
107
+
108
+
109
+ class FeatureRegression(nn.Module):
110
+ def __init__(self, input_nc=512, output_dim=6, use_cuda=True):
111
+ super(FeatureRegression, self).__init__()
112
+ self.conv = nn.Sequential(
113
+ nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1),
114
+ nn.BatchNorm2d(512),
115
+ nn.ReLU(inplace=True),
116
+ nn.Conv2d(512, 256, kernel_size=4, stride=2, padding=1),
117
+ nn.BatchNorm2d(256),
118
+ nn.ReLU(inplace=True),
119
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
120
+ nn.BatchNorm2d(128),
121
+ nn.ReLU(inplace=True),
122
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
123
+ nn.BatchNorm2d(64),
124
+ nn.ReLU(inplace=True),
125
+ )
126
+ self.linear = nn.Linear(64 * 4 * 3, output_dim)
127
+ self.tanh = nn.Tanh()
128
+ if use_cuda:
129
+ self.conv.cuda()
130
+ self.linear.cuda()
131
+ self.tanh.cuda()
132
+
133
+ def forward(self, x):
134
+ x = self.conv(x)
135
+ x = x.view(x.size(0), -1)
136
+ x = self.linear(x)
137
+ x = self.tanh(x)
138
+ return x
139
+
140
+
141
+ class AffineGridGen(nn.Module):
142
+ def __init__(self, out_h=256, out_w=192, out_ch=3):
143
+ super(AffineGridGen, self).__init__()
144
+ self.out_h = out_h
145
+ self.out_w = out_w
146
+ self.out_ch = out_ch
147
+
148
+ def forward(self, theta):
149
+ theta = theta.contiguous()
150
+ batch_size = theta.size()[0]
151
+ out_size = torch.Size(
152
+ (batch_size, self.out_ch, self.out_h, self.out_w))
153
+ return F.affine_grid(theta, out_size)
154
+
155
+
156
+ class TpsGridGen(nn.Module):
157
+ def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0, use_cuda=True):
158
+ super(TpsGridGen, self).__init__()
159
+ self.out_h, self.out_w = out_h, out_w
160
+ self.reg_factor = reg_factor
161
+ self.use_cuda = use_cuda
162
+
163
+ # create grid in numpy
164
+ self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
165
+ # sampling grid with dim-0 coords (Y)
166
+ self.grid_X, self.grid_Y = np.meshgrid(
167
+ np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
168
+ # grid_X,grid_Y: size [1,H,W,1,1]
169
+ self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3)
170
+ self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3)
171
+ if use_cuda:
172
+ self.grid_X = self.grid_X.cuda()
173
+ self.grid_Y = self.grid_Y.cuda()
174
+
175
+ # initialize regular grid for control points P_i
176
+ if use_regular_grid:
177
+ axis_coords = np.linspace(-1, 1, grid_size)
178
+ self.N = grid_size*grid_size
179
+ P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
180
+ P_X = np.reshape(P_X, (-1, 1)) # size (N,1)
181
+ P_Y = np.reshape(P_Y, (-1, 1)) # size (N,1)
182
+ P_X = torch.FloatTensor(P_X)
183
+ P_Y = torch.FloatTensor(P_Y)
184
+ self.P_X_base = P_X.clone()
185
+ self.P_Y_base = P_Y.clone()
186
+ self.Li = self.compute_L_inverse(P_X, P_Y).unsqueeze(0)
187
+ self.P_X = P_X.unsqueeze(2).unsqueeze(
188
+ 3).unsqueeze(4).transpose(0, 4)
189
+ self.P_Y = P_Y.unsqueeze(2).unsqueeze(
190
+ 3).unsqueeze(4).transpose(0, 4)
191
+ if use_cuda:
192
+ self.P_X = self.P_X.cuda()
193
+ self.P_Y = self.P_Y.cuda()
194
+ self.P_X_base = self.P_X_base.cuda()
195
+ self.P_Y_base = self.P_Y_base.cuda()
196
+
197
+ def forward(self, theta):
198
+ warped_grid = self.apply_transformation(
199
+ theta, torch.cat((self.grid_X, self.grid_Y), 3))
200
+
201
+ return warped_grid
202
+
203
+ def compute_L_inverse(self, X, Y):
204
+ N = X.size()[0] # num of points (along dim 0)
205
+ # construct matrix K
206
+ Xmat = X.expand(N, N)
207
+ Ymat = Y.expand(N, N)
208
+ P_dist_squared = torch.pow(
209
+ Xmat-Xmat.transpose(0, 1), 2)+torch.pow(Ymat-Ymat.transpose(0, 1), 2)
210
+ # make diagonal 1 to avoid NaN in log computation
211
+ P_dist_squared[P_dist_squared == 0] = 1
212
+ K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
213
+ # construct matrix L
214
+ O = torch.FloatTensor(N, 1).fill_(1)
215
+ Z = torch.FloatTensor(3, 3).fill_(0)
216
+ P = torch.cat((O, X, Y), 1)
217
+ L = torch.cat((torch.cat((K, P), 1), torch.cat(
218
+ (P.transpose(0, 1), Z), 1)), 0)
219
+ Li = torch.inverse(L)
220
+ if self.use_cuda:
221
+ Li = Li.cuda()
222
+ return Li
223
+
224
+ def apply_transformation(self, theta, points):
225
+ if theta.dim() == 2:
226
+ theta = theta.unsqueeze(2).unsqueeze(3)
227
+ # points should be in the [B,H,W,2] format,
228
+ # where points[:,:,:,0] are the X coords
229
+ # and points[:,:,:,1] are the Y coords
230
+
231
+ # input are the corresponding control points P_i
232
+ batch_size = theta.size()[0]
233
+ # split theta into point coordinates
234
+ Q_X = theta[:, :self.N, :, :].squeeze(3)
235
+ Q_Y = theta[:, self.N:, :, :].squeeze(3)
236
+ Q_X = Q_X + self.P_X_base.expand_as(Q_X)
237
+ Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
238
+
239
+ # get spatial dimensions of points
240
+ points_b = points.size()[0]
241
+ points_h = points.size()[1]
242
+ points_w = points.size()[2]
243
+
244
+ # repeat pre-defined control points along spatial dimensions of points to be transformed
245
+ P_X = self.P_X.expand((1, points_h, points_w, 1, self.N))
246
+ P_Y = self.P_Y.expand((1, points_h, points_w, 1, self.N))
247
+
248
+ # compute weigths for non-linear part
249
+ W_X = torch.bmm(self.Li[:, :self.N, :self.N].expand(
250
+ (batch_size, self.N, self.N)), Q_X)
251
+ W_Y = torch.bmm(self.Li[:, :self.N, :self.N].expand(
252
+ (batch_size, self.N, self.N)), Q_Y)
253
+ # reshape
254
+ # W_X,W,Y: size [B,H,W,1,N]
255
+ W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(
256
+ 1, 4).repeat(1, points_h, points_w, 1, 1)
257
+ W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(
258
+ 1, 4).repeat(1, points_h, points_w, 1, 1)
259
+ # compute weights for affine part
260
+ A_X = torch.bmm(self.Li[:, self.N:, :self.N].expand(
261
+ (batch_size, 3, self.N)), Q_X)
262
+ A_Y = torch.bmm(self.Li[:, self.N:, :self.N].expand(
263
+ (batch_size, 3, self.N)), Q_Y)
264
+ # reshape
265
+ # A_X,A,Y: size [B,H,W,1,3]
266
+ A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(
267
+ 1, 4).repeat(1, points_h, points_w, 1, 1)
268
+ A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(
269
+ 1, 4).repeat(1, points_h, points_w, 1, 1)
270
+
271
+ # compute distance P_i - (grid_X,grid_Y)
272
+ # grid is expanded in point dim 4, but not in batch dim 0, as points P_X,P_Y are fixed for all batch
273
+ points_X_for_summation = points[:, :, :, 0].unsqueeze(
274
+ 3).unsqueeze(4).expand(points[:, :, :, 0].size()+(1, self.N))
275
+ points_Y_for_summation = points[:, :, :, 1].unsqueeze(
276
+ 3).unsqueeze(4).expand(points[:, :, :, 1].size()+(1, self.N))
277
+
278
+ if points_b == 1:
279
+ delta_X = points_X_for_summation-P_X
280
+ delta_Y = points_Y_for_summation-P_Y
281
+ else:
282
+ # use expanded P_X,P_Y in batch dimension
283
+ delta_X = points_X_for_summation - \
284
+ P_X.expand_as(points_X_for_summation)
285
+ delta_Y = points_Y_for_summation - \
286
+ P_Y.expand_as(points_Y_for_summation)
287
+
288
+ dist_squared = torch.pow(delta_X, 2)+torch.pow(delta_Y, 2)
289
+ # U: size [1,H,W,1,N]
290
+ dist_squared[dist_squared == 0] = 1 # avoid NaN in log computation
291
+ U = torch.mul(dist_squared, torch.log(dist_squared))
292
+
293
+ # expand grid in batch dimension if necessary
294
+ points_X_batch = points[:, :, :, 0].unsqueeze(3)
295
+ points_Y_batch = points[:, :, :, 1].unsqueeze(3)
296
+ if points_b == 1:
297
+ points_X_batch = points_X_batch.expand(
298
+ (batch_size,)+points_X_batch.size()[1:])
299
+ points_Y_batch = points_Y_batch.expand(
300
+ (batch_size,)+points_Y_batch.size()[1:])
301
+
302
+ points_X_prime = A_X[:, :, :, :, 0] + \
303
+ torch.mul(A_X[:, :, :, :, 1], points_X_batch) + \
304
+ torch.mul(A_X[:, :, :, :, 2], points_Y_batch) + \
305
+ torch.sum(torch.mul(W_X, U.expand_as(W_X)), 4)
306
+
307
+ points_Y_prime = A_Y[:, :, :, :, 0] + \
308
+ torch.mul(A_Y[:, :, :, :, 1], points_X_batch) + \
309
+ torch.mul(A_Y[:, :, :, :, 2], points_Y_batch) + \
310
+ torch.sum(torch.mul(W_Y, U.expand_as(W_Y)), 4)
311
+
312
+ return torch.cat((points_X_prime, points_Y_prime), 3)
313
+
314
+ # Defines the Unet generator.
315
+ # |num_downs|: number of downsamplings in UNet. For example,
316
+ # if |num_downs| == 7, image of size 128x128 will become of size 1x1
317
+ # at the bottleneck
318
+
319
+
320
+ class UnetGenerator(nn.Module):
321
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64,
322
+ norm_layer=nn.BatchNorm2d, use_dropout=False):
323
+ super(UnetGenerator, self).__init__()
324
+ # construct unet structure
325
+ unet_block = UnetSkipConnectionBlock(
326
+ ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True)
327
+ for i in range(num_downs - 5):
328
+ unet_block = UnetSkipConnectionBlock(
329
+ ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
330
+ unet_block = UnetSkipConnectionBlock(
331
+ ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
332
+ unet_block = UnetSkipConnectionBlock(
333
+ ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
334
+ unet_block = UnetSkipConnectionBlock(
335
+ ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
336
+ unet_block = UnetSkipConnectionBlock(
337
+ output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer)
338
+
339
+ self.model = unet_block
340
+
341
+ def forward(self, input):
342
+ return self.model(input)
343
+
344
+
345
+ # Defines the submodule with skip connection.
346
+ # X -------------------identity---------------------- X
347
+ # |-- downsampling -- |submodule| -- upsampling --|
348
+ class UnetSkipConnectionBlock(nn.Module):
349
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
350
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
351
+ super(UnetSkipConnectionBlock, self).__init__()
352
+ self.outermost = outermost
353
+ use_bias = norm_layer == nn.InstanceNorm2d
354
+
355
+ if input_nc is None:
356
+ input_nc = outer_nc
357
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
358
+ stride=2, padding=1, bias=use_bias)
359
+ downrelu = nn.LeakyReLU(0.2, True)
360
+ downnorm = norm_layer(inner_nc)
361
+ uprelu = nn.ReLU(True)
362
+ upnorm = norm_layer(outer_nc)
363
+
364
+ if outermost:
365
+ upsample = nn.Upsample(scale_factor=2, mode='bilinear')
366
+ upconv = nn.Conv2d(inner_nc * 2, outer_nc,
367
+ kernel_size=3, stride=1, padding=1, bias=use_bias)
368
+ down = [downconv]
369
+ up = [uprelu, upsample, upconv, upnorm]
370
+ model = down + [submodule] + up
371
+ elif innermost:
372
+ upsample = nn.Upsample(scale_factor=2, mode='bilinear')
373
+ upconv = nn.Conv2d(inner_nc, outer_nc, kernel_size=3,
374
+ stride=1, padding=1, bias=use_bias)
375
+ down = [downrelu, downconv]
376
+ up = [uprelu, upsample, upconv, upnorm]
377
+ model = down + up
378
+ else:
379
+ upsample = nn.Upsample(scale_factor=2, mode='bilinear')
380
+ upconv = nn.Conv2d(inner_nc*2, outer_nc, kernel_size=3,
381
+ stride=1, padding=1, bias=use_bias)
382
+ down = [downrelu, downconv, downnorm]
383
+ up = [uprelu, upsample, upconv, upnorm]
384
+
385
+ if use_dropout:
386
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
387
+ else:
388
+ model = down + [submodule] + up
389
+
390
+ self.model = nn.Sequential(*model)
391
+
392
+ def forward(self, x):
393
+ if self.outermost:
394
+ return self.model(x)
395
+ else:
396
+ return torch.cat([x, self.model(x)], 1)
397
+
398
+
399
+ class Vgg19(nn.Module):
400
+ def __init__(self, requires_grad=False):
401
+ super(Vgg19, self).__init__()
402
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
403
+ self.slice1 = torch.nn.Sequential()
404
+ self.slice2 = torch.nn.Sequential()
405
+ self.slice3 = torch.nn.Sequential()
406
+ self.slice4 = torch.nn.Sequential()
407
+ self.slice5 = torch.nn.Sequential()
408
+ for x in range(2):
409
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
410
+ for x in range(2, 7):
411
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
412
+ for x in range(7, 12):
413
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
414
+ for x in range(12, 21):
415
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
416
+ for x in range(21, 30):
417
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
418
+ if not requires_grad:
419
+ for param in self.parameters():
420
+ param.requires_grad = False
421
+
422
+ def forward(self, X):
423
+ h_relu1 = self.slice1(X)
424
+ h_relu2 = self.slice2(h_relu1)
425
+ h_relu3 = self.slice3(h_relu2)
426
+ h_relu4 = self.slice4(h_relu3)
427
+ h_relu5 = self.slice5(h_relu4)
428
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
429
+ return out
430
+
431
+
432
+ class VGGLoss(nn.Module):
433
+ def __init__(self, layids=None):
434
+ super(VGGLoss, self).__init__()
435
+ self.vgg = Vgg19()
436
+ self.vgg.cuda()
437
+ self.criterion = nn.L1Loss()
438
+ self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
439
+ self.layids = layids
440
+
441
+ def forward(self, x, y):
442
+ x_vgg, y_vgg = self.vgg(x), self.vgg(y)
443
+ loss = 0
444
+ if self.layids is None:
445
+ self.layids = list(range(len(x_vgg)))
446
+ for i in self.layids:
447
+ loss += self.weights[i] * \
448
+ self.criterion(x_vgg[i], y_vgg[i].detach())
449
+ return loss
450
+
451
+
452
+ class DT(nn.Module):
453
+ def __init__(self):
454
+ super(DT, self).__init__()
455
+
456
+ def forward(self, x1, x2):
457
+ dt = torch.abs(x1 - x2)
458
+ return dt
459
+
460
+
461
+ class DT2(nn.Module):
462
+ def __init__(self):
463
+ super(DT, self).__init__()
464
+
465
+ def forward(self, x1, y1, x2, y2):
466
+ dt = torch.sqrt(torch.mul(x1 - x2, x1 - x2) +
467
+ torch.mul(y1 - y2, y1 - y2))
468
+ return dt
469
+
470
+
471
+ class GicLoss(nn.Module):
472
+ def __init__(self, opt):
473
+ super(GicLoss, self).__init__()
474
+ self.dT = DT()
475
+ self.opt = opt
476
+
477
+ def forward(self, grid):
478
+ Gx = grid[:, :, :, 0]
479
+ Gy = grid[:, :, :, 1]
480
+ Gxcenter = Gx[:, 1:self.opt.fine_height - 1, 1:self.opt.fine_width - 1]
481
+ Gxup = Gx[:, 0:self.opt.fine_height - 2, 1:self.opt.fine_width - 1]
482
+ Gxdown = Gx[:, 2:self.opt.fine_height, 1:self.opt.fine_width - 1]
483
+ Gxleft = Gx[:, 1:self.opt.fine_height - 1, 0:self.opt.fine_width - 2]
484
+ Gxright = Gx[:, 1:self.opt.fine_height - 1, 2:self.opt.fine_width]
485
+
486
+ Gycenter = Gy[:, 1:self.opt.fine_height - 1, 1:self.opt.fine_width - 1]
487
+ Gyup = Gy[:, 0:self.opt.fine_height - 2, 1:self.opt.fine_width - 1]
488
+ Gydown = Gy[:, 2:self.opt.fine_height, 1:self.opt.fine_width - 1]
489
+ Gyleft = Gy[:, 1:self.opt.fine_height - 1, 0:self.opt.fine_width - 2]
490
+ Gyright = Gy[:, 1:self.opt.fine_height - 1, 2:self.opt.fine_width]
491
+
492
+ dtleft = self.dT(Gxleft, Gxcenter)
493
+ dtright = self.dT(Gxright, Gxcenter)
494
+ dtup = self.dT(Gyup, Gycenter)
495
+ dtdown = self.dT(Gydown, Gycenter)
496
+
497
+ return torch.sum(torch.abs(dtleft - dtright) + torch.abs(dtup - dtdown))
498
+
499
+
500
+ class GMM(nn.Module):
501
+ """ Geometric Matching Module
502
+ """
503
+
504
+ def __init__(self, opt):
505
+ super(GMM, self).__init__()
506
+ self.extractionA = FeatureExtraction(
507
+ 22, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d)
508
+ self.extractionB = FeatureExtraction(
509
+ 1, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d)
510
+ self.l2norm = FeatureL2Norm()
511
+ self.correlation = FeatureCorrelation()
512
+ self.regression = FeatureRegression(
513
+ input_nc=192, output_dim=2*opt.grid_size**2, use_cuda=True)
514
+ self.gridGen = TpsGridGen(
515
+ opt.fine_height, opt.fine_width, use_cuda=True, grid_size=opt.grid_size)
516
+
517
+ def forward(self, inputA, inputB):
518
+ featureA = self.extractionA(inputA)
519
+ featureB = self.extractionB(inputB)
520
+ featureA = self.l2norm(featureA)
521
+ featureB = self.l2norm(featureB)
522
+ correlation = self.correlation(featureA, featureB)
523
+
524
+ theta = self.regression(correlation)
525
+ grid = self.gridGen(theta)
526
+ return grid, theta
527
+
528
+
529
+ def save_checkpoint(model, save_path):
530
+ if not os.path.exists(os.path.dirname(save_path)):
531
+ os.makedirs(os.path.dirname(save_path))
532
+
533
+ torch.save(model.cpu().state_dict(), save_path)
534
+ model.cuda()
535
+
536
+
537
+ def load_checkpoint(model, checkpoint_path):
538
+ if not os.path.exists(checkpoint_path):
539
+ return
540
+ model.load_state_dict(torch.load(checkpoint_path))
541
+ model.cuda()
requirements (2).txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch>=1.10
2
+ torchvision>=0.11
3
+ tensorboardX
4
+ pillow
5
+ numpy
6
+ opencv-contrib-python
test.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import argparse
7
+ import os
8
+ import time
9
+ from cp_dataset import CPDataset, CPDataLoader
10
+ from networks import GMM, UnetGenerator, load_checkpoint
11
+
12
+ from tensorboardX import SummaryWriter
13
+ from visualization import board_add_image, board_add_images, save_images
14
+
15
+
16
+ def get_opt():
17
+ parser = argparse.ArgumentParser()
18
+
19
+ parser.add_argument("--name", default="GMM")
20
+ # parser.add_argument("--name", default="TOM")
21
+
22
+ parser.add_argument("--gpu_ids", default="")
23
+ parser.add_argument('-j', '--workers', type=int, default=1)
24
+ parser.add_argument('-b', '--batch-size', type=int, default=4)
25
+
26
+ parser.add_argument("--dataroot", default="data")
27
+
28
+ # parser.add_argument("--datamode", default="train")
29
+ parser.add_argument("--datamode", default="test")
30
+
31
+ parser.add_argument("--stage", default="GMM")
32
+ # parser.add_argument("--stage", default="TOM")
33
+
34
+ # parser.add_argument("--data_list", default="train_pairs.txt")
35
+ parser.add_argument("--data_list", default="test_pairs.txt")
36
+ # parser.add_argument("--data_list", default="test_pairs_same.txt")
37
+
38
+ parser.add_argument("--fine_width", type=int, default=192)
39
+ parser.add_argument("--fine_height", type=int, default=256)
40
+ parser.add_argument("--radius", type=int, default=5)
41
+ parser.add_argument("--grid_size", type=int, default=5)
42
+
43
+ parser.add_argument('--tensorboard_dir', type=str,
44
+ default='tensorboard', help='save tensorboard infos')
45
+
46
+ parser.add_argument('--result_dir', type=str,
47
+ default='result', help='save result infos')
48
+
49
+ parser.add_argument('--checkpoint', type=str, default='checkpoints/GMM/gmm_final.pth', help='model checkpoint for test')
50
+ # parser.add_argument('--checkpoint', type=str, default='checkpoints/TOM/tom_final.pth', help='model checkpoint for test')
51
+
52
+ parser.add_argument("--display_count", type=int, default=1)
53
+ parser.add_argument("--shuffle", action='store_true',
54
+ help='shuffle input data')
55
+
56
+ opt = parser.parse_args()
57
+ return opt
58
+
59
+
60
+ def test_gmm(opt, test_loader, model, board):
61
+ model.cuda()
62
+ model.eval()
63
+
64
+ base_name = os.path.basename(opt.checkpoint)
65
+ name = opt.name
66
+ save_dir = os.path.join(opt.result_dir, name, opt.datamode)
67
+ if not os.path.exists(save_dir):
68
+ os.makedirs(save_dir)
69
+ warp_cloth_dir = os.path.join(save_dir, 'warp-cloth')
70
+ if not os.path.exists(warp_cloth_dir):
71
+ os.makedirs(warp_cloth_dir)
72
+ warp_mask_dir = os.path.join(save_dir, 'warp-mask')
73
+ if not os.path.exists(warp_mask_dir):
74
+ os.makedirs(warp_mask_dir)
75
+ result_dir1 = os.path.join(save_dir, 'result_dir')
76
+ if not os.path.exists(result_dir1):
77
+ os.makedirs(result_dir1)
78
+ overlayed_TPS_dir = os.path.join(save_dir, 'overlayed_TPS')
79
+ if not os.path.exists(overlayed_TPS_dir):
80
+ os.makedirs(overlayed_TPS_dir)
81
+ warped_grid_dir = os.path.join(save_dir, 'warped_grid')
82
+ if not os.path.exists(warped_grid_dir):
83
+ os.makedirs(warped_grid_dir)
84
+ for step, inputs in enumerate(test_loader.data_loader):
85
+ iter_start_time = time.time()
86
+
87
+ c_names = inputs['c_name']
88
+ im_names = inputs['im_name']
89
+ im = inputs['image'].cuda()
90
+ im_pose = inputs['pose_image'].cuda()
91
+ im_h = inputs['head'].cuda()
92
+ shape = inputs['shape'].cuda()
93
+ agnostic = inputs['agnostic'].cuda()
94
+ c = inputs['cloth'].cuda()
95
+ cm = inputs['cloth_mask'].cuda()
96
+ im_c = inputs['parse_cloth'].cuda()
97
+ im_g = inputs['grid_image'].cuda()
98
+ shape_ori = inputs['shape_ori'] # original body shape without blurring
99
+
100
+ grid, theta = model(agnostic, cm)
101
+ warped_cloth = F.grid_sample(c, grid, padding_mode='border')
102
+ warped_mask = F.grid_sample(cm, grid, padding_mode='zeros')
103
+ warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')
104
+ overlay = 0.7 * warped_cloth + 0.3 * im
105
+
106
+ visuals = [[im_h, shape, im_pose],
107
+ [c, warped_cloth, im_c],
108
+ [warped_grid, (warped_cloth+im)*0.5, im]]
109
+
110
+ # save_images(warped_cloth, c_names, warp_cloth_dir)
111
+ # save_images(warped_mask*2-1, c_names, warp_mask_dir)
112
+ save_images(warped_cloth, im_names, warp_cloth_dir)
113
+ save_images(warped_mask * 2 - 1, im_names, warp_mask_dir)
114
+ save_images(shape_ori.cuda() * 0.2 + warped_cloth *
115
+ 0.8, im_names, result_dir1)
116
+ save_images(warped_grid, im_names, warped_grid_dir)
117
+ save_images(overlay, im_names, overlayed_TPS_dir)
118
+
119
+ if (step+1) % opt.display_count == 0:
120
+ board_add_images(board, 'combine', visuals, step+1)
121
+ t = time.time() - iter_start_time
122
+ print('step: %8d, time: %.3f' % (step+1, t), flush=True)
123
+
124
+
125
+ def test_tom(opt, test_loader, model, board):
126
+ model.cuda()
127
+ model.eval()
128
+
129
+ base_name = os.path.basename(opt.checkpoint)
130
+ # save_dir = os.path.join(opt.result_dir, base_name, opt.datamode)
131
+ save_dir = os.path.join(opt.result_dir, opt.name, opt.datamode)
132
+ if not os.path.exists(save_dir):
133
+ os.makedirs(save_dir)
134
+ try_on_dir = os.path.join(save_dir, 'try-on')
135
+ if not os.path.exists(try_on_dir):
136
+ os.makedirs(try_on_dir)
137
+ p_rendered_dir = os.path.join(save_dir, 'p_rendered')
138
+ if not os.path.exists(p_rendered_dir):
139
+ os.makedirs(p_rendered_dir)
140
+ m_composite_dir = os.path.join(save_dir, 'm_composite')
141
+ if not os.path.exists(m_composite_dir):
142
+ os.makedirs(m_composite_dir)
143
+ im_pose_dir = os.path.join(save_dir, 'im_pose')
144
+ if not os.path.exists(im_pose_dir):
145
+ os.makedirs(im_pose_dir)
146
+ shape_dir = os.path.join(save_dir, 'shape')
147
+ if not os.path.exists(shape_dir):
148
+ os.makedirs(shape_dir)
149
+ im_h_dir = os.path.join(save_dir, 'im_h')
150
+ if not os.path.exists(im_h_dir):
151
+ os.makedirs(im_h_dir) # for test data
152
+
153
+ print('Dataset size: %05d!' % (len(test_loader.dataset)), flush=True)
154
+ for step, inputs in enumerate(test_loader.data_loader):
155
+ iter_start_time = time.time()
156
+
157
+ im_names = inputs['im_name']
158
+ im = inputs['image'].cuda()
159
+ im_pose = inputs['pose_image']
160
+ im_h = inputs['head']
161
+ shape = inputs['shape']
162
+
163
+ agnostic = inputs['agnostic'].cuda()
164
+ c = inputs['cloth'].cuda()
165
+ cm = inputs['cloth_mask'].cuda()
166
+
167
+ # outputs = model(torch.cat([agnostic, c], 1)) # CP-VTON
168
+ outputs = model(torch.cat([agnostic, c, cm], 1)) # CP-VTON+
169
+ p_rendered, m_composite = torch.split(outputs, 3, 1)
170
+ p_rendered = F.tanh(p_rendered)
171
+ m_composite = F.sigmoid(m_composite)
172
+ p_tryon = c * m_composite + p_rendered * (1 - m_composite)
173
+
174
+ visuals = [[im_h, shape, im_pose],
175
+ [c, 2*cm-1, m_composite],
176
+ [p_rendered, p_tryon, im]]
177
+
178
+ save_images(p_tryon, im_names, try_on_dir)
179
+ save_images(im_h, im_names, im_h_dir)
180
+ save_images(shape, im_names, shape_dir)
181
+ save_images(im_pose, im_names, im_pose_dir)
182
+ save_images(m_composite, im_names, m_composite_dir)
183
+ save_images(p_rendered, im_names, p_rendered_dir) # For test data
184
+
185
+ if (step+1) % opt.display_count == 0:
186
+ board_add_images(board, 'combine', visuals, step+1)
187
+ t = time.time() - iter_start_time
188
+ print('step: %8d, time: %.3f' % (step+1, t), flush=True)
189
+
190
+
191
+ def main():
192
+ opt = get_opt()
193
+ print(opt)
194
+ print("Start to test stage: %s, named: %s!" % (opt.stage, opt.name))
195
+
196
+ # create dataset
197
+ test_dataset = CPDataset(opt)
198
+
199
+ # create dataloader
200
+ test_loader = CPDataLoader(opt, test_dataset)
201
+
202
+ # visualization
203
+ if not os.path.exists(opt.tensorboard_dir):
204
+ os.makedirs(opt.tensorboard_dir)
205
+ board = SummaryWriter(logdir=os.path.join(opt.tensorboard_dir, opt.name))
206
+
207
+ # create model & test
208
+ if opt.stage == 'GMM':
209
+ model = GMM(opt)
210
+ load_checkpoint(model, opt.checkpoint)
211
+ with torch.no_grad():
212
+ test_gmm(opt, test_loader, model, board)
213
+ elif opt.stage == 'TOM':
214
+ # model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) # CP-VTON
215
+ model = UnetGenerator(26, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) # CP-VTON+
216
+ load_checkpoint(model, opt.checkpoint)
217
+ with torch.no_grad():
218
+ test_tom(opt, test_loader, model, board)
219
+ else:
220
+ raise NotImplementedError('Model [%s] is not implemented' % opt.stage)
221
+
222
+ print('Finished test %s, named: %s!' % (opt.stage, opt.name))
223
+
224
+
225
+ if __name__ == "__main__":
226
+ main()
train.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import argparse
7
+ import os
8
+ import time
9
+ from cp_dataset import CPDataset, CPDataLoader
10
+ from networks import GicLoss, GMM, UnetGenerator, VGGLoss, load_checkpoint, save_checkpoint
11
+
12
+ from tensorboardX import SummaryWriter
13
+ from visualization import board_add_image, board_add_images
14
+
15
+
16
+ def get_opt():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--name", default="GMM")
19
+ # parser.add_argument("--name", default="TOM")
20
+
21
+ parser.add_argument("--gpu_ids", default="")
22
+ parser.add_argument('-j', '--workers', type=int, default=1)
23
+ parser.add_argument('-b', '--batch-size', type=int, default=4)
24
+
25
+ parser.add_argument("--dataroot", default="data")
26
+
27
+ parser.add_argument("--datamode", default="train")
28
+
29
+ parser.add_argument("--stage", default="GMM")
30
+ # parser.add_argument("--stage", default="TOM")
31
+
32
+ parser.add_argument("--data_list", default="train_pairs.txt")
33
+
34
+ parser.add_argument("--fine_width", type=int, default=192)
35
+ parser.add_argument("--fine_height", type=int, default=256)
36
+ parser.add_argument("--radius", type=int, default=5)
37
+ parser.add_argument("--grid_size", type=int, default=5)
38
+ parser.add_argument('--lr', type=float, default=0.0001,
39
+ help='initial learning rate for adam')
40
+ parser.add_argument('--tensorboard_dir', type=str,
41
+ default='tensorboard', help='save tensorboard infos')
42
+ parser.add_argument('--checkpoint_dir', type=str,
43
+ default='checkpoints', help='save checkpoint infos')
44
+ parser.add_argument('--checkpoint', type=str, default='',
45
+ help='model checkpoint for initialization')
46
+ parser.add_argument("--display_count", type=int, default=20)
47
+ parser.add_argument("--save_count", type=int, default=5000)
48
+ parser.add_argument("--keep_step", type=int, default=100000)
49
+ parser.add_argument("--decay_step", type=int, default=100000)
50
+ parser.add_argument("--shuffle", action='store_true',
51
+ help='shuffle input data')
52
+
53
+ opt = parser.parse_args()
54
+ return opt
55
+
56
+
57
+ def train_gmm(opt, train_loader, model, board):
58
+ model.cuda()
59
+ model.train()
60
+
61
+ # criterion
62
+ criterionL1 = nn.L1Loss()
63
+ gicloss = GicLoss(opt)
64
+ # optimizer
65
+ optimizer = torch.optim.Adam(
66
+ model.parameters(), lr=opt.lr, betas=(0.5, 0.999))
67
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0 -
68
+ max(0, step - opt.keep_step) / float(opt.decay_step + 1))
69
+
70
+ for step in range(opt.keep_step + opt.decay_step):
71
+ iter_start_time = time.time()
72
+ inputs = train_loader.next_batch()
73
+
74
+ im = inputs['image'].cuda()
75
+ im_pose = inputs['pose_image'].cuda()
76
+ im_h = inputs['head'].cuda()
77
+ shape = inputs['shape'].cuda()
78
+ agnostic = inputs['agnostic'].cuda()
79
+ c = inputs['cloth'].cuda()
80
+ cm = inputs['cloth_mask'].cuda()
81
+ im_c = inputs['parse_cloth'].cuda()
82
+ im_g = inputs['grid_image'].cuda()
83
+
84
+ grid, theta = model(agnostic, cm) # can be added c too for new training
85
+ warped_cloth = F.grid_sample(c, grid, padding_mode='border')
86
+ warped_mask = F.grid_sample(cm, grid, padding_mode='zeros')
87
+ warped_grid = F.grid_sample(im_g, grid, padding_mode='zeros')
88
+
89
+ visuals = [[im_h, shape, im_pose],
90
+ [c, warped_cloth, im_c],
91
+ [warped_grid, (warped_cloth+im)*0.5, im]]
92
+
93
+ # loss for warped cloth
94
+ Lwarp = criterionL1(warped_cloth, im_c) # changing to previous code as it corresponds to the working code
95
+ # Actual loss function as in the paper given below (comment out previous line and uncomment below to train as per the paper)
96
+ # Lwarp = criterionL1(warped_mask, cm) # loss for warped mask thanks @xuxiaochun025 for fixing the git code.
97
+
98
+ # grid regularization loss
99
+ Lgic = gicloss(grid)
100
+ # 200x200 = 40.000 * 0.001
101
+ Lgic = Lgic / (grid.shape[0] * grid.shape[1] * grid.shape[2])
102
+
103
+ loss = Lwarp + 40 * Lgic # total GMM loss
104
+
105
+ optimizer.zero_grad()
106
+ loss.backward()
107
+ optimizer.step()
108
+
109
+ if (step+1) % opt.display_count == 0:
110
+ board_add_images(board, 'combine', visuals, step+1)
111
+ board.add_scalar('loss', loss.item(), step+1)
112
+ board.add_scalar('40*Lgic', (40*Lgic).item(), step+1)
113
+ board.add_scalar('Lwarp', Lwarp.item(), step+1)
114
+ t = time.time() - iter_start_time
115
+ print('step: %8d, time: %.3f, loss: %4f, (40*Lgic): %.8f, Lwarp: %.6f' %
116
+ (step+1, t, loss.item(), (40*Lgic).item(), Lwarp.item()), flush=True)
117
+
118
+ if (step+1) % opt.save_count == 0:
119
+ save_checkpoint(model, os.path.join(
120
+ opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))
121
+
122
+
123
+ def train_tom(opt, train_loader, model, board):
124
+ model.cuda()
125
+ model.train()
126
+
127
+ # criterion
128
+ criterionL1 = nn.L1Loss()
129
+ criterionVGG = VGGLoss()
130
+ criterionMask = nn.L1Loss()
131
+
132
+ # optimizer
133
+ optimizer = torch.optim.Adam(
134
+ model.parameters(), lr=opt.lr, betas=(0.5, 0.999))
135
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda step: 1.0 -
136
+ max(0, step - opt.keep_step) / float(opt.decay_step + 1))
137
+
138
+ for step in range(opt.keep_step + opt.decay_step):
139
+ iter_start_time = time.time()
140
+ inputs = train_loader.next_batch()
141
+
142
+ im = inputs['image'].cuda()
143
+ im_pose = inputs['pose_image']
144
+ im_h = inputs['head']
145
+ shape = inputs['shape']
146
+
147
+ agnostic = inputs['agnostic'].cuda()
148
+ c = inputs['cloth'].cuda()
149
+ cm = inputs['cloth_mask'].cuda()
150
+ pcm = inputs['parse_cloth_mask'].cuda()
151
+
152
+ # outputs = model(torch.cat([agnostic, c], 1)) # CP-VTON
153
+ outputs = model(torch.cat([agnostic, c, cm], 1)) # CP-VTON+
154
+ p_rendered, m_composite = torch.split(outputs, 3, 1)
155
+ p_rendered = F.tanh(p_rendered)
156
+ m_composite = F.sigmoid(m_composite)
157
+ p_tryon = c * m_composite + p_rendered * (1 - m_composite)
158
+
159
+ """visuals = [[im_h, shape, im_pose],
160
+ [c, cm*2-1, m_composite*2-1],
161
+ [p_rendered, p_tryon, im]]""" # CP-VTON
162
+
163
+ visuals = [[im_h, shape, im_pose],
164
+ [c, pcm*2-1, m_composite*2-1],
165
+ [p_rendered, p_tryon, im]] # CP-VTON+
166
+
167
+ loss_l1 = criterionL1(p_tryon, im)
168
+ loss_vgg = criterionVGG(p_tryon, im)
169
+ # loss_mask = criterionMask(m_composite, cm) # CP-VTON
170
+ loss_mask = criterionMask(m_composite, pcm) # CP-VTON+
171
+ loss = loss_l1 + loss_vgg + loss_mask
172
+ optimizer.zero_grad()
173
+ loss.backward()
174
+ optimizer.step()
175
+
176
+ if (step+1) % opt.display_count == 0:
177
+ board_add_images(board, 'combine', visuals, step+1)
178
+ board.add_scalar('metric', loss.item(), step+1)
179
+ board.add_scalar('L1', loss_l1.item(), step+1)
180
+ board.add_scalar('VGG', loss_vgg.item(), step+1)
181
+ board.add_scalar('MaskL1', loss_mask.item(), step+1)
182
+ t = time.time() - iter_start_time
183
+ print('step: %8d, time: %.3f, loss: %.4f, l1: %.4f, vgg: %.4f, mask: %.4f'
184
+ % (step+1, t, loss.item(), loss_l1.item(),
185
+ loss_vgg.item(), loss_mask.item()), flush=True)
186
+
187
+ if (step+1) % opt.save_count == 0:
188
+ save_checkpoint(model, os.path.join(
189
+ opt.checkpoint_dir, opt.name, 'step_%06d.pth' % (step+1)))
190
+
191
+
192
+ def main():
193
+ opt = get_opt()
194
+ print(opt)
195
+ print("Start to train stage: %s, named: %s!" % (opt.stage, opt.name))
196
+
197
+ # create dataset
198
+ train_dataset = CPDataset(opt)
199
+
200
+ # create dataloader
201
+ train_loader = CPDataLoader(opt, train_dataset)
202
+
203
+ # visualization
204
+ if not os.path.exists(opt.tensorboard_dir):
205
+ os.makedirs(opt.tensorboard_dir)
206
+ board = SummaryWriter(logdir=os.path.join(opt.tensorboard_dir, opt.name))
207
+
208
+ # create model & train & save the final checkpoint
209
+ if opt.stage == 'GMM':
210
+ model = GMM(opt)
211
+ if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
212
+ load_checkpoint(model, opt.checkpoint)
213
+ train_gmm(opt, train_loader, model, board)
214
+ save_checkpoint(model, os.path.join(
215
+ opt.checkpoint_dir, opt.name, 'gmm_final.pth'))
216
+ elif opt.stage == 'TOM':
217
+ # model = UnetGenerator(25, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) # CP-VTON
218
+ model = UnetGenerator(
219
+ 26, 4, 6, ngf=64, norm_layer=nn.InstanceNorm2d) # CP-VTON+
220
+ if not opt.checkpoint == '' and os.path.exists(opt.checkpoint):
221
+ load_checkpoint(model, opt.checkpoint)
222
+ train_tom(opt, train_loader, model, board)
223
+ save_checkpoint(model, os.path.join(
224
+ opt.checkpoint_dir, opt.name, 'tom_final.pth'))
225
+ else:
226
+ raise NotImplementedError('Model [%s] is not implemented' % opt.stage)
227
+
228
+ print('Finished training %s, named: %s!' % (opt.stage, opt.name))
229
+
230
+
231
+ if __name__ == "__main__":
232
+ main()
visualization.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorboardX import SummaryWriter
2
+ import torch
3
+ from PIL import Image
4
+ import os
5
+
6
+
7
+ def tensor_for_board(img_tensor):
8
+ # map into [0,1]
9
+ tensor = (img_tensor.clone()+1) * 0.5
10
+ tensor.cpu().clamp(0, 1)
11
+
12
+ if tensor.size(1) == 1:
13
+ tensor = tensor.repeat(1, 3, 1, 1)
14
+
15
+ return tensor
16
+
17
+
18
+ def tensor_list_for_board(img_tensors_list):
19
+ grid_h = len(img_tensors_list)
20
+ grid_w = max(len(img_tensors) for img_tensors in img_tensors_list)
21
+
22
+ batch_size, channel, height, width = tensor_for_board(
23
+ img_tensors_list[0][0]).size()
24
+ canvas_h = grid_h * height
25
+ canvas_w = grid_w * width
26
+ canvas = torch.FloatTensor(
27
+ batch_size, channel, canvas_h, canvas_w).fill_(0.5)
28
+ for i, img_tensors in enumerate(img_tensors_list):
29
+ for j, img_tensor in enumerate(img_tensors):
30
+ offset_h = i * height
31
+ offset_w = j * width
32
+ tensor = tensor_for_board(img_tensor)
33
+ canvas[:, :, offset_h: offset_h + height,
34
+ offset_w: offset_w + width].copy_(tensor)
35
+
36
+ return canvas
37
+
38
+
39
+ def board_add_image(board, tag_name, img_tensor, step_count):
40
+ tensor = tensor_for_board(img_tensor)
41
+
42
+ for i, img in enumerate(tensor):
43
+ board.add_image('%s/%03d' % (tag_name, i), img, step_count)
44
+
45
+
46
+ def board_add_images(board, tag_name, img_tensors_list, step_count):
47
+ tensor = tensor_list_for_board(img_tensors_list)
48
+
49
+ for i, img in enumerate(tensor):
50
+ board.add_image('%s/%03d' % (tag_name, i), img, step_count)
51
+
52
+
53
+ def save_images(img_tensors, img_names, save_dir):
54
+ for img_tensor, img_name in zip(img_tensors, img_names):
55
+ tensor = (img_tensor.clone()+1)*0.5 * 255
56
+ tensor = tensor.cpu().clamp(0, 255)
57
+
58
+ array = tensor.numpy().astype('uint8')
59
+ if array.shape[0] == 1:
60
+ array = array.squeeze(0)
61
+ elif array.shape[0] == 3:
62
+ array = array.swapaxes(0, 1).swapaxes(1, 2)
63
+
64
+ Image.fromarray(array).save(os.path.join(save_dir, img_name))