Spaces:
Sleeping
Sleeping
Update networks.py
Browse files- networks.py +8 -49
networks.py
CHANGED
@@ -1,13 +1,11 @@
|
|
1 |
-
# coding=utf-8
|
2 |
import torch
|
3 |
import torch.nn as nn
|
|
|
4 |
from torch.nn import init
|
5 |
from torchvision import models
|
6 |
import os
|
7 |
-
|
8 |
import numpy as np
|
9 |
|
10 |
-
|
11 |
def weights_init_normal(m):
|
12 |
classname = m.__class__.__name__
|
13 |
if classname.find('Conv') != -1:
|
@@ -18,7 +16,6 @@ def weights_init_normal(m):
|
|
18 |
init.normal_(m.weight.data, 1.0, 0.02)
|
19 |
init.constant_(m.bias.data, 0.0)
|
20 |
|
21 |
-
|
22 |
def weights_init_xavier(m):
|
23 |
classname = m.__class__.__name__
|
24 |
if classname.find('Conv') != -1:
|
@@ -29,7 +26,6 @@ def weights_init_xavier(m):
|
|
29 |
init.normal_(m.weight.data, 1.0, 0.02)
|
30 |
init.constant_(m.bias.data, 0.0)
|
31 |
|
32 |
-
|
33 |
def weights_init_kaiming(m):
|
34 |
classname = m.__class__.__name__
|
35 |
if classname.find('Conv') != -1:
|
@@ -40,7 +36,6 @@ def weights_init_kaiming(m):
|
|
40 |
init.normal_(m.weight.data, 1.0, 0.02)
|
41 |
init.constant_(m.bias.data, 0.0)
|
42 |
|
43 |
-
|
44 |
def init_weights(net, init_type='normal'):
|
45 |
print('initialization method [%s]' % init_type)
|
46 |
if init_type == 'normal':
|
@@ -53,7 +48,6 @@ def init_weights(net, init_type='normal'):
|
|
53 |
raise NotImplementedError(
|
54 |
'initialization method [%s] is not implemented' % init_type)
|
55 |
|
56 |
-
|
57 |
class FeatureExtraction(nn.Module):
|
58 |
def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
59 |
super(FeatureExtraction, self).__init__()
|
@@ -78,7 +72,6 @@ class FeatureExtraction(nn.Module):
|
|
78 |
def forward(self, x):
|
79 |
return self.model(x)
|
80 |
|
81 |
-
|
82 |
class FeatureL2Norm(torch.nn.Module):
|
83 |
def __init__(self):
|
84 |
super(FeatureL2Norm, self).__init__()
|
@@ -89,7 +82,6 @@ class FeatureL2Norm(torch.nn.Module):
|
|
89 |
epsilon, 0.5).unsqueeze(1).expand_as(feature)
|
90 |
return torch.div(feature, norm)
|
91 |
|
92 |
-
|
93 |
class FeatureCorrelation(nn.Module):
|
94 |
def __init__(self):
|
95 |
super(FeatureCorrelation, self).__init__()
|
@@ -105,9 +97,8 @@ class FeatureCorrelation(nn.Module):
|
|
105 |
b, h, w, h*w).transpose(2, 3).transpose(1, 2)
|
106 |
return correlation_tensor
|
107 |
|
108 |
-
|
109 |
class FeatureRegression(nn.Module):
|
110 |
-
def __init__(self, input_nc=512, output_dim=6
|
111 |
super(FeatureRegression, self).__init__()
|
112 |
self.conv = nn.Sequential(
|
113 |
nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1),
|
@@ -125,10 +116,6 @@ class FeatureRegression(nn.Module):
|
|
125 |
)
|
126 |
self.linear = nn.Linear(64 * 4 * 3, output_dim)
|
127 |
self.tanh = nn.Tanh()
|
128 |
-
if use_cuda:
|
129 |
-
self.conv.cuda()
|
130 |
-
self.linear.cuda()
|
131 |
-
self.tanh.cuda()
|
132 |
|
133 |
def forward(self, x):
|
134 |
x = self.conv(x)
|
@@ -137,7 +124,6 @@ class FeatureRegression(nn.Module):
|
|
137 |
x = self.tanh(x)
|
138 |
return x
|
139 |
|
140 |
-
|
141 |
class AffineGridGen(nn.Module):
|
142 |
def __init__(self, out_h=256, out_w=192, out_ch=3):
|
143 |
super(AffineGridGen, self).__init__()
|
@@ -152,13 +138,11 @@ class AffineGridGen(nn.Module):
|
|
152 |
(batch_size, self.out_ch, self.out_h, self.out_w))
|
153 |
return F.affine_grid(theta, out_size)
|
154 |
|
155 |
-
|
156 |
class TpsGridGen(nn.Module):
|
157 |
-
def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0
|
158 |
super(TpsGridGen, self).__init__()
|
159 |
self.out_h, self.out_w = out_h, out_w
|
160 |
self.reg_factor = reg_factor
|
161 |
-
self.use_cuda = use_cuda
|
162 |
|
163 |
# create grid in numpy
|
164 |
self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
|
@@ -168,9 +152,6 @@ class TpsGridGen(nn.Module):
|
|
168 |
# grid_X,grid_Y: size [1,H,W,1,1]
|
169 |
self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3)
|
170 |
self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3)
|
171 |
-
if use_cuda:
|
172 |
-
self.grid_X = self.grid_X.cuda()
|
173 |
-
self.grid_Y = self.grid_Y.cuda()
|
174 |
|
175 |
# initialize regular grid for control points P_i
|
176 |
if use_regular_grid:
|
@@ -188,11 +169,6 @@ class TpsGridGen(nn.Module):
|
|
188 |
3).unsqueeze(4).transpose(0, 4)
|
189 |
self.P_Y = P_Y.unsqueeze(2).unsqueeze(
|
190 |
3).unsqueeze(4).transpose(0, 4)
|
191 |
-
if use_cuda:
|
192 |
-
self.P_X = self.P_X.cuda()
|
193 |
-
self.P_Y = self.P_Y.cuda()
|
194 |
-
self.P_X_base = self.P_X_base.cuda()
|
195 |
-
self.P_Y_base = self.P_Y_base.cuda()
|
196 |
|
197 |
def forward(self, theta):
|
198 |
warped_grid = self.apply_transformation(
|
@@ -217,8 +193,6 @@ class TpsGridGen(nn.Module):
|
|
217 |
L = torch.cat((torch.cat((K, P), 1), torch.cat(
|
218 |
(P.transpose(0, 1), Z), 1)), 0)
|
219 |
Li = torch.inverse(L)
|
220 |
-
if self.use_cuda:
|
221 |
-
Li = Li.cuda()
|
222 |
return Li
|
223 |
|
224 |
def apply_transformation(self, theta, points):
|
@@ -315,8 +289,6 @@ class TpsGridGen(nn.Module):
|
|
315 |
# |num_downs|: number of downsamplings in UNet. For example,
|
316 |
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
|
317 |
# at the bottleneck
|
318 |
-
|
319 |
-
|
320 |
class UnetGenerator(nn.Module):
|
321 |
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
322 |
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
@@ -341,7 +313,6 @@ class UnetGenerator(nn.Module):
|
|
341 |
def forward(self, input):
|
342 |
return self.model(input)
|
343 |
|
344 |
-
|
345 |
# Defines the submodule with skip connection.
|
346 |
# X -------------------identity---------------------- X
|
347 |
# |-- downsampling -- |submodule| -- upsampling --|
|
@@ -395,7 +366,6 @@ class UnetSkipConnectionBlock(nn.Module):
|
|
395 |
else:
|
396 |
return torch.cat([x, self.model(x)], 1)
|
397 |
|
398 |
-
|
399 |
class Vgg19(nn.Module):
|
400 |
def __init__(self, requires_grad=False):
|
401 |
super(Vgg19, self).__init__()
|
@@ -428,12 +398,10 @@ class Vgg19(nn.Module):
|
|
428 |
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
429 |
return out
|
430 |
|
431 |
-
|
432 |
class VGGLoss(nn.Module):
|
433 |
def __init__(self, layids=None):
|
434 |
super(VGGLoss, self).__init__()
|
435 |
self.vgg = Vgg19()
|
436 |
-
self.vgg.cuda()
|
437 |
self.criterion = nn.L1Loss()
|
438 |
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
|
439 |
self.layids = layids
|
@@ -448,7 +416,6 @@ class VGGLoss(nn.Module):
|
|
448 |
self.criterion(x_vgg[i], y_vgg[i].detach())
|
449 |
return loss
|
450 |
|
451 |
-
|
452 |
class DT(nn.Module):
|
453 |
def __init__(self):
|
454 |
super(DT, self).__init__()
|
@@ -457,17 +424,15 @@ class DT(nn.Module):
|
|
457 |
dt = torch.abs(x1 - x2)
|
458 |
return dt
|
459 |
|
460 |
-
|
461 |
class DT2(nn.Module):
|
462 |
def __init__(self):
|
463 |
-
super(
|
464 |
|
465 |
def forward(self, x1, y1, x2, y2):
|
466 |
dt = torch.sqrt(torch.mul(x1 - x2, x1 - x2) +
|
467 |
torch.mul(y1 - y2, y1 - y2))
|
468 |
return dt
|
469 |
|
470 |
-
|
471 |
class GicLoss(nn.Module):
|
472 |
def __init__(self, opt):
|
473 |
super(GicLoss, self).__init__()
|
@@ -496,7 +461,6 @@ class GicLoss(nn.Module):
|
|
496 |
|
497 |
return torch.sum(torch.abs(dtleft - dtright) + torch.abs(dtup - dtdown))
|
498 |
|
499 |
-
|
500 |
class GMM(nn.Module):
|
501 |
""" Geometric Matching Module
|
502 |
"""
|
@@ -510,9 +474,9 @@ class GMM(nn.Module):
|
|
510 |
self.l2norm = FeatureL2Norm()
|
511 |
self.correlation = FeatureCorrelation()
|
512 |
self.regression = FeatureRegression(
|
513 |
-
input_nc=192, output_dim=2*opt.grid_size**2
|
514 |
self.gridGen = TpsGridGen(
|
515 |
-
opt.fine_height, opt.fine_width,
|
516 |
|
517 |
def forward(self, inputA, inputB):
|
518 |
featureA = self.extractionA(inputA)
|
@@ -525,17 +489,12 @@ class GMM(nn.Module):
|
|
525 |
grid = self.gridGen(theta)
|
526 |
return grid, theta
|
527 |
|
528 |
-
|
529 |
def save_checkpoint(model, save_path):
|
530 |
if not os.path.exists(os.path.dirname(save_path)):
|
531 |
os.makedirs(os.path.dirname(save_path))
|
532 |
-
|
533 |
-
torch.save(model.cpu().state_dict(), save_path)
|
534 |
-
model.cuda()
|
535 |
-
|
536 |
|
537 |
def load_checkpoint(model, checkpoint_path):
|
538 |
if not os.path.exists(checkpoint_path):
|
539 |
return
|
540 |
-
model.load_state_dict(torch.load(checkpoint_path))
|
541 |
-
model.cuda()
|
|
|
|
|
1 |
import torch
|
2 |
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
from torch.nn import init
|
5 |
from torchvision import models
|
6 |
import os
|
|
|
7 |
import numpy as np
|
8 |
|
|
|
9 |
def weights_init_normal(m):
|
10 |
classname = m.__class__.__name__
|
11 |
if classname.find('Conv') != -1:
|
|
|
16 |
init.normal_(m.weight.data, 1.0, 0.02)
|
17 |
init.constant_(m.bias.data, 0.0)
|
18 |
|
|
|
19 |
def weights_init_xavier(m):
|
20 |
classname = m.__class__.__name__
|
21 |
if classname.find('Conv') != -1:
|
|
|
26 |
init.normal_(m.weight.data, 1.0, 0.02)
|
27 |
init.constant_(m.bias.data, 0.0)
|
28 |
|
|
|
29 |
def weights_init_kaiming(m):
|
30 |
classname = m.__class__.__name__
|
31 |
if classname.find('Conv') != -1:
|
|
|
36 |
init.normal_(m.weight.data, 1.0, 0.02)
|
37 |
init.constant_(m.bias.data, 0.0)
|
38 |
|
|
|
39 |
def init_weights(net, init_type='normal'):
|
40 |
print('initialization method [%s]' % init_type)
|
41 |
if init_type == 'normal':
|
|
|
48 |
raise NotImplementedError(
|
49 |
'initialization method [%s] is not implemented' % init_type)
|
50 |
|
|
|
51 |
class FeatureExtraction(nn.Module):
|
52 |
def __init__(self, input_nc, ngf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_dropout=False):
|
53 |
super(FeatureExtraction, self).__init__()
|
|
|
72 |
def forward(self, x):
|
73 |
return self.model(x)
|
74 |
|
|
|
75 |
class FeatureL2Norm(torch.nn.Module):
|
76 |
def __init__(self):
|
77 |
super(FeatureL2Norm, self).__init__()
|
|
|
82 |
epsilon, 0.5).unsqueeze(1).expand_as(feature)
|
83 |
return torch.div(feature, norm)
|
84 |
|
|
|
85 |
class FeatureCorrelation(nn.Module):
|
86 |
def __init__(self):
|
87 |
super(FeatureCorrelation, self).__init__()
|
|
|
97 |
b, h, w, h*w).transpose(2, 3).transpose(1, 2)
|
98 |
return correlation_tensor
|
99 |
|
|
|
100 |
class FeatureRegression(nn.Module):
|
101 |
+
def __init__(self, input_nc=512, output_dim=6):
|
102 |
super(FeatureRegression, self).__init__()
|
103 |
self.conv = nn.Sequential(
|
104 |
nn.Conv2d(input_nc, 512, kernel_size=4, stride=2, padding=1),
|
|
|
116 |
)
|
117 |
self.linear = nn.Linear(64 * 4 * 3, output_dim)
|
118 |
self.tanh = nn.Tanh()
|
|
|
|
|
|
|
|
|
119 |
|
120 |
def forward(self, x):
|
121 |
x = self.conv(x)
|
|
|
124 |
x = self.tanh(x)
|
125 |
return x
|
126 |
|
|
|
127 |
class AffineGridGen(nn.Module):
|
128 |
def __init__(self, out_h=256, out_w=192, out_ch=3):
|
129 |
super(AffineGridGen, self).__init__()
|
|
|
138 |
(batch_size, self.out_ch, self.out_h, self.out_w))
|
139 |
return F.affine_grid(theta, out_size)
|
140 |
|
|
|
141 |
class TpsGridGen(nn.Module):
|
142 |
+
def __init__(self, out_h=256, out_w=192, use_regular_grid=True, grid_size=3, reg_factor=0):
|
143 |
super(TpsGridGen, self).__init__()
|
144 |
self.out_h, self.out_w = out_h, out_w
|
145 |
self.reg_factor = reg_factor
|
|
|
146 |
|
147 |
# create grid in numpy
|
148 |
self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
|
|
|
152 |
# grid_X,grid_Y: size [1,H,W,1,1]
|
153 |
self.grid_X = torch.FloatTensor(self.grid_X).unsqueeze(0).unsqueeze(3)
|
154 |
self.grid_Y = torch.FloatTensor(self.grid_Y).unsqueeze(0).unsqueeze(3)
|
|
|
|
|
|
|
155 |
|
156 |
# initialize regular grid for control points P_i
|
157 |
if use_regular_grid:
|
|
|
169 |
3).unsqueeze(4).transpose(0, 4)
|
170 |
self.P_Y = P_Y.unsqueeze(2).unsqueeze(
|
171 |
3).unsqueeze(4).transpose(0, 4)
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
def forward(self, theta):
|
174 |
warped_grid = self.apply_transformation(
|
|
|
193 |
L = torch.cat((torch.cat((K, P), 1), torch.cat(
|
194 |
(P.transpose(0, 1), Z), 1)), 0)
|
195 |
Li = torch.inverse(L)
|
|
|
|
|
196 |
return Li
|
197 |
|
198 |
def apply_transformation(self, theta, points):
|
|
|
289 |
# |num_downs|: number of downsamplings in UNet. For example,
|
290 |
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
|
291 |
# at the bottleneck
|
|
|
|
|
292 |
class UnetGenerator(nn.Module):
|
293 |
def __init__(self, input_nc, output_nc, num_downs, ngf=64,
|
294 |
norm_layer=nn.BatchNorm2d, use_dropout=False):
|
|
|
313 |
def forward(self, input):
|
314 |
return self.model(input)
|
315 |
|
|
|
316 |
# Defines the submodule with skip connection.
|
317 |
# X -------------------identity---------------------- X
|
318 |
# |-- downsampling -- |submodule| -- upsampling --|
|
|
|
366 |
else:
|
367 |
return torch.cat([x, self.model(x)], 1)
|
368 |
|
|
|
369 |
class Vgg19(nn.Module):
|
370 |
def __init__(self, requires_grad=False):
|
371 |
super(Vgg19, self).__init__()
|
|
|
398 |
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
399 |
return out
|
400 |
|
|
|
401 |
class VGGLoss(nn.Module):
|
402 |
def __init__(self, layids=None):
|
403 |
super(VGGLoss, self).__init__()
|
404 |
self.vgg = Vgg19()
|
|
|
405 |
self.criterion = nn.L1Loss()
|
406 |
self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
|
407 |
self.layids = layids
|
|
|
416 |
self.criterion(x_vgg[i], y_vgg[i].detach())
|
417 |
return loss
|
418 |
|
|
|
419 |
class DT(nn.Module):
|
420 |
def __init__(self):
|
421 |
super(DT, self).__init__()
|
|
|
424 |
dt = torch.abs(x1 - x2)
|
425 |
return dt
|
426 |
|
|
|
427 |
class DT2(nn.Module):
|
428 |
def __init__(self):
|
429 |
+
super(DT2, self).__init__()
|
430 |
|
431 |
def forward(self, x1, y1, x2, y2):
|
432 |
dt = torch.sqrt(torch.mul(x1 - x2, x1 - x2) +
|
433 |
torch.mul(y1 - y2, y1 - y2))
|
434 |
return dt
|
435 |
|
|
|
436 |
class GicLoss(nn.Module):
|
437 |
def __init__(self, opt):
|
438 |
super(GicLoss, self).__init__()
|
|
|
461 |
|
462 |
return torch.sum(torch.abs(dtleft - dtright) + torch.abs(dtup - dtdown))
|
463 |
|
|
|
464 |
class GMM(nn.Module):
|
465 |
""" Geometric Matching Module
|
466 |
"""
|
|
|
474 |
self.l2norm = FeatureL2Norm()
|
475 |
self.correlation = FeatureCorrelation()
|
476 |
self.regression = FeatureRegression(
|
477 |
+
input_nc=192, output_dim=2*opt.grid_size**2)
|
478 |
self.gridGen = TpsGridGen(
|
479 |
+
opt.fine_height, opt.fine_width, grid_size=opt.grid_size)
|
480 |
|
481 |
def forward(self, inputA, inputB):
|
482 |
featureA = self.extractionA(inputA)
|
|
|
489 |
grid = self.gridGen(theta)
|
490 |
return grid, theta
|
491 |
|
|
|
492 |
def save_checkpoint(model, save_path):
|
493 |
if not os.path.exists(os.path.dirname(save_path)):
|
494 |
os.makedirs(os.path.dirname(save_path))
|
495 |
+
torch.save(model.state_dict(), save_path)
|
|
|
|
|
|
|
496 |
|
497 |
def load_checkpoint(model, checkpoint_path):
|
498 |
if not os.path.exists(checkpoint_path):
|
499 |
return
|
500 |
+
model.load_state_dict(torch.load(checkpoint_path, map_location=torch.device('cpu')))
|
|