Spaces:
Sleeping
Sleeping
Update networks.py
Browse files- networks.py +85 -75
networks.py
CHANGED
@@ -120,6 +120,7 @@ class FeatureRegression(nn.Module):
|
|
120 |
x = self.linear(x)
|
121 |
return self.tanh(x)
|
122 |
|
|
|
123 |
class TpsGridGen(nn.Module):
|
124 |
def __init__(self, out_h=256, out_w=192, grid_size=5):
|
125 |
super(TpsGridGen, self).__init__()
|
@@ -128,106 +129,115 @@ class TpsGridGen(nn.Module):
|
|
128 |
self.grid_size = grid_size
|
129 |
self.N = grid_size * grid_size
|
130 |
|
131 |
-
# Create
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
axis_coords = np.linspace(-1, 1, grid_size)
|
133 |
P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
|
134 |
-
P_X =
|
135 |
-
P_Y =
|
136 |
|
137 |
-
|
138 |
-
self.
|
139 |
-
self.register_buffer('
|
|
|
140 |
|
141 |
# Compute inverse matrix L^-1
|
142 |
Li = self.compute_L_inverse(P_X, P_Y)
|
143 |
-
self.register_buffer('Li', Li)
|
144 |
-
|
145 |
-
# Create sampling grid
|
146 |
-
grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
|
147 |
-
self.register_buffer('grid_X', torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3)) # (1,H,W,1)
|
148 |
-
self.register_buffer('grid_Y', torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3)) # (1,H,W,1)
|
149 |
|
150 |
def compute_L_inverse(self, X, Y):
|
151 |
-
N = X.
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
P_dist_squared =
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
return torch.inverse(L)
|
170 |
|
171 |
def forward(self, theta):
|
172 |
batch_size = theta.size(0)
|
173 |
device = theta.device
|
174 |
|
175 |
-
# Split theta into
|
176 |
-
Q_X = theta[:, :self.N].view(batch_size, self.N, 1)
|
177 |
-
Q_Y = theta[:, self.N:].view(batch_size, self.N, 1)
|
178 |
Q_X = Q_X + self.P_X_base.expand_as(Q_X)
|
179 |
Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
|
180 |
|
181 |
-
#
|
182 |
-
|
|
|
183 |
|
184 |
-
#
|
185 |
-
|
186 |
-
|
187 |
|
188 |
-
#
|
189 |
-
|
190 |
-
|
191 |
|
192 |
-
#
|
193 |
-
|
194 |
-
|
195 |
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
201 |
-
|
202 |
-
|
203 |
-
# Control points P (N, 2)
|
204 |
-
P = torch.cat([self.P_X_base, self.P_Y_base], 1)
|
205 |
-
P = P.unsqueeze(0).expand(batch_size, -1, -1) # (B, N, 2)
|
206 |
-
|
207 |
-
# Compute U = r^2 * log(r^2)
|
208 |
-
grid_flat = grid.view(batch_size, n_points, 2) # (B, H*W, 2)
|
209 |
-
dist = grid_flat.unsqueeze(2) - P.unsqueeze(1) # (B, H*W, N, 2)
|
210 |
-
dist_squared = torch.sum(dist**2, dim=3) # (B, H*W, N)
|
211 |
-
dist_squared[dist_squared == 0] = 1 # Avoid log(0)
|
212 |
-
U = dist_squared * torch.log(dist_squared)
|
213 |
|
214 |
-
# Compute
|
215 |
-
|
216 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
218 |
-
#
|
219 |
-
|
220 |
-
|
221 |
|
222 |
-
#
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
-
|
226 |
-
|
227 |
|
228 |
-
|
229 |
-
points = affine + non_affine
|
230 |
-
return points.view(batch_size, h, w, 1)
|
231 |
|
232 |
class GMM(nn.Module):
|
233 |
def __init__(self, opt=None):
|
|
|
120 |
x = self.linear(x)
|
121 |
return self.tanh(x)
|
122 |
|
123 |
+
# networks.py - TpsGridGen class replacement
|
124 |
class TpsGridGen(nn.Module):
|
125 |
def __init__(self, out_h=256, out_w=192, grid_size=5):
|
126 |
super(TpsGridGen, self).__init__()
|
|
|
129 |
self.grid_size = grid_size
|
130 |
self.N = grid_size * grid_size
|
131 |
|
132 |
+
# Create grid in numpy
|
133 |
+
self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
|
134 |
+
|
135 |
+
# Sampling grid with dim-0 (Y) and dim-1 (X) coords
|
136 |
+
grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
|
137 |
+
self.grid_X = torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3) # [1, H, W, 1]
|
138 |
+
self.grid_Y = torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3) # [1, H, W, 1]
|
139 |
+
|
140 |
+
# Register buffers
|
141 |
+
self.register_buffer('grid_X_base', self.grid_X)
|
142 |
+
self.register_buffer('grid_Y_base', self.grid_Y)
|
143 |
+
|
144 |
+
# Initialize regular grid for control points
|
145 |
axis_coords = np.linspace(-1, 1, grid_size)
|
146 |
P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
|
147 |
+
P_X = np.reshape(P_X, (-1, 1)) # [N, 1]
|
148 |
+
P_Y = np.reshape(P_Y, (-1, 1)) # [N, 1]
|
149 |
|
150 |
+
self.P_X = torch.FloatTensor(P_X)
|
151 |
+
self.P_Y = torch.FloatTensor(P_Y)
|
152 |
+
self.register_buffer('P_X_base', self.P_X)
|
153 |
+
self.register_buffer('P_Y_base', self.P_Y)
|
154 |
|
155 |
# Compute inverse matrix L^-1
|
156 |
Li = self.compute_L_inverse(P_X, P_Y)
|
157 |
+
self.register_buffer('Li', torch.FloatTensor(Li))
|
|
|
|
|
|
|
|
|
|
|
158 |
|
159 |
def compute_L_inverse(self, X, Y):
|
160 |
+
N = X.shape[0] # num of points (along dim 0)
|
161 |
+
|
162 |
+
# Construct matrix K
|
163 |
+
Xmat = np.tile(X, (1, N))
|
164 |
+
Ymat = np.tile(Y, (1, N))
|
165 |
+
P_dist_squared = np.power(Xmat - Xmat.T, 2) + np.power(Ymat - Ymat.T, 2)
|
166 |
+
P_dist_squared[P_dist_squared == 0] = 1 # make diagonal 1 to avoid NaN in log computation
|
167 |
+
K = P_dist_squared * np.log(P_dist_squared)
|
168 |
+
|
169 |
+
# Construct matrix L
|
170 |
+
O = np.ones((N, 1))
|
171 |
+
Z = np.zeros((3, 3))
|
172 |
+
P = np.concatenate((O, X, Y), axis=1)
|
173 |
+
L = np.concatenate((np.concatenate((K, P), axis=1),
|
174 |
+
np.concatenate((P.T, Z), axis=1)), axis=0)
|
175 |
+
|
176 |
+
Li = np.linalg.inv(L)
|
177 |
+
return Li
|
|
|
178 |
|
179 |
def forward(self, theta):
|
180 |
batch_size = theta.size(0)
|
181 |
device = theta.device
|
182 |
|
183 |
+
# Split theta into point coordinates
|
184 |
+
Q_X = theta[:, :self.N].view(batch_size, self.N, 1, 1)
|
185 |
+
Q_Y = theta[:, self.N:].view(batch_size, self.N, 1, 1)
|
186 |
Q_X = Q_X + self.P_X_base.expand_as(Q_X)
|
187 |
Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
|
188 |
|
189 |
+
# Get spatial dimensions of points
|
190 |
+
points = torch.cat((self.grid_X_base.expand(batch_size, -1, -1, -1),
|
191 |
+
self.grid_Y_base.expand(batch_size, -1, -1, -1)), 3)
|
192 |
|
193 |
+
# Repeat pre-defined control points along spatial dimensions of points to be transformed
|
194 |
+
P_X = self.P_X_base.expand(batch_size, 1, 1, self.N)
|
195 |
+
P_Y = self.P_Y_base.expand(batch_size, 1, 1, self.N)
|
196 |
|
197 |
+
# Compute weights for non-linear part
|
198 |
+
W_X = torch.bmm(self.Li[:self.N, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_X.squeeze(-1))
|
199 |
+
W_Y = torch.bmm(self.Li[:self.N, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_Y.squeeze(-1))
|
200 |
|
201 |
+
# Reshape to [B, H, W, N]
|
202 |
+
W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1)
|
203 |
+
W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1)
|
204 |
|
205 |
+
# Compute weights for affine part
|
206 |
+
A_X = torch.bmm(self.Li[self.N:, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_X.squeeze(-1))
|
207 |
+
A_Y = torch.bmm(self.Li[self.N:, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_Y.squeeze(-1))
|
208 |
+
|
209 |
+
# Reshape to [B, H, W, 1, 3]
|
210 |
+
A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1)
|
211 |
+
A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
+
# Compute distance P_i - (grid_X, grid_Y)
|
214 |
+
points_X = points[:, :, :, 0].unsqueeze(3) # [B, H, W, 1]
|
215 |
+
points_Y = points[:, :, :, 1].unsqueeze(3) # [B, H, W, 1]
|
216 |
+
|
217 |
+
delta_X = points_X - P_X
|
218 |
+
delta_Y = points_Y - P_Y
|
219 |
+
|
220 |
+
# Compute U (radial basis function)
|
221 |
+
dist_squared = torch.pow(delta_X, 2) + torch.pow(delta_Y, 2)
|
222 |
+
dist_squared[dist_squared == 0] = 1 # avoid NaN in log computation
|
223 |
+
U = dist_squared * torch.log(dist_squared)
|
224 |
|
225 |
+
# Compute non-affine part
|
226 |
+
points_X_prime = torch.sum(torch.mul(W_X, U), dim=4)
|
227 |
+
points_Y_prime = torch.sum(torch.mul(W_Y, U), dim=4)
|
228 |
|
229 |
+
# Compute affine part
|
230 |
+
A_X0 = A_X[:, :, :, :, 0]
|
231 |
+
A_X1 = A_X[:, :, :, :, 1]
|
232 |
+
A_X2 = A_X[:, :, :, :, 2]
|
233 |
+
A_Y0 = A_Y[:, :, :, :, 0]
|
234 |
+
A_Y1 = A_Y[:, :, :, :, 1]
|
235 |
+
A_Y2 = A_Y[:, :, :, :, 2]
|
236 |
|
237 |
+
points_X_prime += A_X0 + torch.mul(A_X1, points_X.squeeze(3)) + torch.mul(A_X2, points_Y.squeeze(3))
|
238 |
+
points_Y_prime += A_Y0 + torch.mul(A_Y1, points_X.squeeze(3)) + torch.mul(A_Y2, points_Y.squeeze(3))
|
239 |
|
240 |
+
return torch.cat((points_X_prime.unsqueeze(3), points_Y_prime.unsqueeze(3)), 3)
|
|
|
|
|
241 |
|
242 |
class GMM(nn.Module):
|
243 |
def __init__(self, opt=None):
|