gaur3009 commited on
Commit
e6d6add
·
verified ·
1 Parent(s): 1964060

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +85 -75
networks.py CHANGED
@@ -120,6 +120,7 @@ class FeatureRegression(nn.Module):
120
  x = self.linear(x)
121
  return self.tanh(x)
122
 
 
123
  class TpsGridGen(nn.Module):
124
  def __init__(self, out_h=256, out_w=192, grid_size=5):
125
  super(TpsGridGen, self).__init__()
@@ -128,106 +129,115 @@ class TpsGridGen(nn.Module):
128
  self.grid_size = grid_size
129
  self.N = grid_size * grid_size
130
 
131
- # Create regular grid of control points
 
 
 
 
 
 
 
 
 
 
 
 
132
  axis_coords = np.linspace(-1, 1, grid_size)
133
  P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
134
- P_X = torch.FloatTensor(P_X.reshape(-1, 1)) # (N,1)
135
- P_Y = torch.FloatTensor(P_Y.reshape(-1, 1)) # (N,1)
136
 
137
- # Register buffers to persist through saving/loading
138
- self.register_buffer('P_X_base', P_X)
139
- self.register_buffer('P_Y_base', P_Y)
 
140
 
141
  # Compute inverse matrix L^-1
142
  Li = self.compute_L_inverse(P_X, P_Y)
143
- self.register_buffer('Li', Li)
144
-
145
- # Create sampling grid
146
- grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
147
- self.register_buffer('grid_X', torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3)) # (1,H,W,1)
148
- self.register_buffer('grid_Y', torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3)) # (1,H,W,1)
149
 
150
  def compute_L_inverse(self, X, Y):
151
- N = X.size(0)
152
- device = X.device
153
-
154
- # Construct distance matrix
155
- Xmat = X.expand(N, N)
156
- Ymat = Y.expand(N, N)
157
- P_dist_squared = torch.pow(Xmat - Xmat.t(), 2) + torch.pow(Ymat - Ymat.t(), 2)
158
- P_dist_squared[P_dist_squared == 0] = 1 # Avoid log(0)
159
- K = torch.mul(P_dist_squared, torch.log(P_dist_squared))
160
-
161
- # Construct L matrix
162
- O = torch.ones(N, 1, device=device)
163
- Z = torch.zeros(3, 3, device=device)
164
- P = torch.cat((O, X, Y), 1)
165
- L_top = torch.cat((K, P), 1)
166
- L_bottom = torch.cat((P.t(), Z), 1)
167
- L = torch.cat((L_top, L_bottom), 0)
168
-
169
- return torch.inverse(L)
170
 
171
  def forward(self, theta):
172
  batch_size = theta.size(0)
173
  device = theta.device
174
 
175
- # Split theta into x and y components
176
- Q_X = theta[:, :self.N].view(batch_size, self.N, 1)
177
- Q_Y = theta[:, self.N:].view(batch_size, self.N, 1)
178
  Q_X = Q_X + self.P_X_base.expand_as(Q_X)
179
  Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
180
 
181
- # Extract top-left NxN block of Li matrix
182
- Li_block = self.Li[:self.N, :self.N]
 
183
 
184
- # Compute weights
185
- W_X = torch.bmm(Li_block.expand(batch_size, -1, -1), Q_X)
186
- W_Y = torch.bmm(Li_block.expand(batch_size, -1, -1), Q_Y)
187
 
188
- # Prepare grid tensors
189
- grid_X = self.grid_X.expand(batch_size, -1, -1, -1)
190
- grid_Y = self.grid_Y.expand(batch_size, -1, -1, -1)
191
 
192
- # Compute transformed coordinates
193
- points_X = self.transform_points(grid_X, W_X, Q_X)
194
- points_Y = self.transform_points(grid_Y, W_Y, Q_Y)
195
 
196
- return torch.cat((points_X, points_Y), 3)
197
-
198
- # In TpsGridGen class, replace transform_points method with this:
199
- def transform_points(self, grid, W, Q):
200
- batch_size, h, w, _ = grid.size()
201
- n_points = h * w
202
-
203
- # Control points P (N, 2)
204
- P = torch.cat([self.P_X_base, self.P_Y_base], 1)
205
- P = P.unsqueeze(0).expand(batch_size, -1, -1) # (B, N, 2)
206
-
207
- # Compute U = r^2 * log(r^2)
208
- grid_flat = grid.view(batch_size, n_points, 2) # (B, H*W, 2)
209
- dist = grid_flat.unsqueeze(2) - P.unsqueeze(1) # (B, H*W, N, 2)
210
- dist_squared = torch.sum(dist**2, dim=3) # (B, H*W, N)
211
- dist_squared[dist_squared == 0] = 1 # Avoid log(0)
212
- U = dist_squared * torch.log(dist_squared)
213
 
214
- # Compute affine part [1, x, y]
215
- ones = torch.ones(batch_size, n_points, 1, device=grid.device)
216
- A = torch.cat([ones, grid_flat], dim=2) # (B, H*W, 3)
 
 
 
 
 
 
 
 
217
 
218
- # Warp coefficients
219
- W = W.view(batch_size, self.N, 1)
220
- Q = Q.view(batch_size, self.N, 1)
221
 
222
- # Non-affine part
223
- non_affine = torch.bmm(U, W) # (B, H*W, 1)
 
 
 
 
 
224
 
225
- # Affine part
226
- affine = torch.bmm(A, Q) # (B, H*W, 1)
227
 
228
- # Combine components
229
- points = affine + non_affine
230
- return points.view(batch_size, h, w, 1)
231
 
232
  class GMM(nn.Module):
233
  def __init__(self, opt=None):
 
120
  x = self.linear(x)
121
  return self.tanh(x)
122
 
123
+ # networks.py - TpsGridGen class replacement
124
  class TpsGridGen(nn.Module):
125
  def __init__(self, out_h=256, out_w=192, grid_size=5):
126
  super(TpsGridGen, self).__init__()
 
129
  self.grid_size = grid_size
130
  self.N = grid_size * grid_size
131
 
132
+ # Create grid in numpy
133
+ self.grid = np.zeros([self.out_h, self.out_w, 3], dtype=np.float32)
134
+
135
+ # Sampling grid with dim-0 (Y) and dim-1 (X) coords
136
+ grid_X, grid_Y = np.meshgrid(np.linspace(-1, 1, out_w), np.linspace(-1, 1, out_h))
137
+ self.grid_X = torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3) # [1, H, W, 1]
138
+ self.grid_Y = torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3) # [1, H, W, 1]
139
+
140
+ # Register buffers
141
+ self.register_buffer('grid_X_base', self.grid_X)
142
+ self.register_buffer('grid_Y_base', self.grid_Y)
143
+
144
+ # Initialize regular grid for control points
145
  axis_coords = np.linspace(-1, 1, grid_size)
146
  P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
147
+ P_X = np.reshape(P_X, (-1, 1)) # [N, 1]
148
+ P_Y = np.reshape(P_Y, (-1, 1)) # [N, 1]
149
 
150
+ self.P_X = torch.FloatTensor(P_X)
151
+ self.P_Y = torch.FloatTensor(P_Y)
152
+ self.register_buffer('P_X_base', self.P_X)
153
+ self.register_buffer('P_Y_base', self.P_Y)
154
 
155
  # Compute inverse matrix L^-1
156
  Li = self.compute_L_inverse(P_X, P_Y)
157
+ self.register_buffer('Li', torch.FloatTensor(Li))
 
 
 
 
 
158
 
159
  def compute_L_inverse(self, X, Y):
160
+ N = X.shape[0] # num of points (along dim 0)
161
+
162
+ # Construct matrix K
163
+ Xmat = np.tile(X, (1, N))
164
+ Ymat = np.tile(Y, (1, N))
165
+ P_dist_squared = np.power(Xmat - Xmat.T, 2) + np.power(Ymat - Ymat.T, 2)
166
+ P_dist_squared[P_dist_squared == 0] = 1 # make diagonal 1 to avoid NaN in log computation
167
+ K = P_dist_squared * np.log(P_dist_squared)
168
+
169
+ # Construct matrix L
170
+ O = np.ones((N, 1))
171
+ Z = np.zeros((3, 3))
172
+ P = np.concatenate((O, X, Y), axis=1)
173
+ L = np.concatenate((np.concatenate((K, P), axis=1),
174
+ np.concatenate((P.T, Z), axis=1)), axis=0)
175
+
176
+ Li = np.linalg.inv(L)
177
+ return Li
 
178
 
179
  def forward(self, theta):
180
  batch_size = theta.size(0)
181
  device = theta.device
182
 
183
+ # Split theta into point coordinates
184
+ Q_X = theta[:, :self.N].view(batch_size, self.N, 1, 1)
185
+ Q_Y = theta[:, self.N:].view(batch_size, self.N, 1, 1)
186
  Q_X = Q_X + self.P_X_base.expand_as(Q_X)
187
  Q_Y = Q_Y + self.P_Y_base.expand_as(Q_Y)
188
 
189
+ # Get spatial dimensions of points
190
+ points = torch.cat((self.grid_X_base.expand(batch_size, -1, -1, -1),
191
+ self.grid_Y_base.expand(batch_size, -1, -1, -1)), 3)
192
 
193
+ # Repeat pre-defined control points along spatial dimensions of points to be transformed
194
+ P_X = self.P_X_base.expand(batch_size, 1, 1, self.N)
195
+ P_Y = self.P_Y_base.expand(batch_size, 1, 1, self.N)
196
 
197
+ # Compute weights for non-linear part
198
+ W_X = torch.bmm(self.Li[:self.N, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_X.squeeze(-1))
199
+ W_Y = torch.bmm(self.Li[:self.N, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_Y.squeeze(-1))
200
 
201
+ # Reshape to [B, H, W, N]
202
+ W_X = W_X.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1)
203
+ W_Y = W_Y.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1)
204
 
205
+ # Compute weights for affine part
206
+ A_X = torch.bmm(self.Li[self.N:, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_X.squeeze(-1))
207
+ A_Y = torch.bmm(self.Li[self.N:, :self.N].unsqueeze(0).expand(batch_size, -1, -1), Q_Y.squeeze(-1))
208
+
209
+ # Reshape to [B, H, W, 1, 3]
210
+ A_X = A_X.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1)
211
+ A_Y = A_Y.unsqueeze(3).unsqueeze(4).transpose(1, 4).repeat(1, self.out_h, self.out_w, 1, 1)
 
 
 
 
 
 
 
 
 
 
212
 
213
+ # Compute distance P_i - (grid_X, grid_Y)
214
+ points_X = points[:, :, :, 0].unsqueeze(3) # [B, H, W, 1]
215
+ points_Y = points[:, :, :, 1].unsqueeze(3) # [B, H, W, 1]
216
+
217
+ delta_X = points_X - P_X
218
+ delta_Y = points_Y - P_Y
219
+
220
+ # Compute U (radial basis function)
221
+ dist_squared = torch.pow(delta_X, 2) + torch.pow(delta_Y, 2)
222
+ dist_squared[dist_squared == 0] = 1 # avoid NaN in log computation
223
+ U = dist_squared * torch.log(dist_squared)
224
 
225
+ # Compute non-affine part
226
+ points_X_prime = torch.sum(torch.mul(W_X, U), dim=4)
227
+ points_Y_prime = torch.sum(torch.mul(W_Y, U), dim=4)
228
 
229
+ # Compute affine part
230
+ A_X0 = A_X[:, :, :, :, 0]
231
+ A_X1 = A_X[:, :, :, :, 1]
232
+ A_X2 = A_X[:, :, :, :, 2]
233
+ A_Y0 = A_Y[:, :, :, :, 0]
234
+ A_Y1 = A_Y[:, :, :, :, 1]
235
+ A_Y2 = A_Y[:, :, :, :, 2]
236
 
237
+ points_X_prime += A_X0 + torch.mul(A_X1, points_X.squeeze(3)) + torch.mul(A_X2, points_Y.squeeze(3))
238
+ points_Y_prime += A_Y0 + torch.mul(A_Y1, points_X.squeeze(3)) + torch.mul(A_Y2, points_Y.squeeze(3))
239
 
240
+ return torch.cat((points_X_prime.unsqueeze(3), points_Y_prime.unsqueeze(3)), 3)
 
 
241
 
242
  class GMM(nn.Module):
243
  def __init__(self, opt=None):