gaur3009 commited on
Commit
4da5a6a
·
verified ·
1 Parent(s): 198f320

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +19 -41
networks.py CHANGED
@@ -195,60 +195,38 @@ class TpsGridGen(nn.Module):
195
 
196
  return torch.cat((points_X, points_Y), 3)
197
 
 
198
  def transform_points(self, grid, W, Q):
199
  batch_size, h, w, _ = grid.size()
200
  n_points = h * w
201
 
202
- # Flatten grid to (batch_size, H*W, 1)
203
- grid_flat = grid.view(batch_size, n_points, 1)
204
-
205
- # Prepare control points
206
- P = torch.cat([self.P_X_base, self.P_Y_base], 1) # (N,2)
207
  P = P.unsqueeze(0).expand(batch_size, -1, -1) # (B, N, 2)
208
 
209
- # Compute distance between grid points and control points
210
- grid_expanded = grid_flat.expand(-1, -1, self.N) # (B, H*W, N)
211
- P_expanded = P.permute(0, 2, 1).unsqueeze(1) # (B, 1, 2, N)
212
- P_expanded = P_expanded.expand(-1, n_points, -1, -1) # (B, H*W, 2, N)
213
-
214
- # Reshape grid for calculation
215
- grid_reshaped = grid.view(batch_size, n_points, 1, 1).expand(-1, -1, -1, self.N) # (B, H*W, 1, N)
216
-
217
- # Compute delta
218
- delta = grid_reshaped - P_expanded
219
-
220
- # Compute U (radial basis function)
221
- dist_squared = torch.sum(torch.pow(delta, 2), dim=2, keepdim=False) # (B, H*W, N)
222
  dist_squared[dist_squared == 0] = 1 # Avoid log(0)
223
- U = torch.mul(dist_squared, torch.log(dist_squared))
224
 
225
- # Compute affine transformation
226
- # Create affine matrix [1, x, y]
227
- A = torch.cat([
228
- torch.ones(batch_size, n_points, 1, device=grid.device),
229
- grid_flat.view(batch_size, n_points, 1)
230
- ], dim=2) # (B, H*W, 3)
231
 
232
- # Reshape Q to include affine parameters
233
- # Q has shape (B, N, 1) - we need to extract affine parameters
234
- # Instead, we'll use the full Li matrix for the affine part
235
- # This is a simplified approach that works for the forward pass
236
 
237
- # Compute affine component directly from Q
238
- affine_x = Q[:, :, 0].mean(dim=1, keepdim=True) # Simplified affine X
239
- affine_y = Q[:, :, 0].mean(dim=1, keepdim=True) # Simplified affine Y
240
- affine = torch.cat([
241
- torch.ones(batch_size, n_points, 1, device=grid.device),
242
- grid_flat.view(batch_size, n_points, 1) * affine_x,
243
- grid_flat.view(batch_size, n_points, 1) * affine_y
244
- ], dim=2)
245
-
246
- # Compute non-affine component
247
  non_affine = torch.bmm(U, W) # (B, H*W, 1)
248
 
249
- # Combine components
250
- points = affine[:, :, :1] + non_affine # Only use the affine bias for X/Y
251
 
 
 
252
  return points.view(batch_size, h, w, 1)
253
 
254
  class GMM(nn.Module):
 
195
 
196
  return torch.cat((points_X, points_Y), 3)
197
 
198
+ # In TpsGridGen class, replace transform_points method with this:
199
  def transform_points(self, grid, W, Q):
200
  batch_size, h, w, _ = grid.size()
201
  n_points = h * w
202
 
203
+ # Control points P (N, 2)
204
+ P = torch.cat([self.P_X_base, self.P_Y_base], 1)
 
 
 
205
  P = P.unsqueeze(0).expand(batch_size, -1, -1) # (B, N, 2)
206
 
207
+ # Compute U = r^2 * log(r^2)
208
+ grid_flat = grid.view(batch_size, n_points, 2) # (B, H*W, 2)
209
+ dist = grid_flat.unsqueeze(2) - P.unsqueeze(1) # (B, H*W, N, 2)
210
+ dist_squared = torch.sum(dist**2, dim=3) # (B, H*W, N)
 
 
 
 
 
 
 
 
 
211
  dist_squared[dist_squared == 0] = 1 # Avoid log(0)
212
+ U = dist_squared * torch.log(dist_squared)
213
 
214
+ # Compute affine part [1, x, y]
215
+ ones = torch.ones(batch_size, n_points, 1, device=grid.device)
216
+ A = torch.cat([ones, grid_flat], dim=2) # (B, H*W, 3)
 
 
 
217
 
218
+ # Warp coefficients
219
+ W = W.view(batch_size, self.N, 1)
220
+ Q = Q.view(batch_size, self.N, 1)
 
221
 
222
+ # Non-affine part
 
 
 
 
 
 
 
 
 
223
  non_affine = torch.bmm(U, W) # (B, H*W, 1)
224
 
225
+ # Affine part
226
+ affine = torch.bmm(A, Q) # (B, H*W, 1)
227
 
228
+ # Combine components
229
+ points = affine + non_affine
230
  return points.view(batch_size, h, w, 1)
231
 
232
  class GMM(nn.Module):