gaur3009 commited on
Commit
ce16d6e
·
verified ·
1 Parent(s): 08cfba0

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +8 -49
networks.py CHANGED
@@ -1,13 +1,11 @@
1
- # coding=utf-8
2
  import torch
3
  import torch.nn as nn
 
4
  from torch.nn import init
5
  from torchvision import models
6
  import os
7
-
8
  import numpy as np
9
 
10
-
11
  def weights_init_normal(m):
12
  classname = m.__class__.__name__
13
  if classname.find('Conv') != -1:
@@ -18,7 +16,6 @@ def weights_init_normal(m):
18
  init.normal_(m.weight.data, 1.0, 0.02)
19
  init.constant_(m.bias.data, 0.0)
20
 
21
-
22
  def weights_init_xavier(m):
23
  classname = m.__class__.__name__
24
  if classname.find('Conv') != -1:
@@ -29,7 +26,6 @@ def weights_init_xavier(m):
29
  init.normal_(m.weight.data, 1.0, 0.02)
30
  init.constant_(m.bias.data, 0.0)
31
 
32
-
33
  def weights_init_kaiming(m):
34
  classname = m.__class__.__name__
35
  if classname.find('Conv') != -1:
@@ -40,7 +36,6 @@ def weights_init_kaiming(m):
40
  init.normal_(m.weight.data, 1.0, 0.02)
41
  init.constant_(m.bias.data, 0.0)
42
 
43
-
44
  def init_weights(net, init_type='normal'):
45
  print('initialization method [%s]' % init_type)
46
  if init_type == 'normal':
@@ -53,7 +48,6 @@ def init_weights(net, init_type='normal'):
53
  raise NotImplementedError(
54
  'initialization method [%s] is not implemented' % init_type)
55
 
56
-
57
  class FeatureExtraction(nn.Module):
58
  def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_dropout=False):
59
  super(FeatureExtraction, self).__init__()
@@ -78,7 +72,6 @@ class FeatureExtraction(nn.Module):
78
  def forward(self, x):
79
  return self.model(x)
80
 
81
-
82
  class FeatureL2Norm(torch.nn.Module):
83
  def __init__(self):
84
  super(FeatureL2Norm, self).__init__()
@@ -89,7 +82,6 @@ class FeatureL2Norm(torch.nn.Module):
89
  epsilon, 0.5).unsqueeze(1).expand_as(feature)
90
  return torch.div(feature, norm)
91
 
92
-
93
  class FeatureCorrelation(nn.Module):
94
  def __init__(self):
95
  super(FeatureCorrelation, self).__init__()
@@ -105,9 +97,8 @@ class FeatureCorrelation(nn.Module):
105
  b, h, w, h*w).transpose(2, 3).transpose(1, 2)
106
  return correlation_tensor
107
 
108
-
109
  class FeatureRegression(nn.Module):
110
- def __init__(self, input_nc=512, output_dim=6, use_cuda=True):
111
  super(FeatureRegression, self).__init__()
112
  self.conv = nn.Sequential(
113
  nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1),
@@ -125,10 +116,6 @@ class FeatureRegression(nn.Module):
125
  )
126
  self.linear = nn.Linear(64 * 4 * 3, output_dim)
127
  self.tanh = nn.Tanh()
128
- if use_cuda:
129
- self.conv.cuda()
130
- self.linear.cuda()
131
- self.tanh.cuda()
132
 
133
  def forward(self, x):
134
  x = self.conv(x)
@@ -137,7 +124,6 @@ class FeatureRegression(nn.Module):
137
  x = self.tanh(x)
138
  return x
139
 
140
-
141
  class AffineGridGen(nn.Module):
142
  def __init__(self, out_h=256, out_w=192, out_ch=3):
143
  super(AffineGridGen, self).__init__()
@@ -152,13 +138,11 @@ class AffineGridGen(nn.Module):
152
  (batch_size, self.out_ch, self.out_h, self.out_w))
153
  return F.affine_grid(theta, out_size)
154
 
155
-
156
  class TpsGridGen(nn.Module):
157
- def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0, use_cuda=True):
158
  super(TpsGridGen, self).__init__()
159
  self.out_h, self.out_w = out_h, out_w
160
  self.reg_factor = reg_factor
161
- self.use_cuda = use_cuda
162
 
163
  # create grid in numpy
164
  self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
@@ -168,9 +152,6 @@ class TpsGridGen(nn.Module):
168
  # grid_X,grid_Y: size [1,H,W,1,1]
169
  self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3)
170
  self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3)
171
- if use_cuda:
172
- self.grid_X = self.grid_X.cuda()
173
- self.grid_Y = self.grid_Y.cuda()
174
 
175
  # initialize regular grid for control points P_i
176
  if use_regular_grid:
@@ -188,11 +169,6 @@ class TpsGridGen(nn.Module):
188
  3).unsqueeze(4).transpose(0, 4)
189
  self.P_Y = P_Y.unsqueeze(2).unsqueeze(
190
  3).unsqueeze(4).transpose(0, 4)
191
- if use_cuda:
192
- self.P_X = self.P_X.cuda()
193
- self.P_Y = self.P_Y.cuda()
194
- self.P_X_base = self.P_X_base.cuda()
195
- self.P_Y_base = self.P_Y_base.cuda()
196
 
197
  def forward(self, theta):
198
  warped_grid = self.apply_transformation(
@@ -217,8 +193,6 @@ class TpsGridGen(nn.Module):
217
  L = torch.cat((torch.cat((K, P), 1), torch.cat(
218
  (P.transpose(0, 1), Z), 1)), 0)
219
  Li = torch.inverse(L)
220
- if self.use_cuda:
221
- Li = Li.cuda()
222
  return Li
223
 
224
  def apply_transformation(self, theta, points):
@@ -315,8 +289,6 @@ class TpsGridGen(nn.Module):
315
  # |num_downs|: number of downsamplings in UNet. For example,
316
  # if |num_downs| == 7, image of size 128x128 will become of size 1x1
317
  # at the bottleneck
318
-
319
-
320
  class UnetGenerator(nn.Module):
321
  def __init__(self, input_nc, output_nc, num_downs, ngf=64,
322
  norm_layer=nn.BatchNorm2d, use_dropout=False):
@@ -341,7 +313,6 @@ class UnetGenerator(nn.Module):
341
  def forward(self, input):
342
  return self.model(input)
343
 
344
-
345
  # Defines the submodule with skip connection.
346
  # X -------------------identity---------------------- X
347
  # |-- downsampling -- |submodule| -- upsampling --|
@@ -395,7 +366,6 @@ class UnetSkipConnectionBlock(nn.Module):
395
  else:
396
  return torch.cat([x, self.model(x)], 1)
397
 
398
-
399
  class Vgg19(nn.Module):
400
  def __init__(self, requires_grad=False):
401
  super(Vgg19, self).__init__()
@@ -428,12 +398,10 @@ class Vgg19(nn.Module):
428
  out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
429
  return out
430
 
431
-
432
  class VGGLoss(nn.Module):
433
  def __init__(self, layids=None):
434
  super(VGGLoss, self).__init__()
435
  self.vgg = Vgg19()
436
- self.vgg.cuda()
437
  self.criterion = nn.L1Loss()
438
  self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
439
  self.layids = layids
@@ -448,7 +416,6 @@ class VGGLoss(nn.Module):
448
  self.criterion(x_vgg[i], y_vgg[i].detach())
449
  return loss
450
 
451
-
452
  class DT(nn.Module):
453
  def __init__(self):
454
  super(DT, self).__init__()
@@ -457,17 +424,15 @@ class DT(nn.Module):
457
  dt = torch.abs(x1 - x2)
458
  return dt
459
 
460
-
461
  class DT2(nn.Module):
462
  def __init__(self):
463
- super(DT, self).__init__()
464
 
465
  def forward(self, x1, y1, x2, y2):
466
  dt = torch.sqrt(torch.mul(x1 - x2, x1 - x2) +
467
  torch.mul(y1 - y2, y1 - y2))
468
  return dt
469
 
470
-
471
  class GicLoss(nn.Module):
472
  def __init__(self, opt):
473
  super(GicLoss, self).__init__()
@@ -496,7 +461,6 @@ class GicLoss(nn.Module):
496
 
497
  return torch.sum(torch.abs(dtleft - dtright) + torch.abs(dtup - dtdown))
498
 
499
-
500
  class GMM(nn.Module):
501
  """ Geometric Matching Module
502
  """
@@ -510,9 +474,9 @@ class GMM(nn.Module):
510
  self.l2norm = FeatureL2Norm()
511
  self.correlation = FeatureCorrelation()
512
  self.regression = FeatureRegression(
513
- input_nc=192, output_dim=2*opt.grid_size**2, use_cuda=True)
514
  self.gridGen = TpsGridGen(
515
- opt.fine_height, opt.fine_width, use_cuda=True, grid_size=opt.grid_size)
516
 
517
  def forward(self, inputA, inputB):
518
  featureA = self.extractionA(inputA)
@@ -525,17 +489,12 @@ class GMM(nn.Module):
525
  grid = self.gridGen(theta)
526
  return grid, theta
527
 
528
-
529
  def save_checkpoint(model, save_path):
530
  if not os.path.exists(os.path.dirname(save_path)):
531
  os.makedirs(os.path.dirname(save_path))
532
-
533
- torch.save(model.cpu().state_dict(), save_path)
534
- model.cuda()
535
-
536
 
537
  def load_checkpoint(model, checkpoint_path):
538
  if not os.path.exists(checkpoint_path):
539
  return
540
- model.load_state_dict(torch.load(checkpoint_path))
541
- model.cuda()
 
 
1
  import torch
2
  import torch.nn as nn
3
+ import torch.nn.functional as F
4
  from torch.nn import init
5
  from torchvision import models
6
  import os
 
7
  import numpy as np
8
 
 
9
  def weights_init_normal(m):
10
  classname = m.__class__.__name__
11
  if classname.find('Conv') != -1:
 
16
  init.normal_(m.weight.data, 1.0, 0.02)
17
  init.constant_(m.bias.data, 0.0)
18
 
 
19
  def weights_init_xavier(m):
20
  classname = m.__class__.__name__
21
  if classname.find('Conv') != -1:
 
26
  init.normal_(m.weight.data, 1.0, 0.02)
27
  init.constant_(m.bias.data, 0.0)
28
 
 
29
  def weights_init_kaiming(m):
30
  classname = m.__class__.__name__
31
  if classname.find('Conv') != -1:
 
36
  init.normal_(m.weight.data, 1.0, 0.02)
37
  init.constant_(m.bias.data, 0.0)
38
 
 
39
  def init_weights(net, init_type='normal'):
40
  print('initialization method [%s]' % init_type)
41
  if init_type == 'normal':
 
48
  raise NotImplementedError(
49
  'initialization method [%s] is not implemented' % init_type)
50
 
 
51
  class FeatureExtraction(nn.Module):
52
  def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_dropout=False):
53
  super(FeatureExtraction, self).__init__()
 
72
  def forward(self, x):
73
  return self.model(x)
74
 
 
75
  class FeatureL2Norm(torch.nn.Module):
76
  def __init__(self):
77
  super(FeatureL2Norm, self).__init__()
 
82
  epsilon, 0.5).unsqueeze(1).expand_as(feature)
83
  return torch.div(feature, norm)
84
 
 
85
  class FeatureCorrelation(nn.Module):
86
  def __init__(self):
87
  super(FeatureCorrelation, self).__init__()
 
97
  b, h, w, h*w).transpose(2, 3).transpose(1, 2)
98
  return correlation_tensor
99
 
 
100
  class FeatureRegression(nn.Module):
101
+ def __init__(self, input_nc=512, output_dim=6):
102
  super(FeatureRegression, self).__init__()
103
  self.conv = nn.Sequential(
104
  nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1),
 
116
  )
117
  self.linear = nn.Linear(64 * 4 * 3, output_dim)
118
  self.tanh = nn.Tanh()
 
 
 
 
119
 
120
  def forward(self, x):
121
  x = self.conv(x)
 
124
  x = self.tanh(x)
125
  return x
126
 
 
127
  class AffineGridGen(nn.Module):
128
  def __init__(self, out_h=256, out_w=192, out_ch=3):
129
  super(AffineGridGen, self).__init__()
 
138
  (batch_size, self.out_ch, self.out_h, self.out_w))
139
  return F.affine_grid(theta, out_size)
140
 
 
141
  class TpsGridGen(nn.Module):
142
+ def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0):
143
  super(TpsGridGen, self).__init__()
144
  self.out_h, self.out_w = out_h, out_w
145
  self.reg_factor = reg_factor
 
146
 
147
  # create grid in numpy
148
  self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
 
152
  # grid_X,grid_Y: size [1,H,W,1,1]
153
  self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3)
154
  self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3)
 
 
 
155
 
156
  # initialize regular grid for control points P_i
157
  if use_regular_grid:
 
169
  3).unsqueeze(4).transpose(0, 4)
170
  self.P_Y = P_Y.unsqueeze(2).unsqueeze(
171
  3).unsqueeze(4).transpose(0, 4)
 
 
 
 
 
172
 
173
  def forward(self, theta):
174
  warped_grid = self.apply_transformation(
 
193
  L = torch.cat((torch.cat((K, P), 1), torch.cat(
194
  (P.transpose(0, 1), Z), 1)), 0)
195
  Li = torch.inverse(L)
 
 
196
  return Li
197
 
198
  def apply_transformation(self, theta, points):
 
289
  # |num_downs|: number of downsamplings in UNet. For example,
290
  # if |num_downs| == 7, image of size 128x128 will become of size 1x1
291
  # at the bottleneck
 
 
292
  class UnetGenerator(nn.Module):
293
  def __init__(self, input_nc, output_nc, num_downs, ngf=64,
294
  norm_layer=nn.BatchNorm2d, use_dropout=False):
 
313
  def forward(self, input):
314
  return self.model(input)
315
 
 
316
  # Defines the submodule with skip connection.
317
  # X -------------------identity---------------------- X
318
  # |-- downsampling -- |submodule| -- upsampling --|
 
366
  else:
367
  return torch.cat([x, self.model(x)], 1)
368
 
 
369
  class Vgg19(nn.Module):
370
  def __init__(self, requires_grad=False):
371
  super(Vgg19, self).__init__()
 
398
  out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
399
  return out
400
 
 
401
  class VGGLoss(nn.Module):
402
  def __init__(self, layids=None):
403
  super(VGGLoss, self).__init__()
404
  self.vgg = Vgg19()
 
405
  self.criterion = nn.L1Loss()
406
  self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
407
  self.layids = layids
 
416
  self.criterion(x_vgg[i], y_vgg[i].detach())
417
  return loss
418
 
 
419
  class DT(nn.Module):
420
  def __init__(self):
421
  super(DT, self).__init__()
 
424
  dt = torch.abs(x1 - x2)
425
  return dt
426
 
 
427
  class DT2(nn.Module):
428
  def __init__(self):
429
+ super(DT2, self).__init__()
430
 
431
  def forward(self, x1, y1, x2, y2):
432
  dt = torch.sqrt(torch.mul(x1 - x2, x1 - x2) +
433
  torch.mul(y1 - y2, y1 - y2))
434
  return dt
435
 
 
436
  class GicLoss(nn.Module):
437
  def __init__(self, opt):
438
  super(GicLoss, self).__init__()
 
461
 
462
  return torch.sum(torch.abs(dtleft - dtright) + torch.abs(dtup - dtdown))
463
 
 
464
  class GMM(nn.Module):
465
  """ Geometric Matching Module
466
  """
 
474
  self.l2norm = FeatureL2Norm()
475
  self.correlation = FeatureCorrelation()
476
  self.regression = FeatureRegression(
477
+ input_nc=192, output_dim=2*opt.grid_size**2)
478
  self.gridGen = TpsGridGen(
479
+ opt.fine_height, opt.fine_width, grid_size=opt.grid_size)
480
 
481
  def forward(self, inputA, inputB):
482
  featureA = self.extractionA(inputA)
 
489
  grid = self.gridGen(theta)
490
  return grid, theta
491
 
 
492
  def save_checkpoint(model, save_path):
493
  if not os.path.exists(os.path.dirname(save_path)):
494
  os.makedirs(os.path.dirname(save_path))
495
+ torch.save(model.state_dict(), save_path)
 
 
 
496
 
497
  def load_checkpoint(model, checkpoint_path):
498
  if not os.path.exists(checkpoint_path):
499
  return
500
+ model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))