gaur3009 commited on
Commit
427172d
·
verified ·
1 Parent(s): aa8b396

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +53 -37
networks.py CHANGED
@@ -203,30 +203,43 @@ class TpsGridGen(nn.Module):
203
  grid_flat = grid.view(batch_size, n_points, 1)
204
 
205
  # Prepare control points
206
- P = torch.cat([self.P_X_base, self.P_Y_base], 1).t().unsqueeze(0) # (1, 2, N)
207
- P = P.expand(batch_size, -1, -1) # (B, 2, N)
208
 
209
  # Compute distance between grid points and control points
210
  grid_expanded = grid_flat.expand(-1, -1, self.N) # (B, H*W, N)
211
- P_expanded = P.expand(n_points, -1, -1).permute(1, 0, 2) # (B, H*W, N)
212
- delta = grid_expanded - P_expanded
 
 
 
 
 
 
213
 
214
  # Compute U (radial basis function)
215
- dist_squared = torch.sum(torch.pow(delta, 2), dim=1, keepdim=True) # (B, H*W, 1)
216
  dist_squared[dist_squared == 0] = 1 # Avoid log(0)
217
  U = torch.mul(dist_squared, torch.log(dist_squared))
218
 
219
  # Compute affine transformation
 
220
  A = torch.cat([
221
  torch.ones(batch_size, n_points, 1, device=grid.device),
222
  grid_flat.view(batch_size, n_points, 1)
223
- ], dim=2)
224
 
225
- # Combine affine and non-affine components
226
- affine = torch.bmm(A, Q.view(batch_size, 1, 3).permute(0, 2, 1))
227
- non_affine = torch.bmm(U.permute(0, 2, 1), W).permute(0, 2, 1)
228
- points = affine + non_affine
229
 
 
 
 
 
 
 
 
 
230
  return points.view(batch_size, h, w, 1)
231
 
232
  class GMM(nn.Module):
@@ -367,37 +380,40 @@ def load_checkpoint(model, checkpoint_path, strict=True):
367
 
368
  state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
369
 
370
- # Initialize TPS grid parameters if missing
371
- if 'gridGen.P_X_base' not in state_dict:
372
- print("Initializing TPS grid parameters...")
373
- grid_size = model.gridGen.grid_size
374
- axis_coords = np.linspace(-1, 1, grid_size)
375
- P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
376
- P_X = torch.FloatTensor(P_X.reshape(-1, 1))
377
- P_Y = torch.FloatTensor(P_Y.reshape(-1, 1))
378
- state_dict['gridGen.P_X_base'] = P_X
379
- state_dict['gridGen.P_Y_base'] = P_Y
380
-
381
- # Compute Li
382
- Li = model.gridGen.compute_L_inverse(P_X, P_Y)
383
- state_dict['gridGen.Li'] = Li
384
-
385
- # Create grid
386
- grid_X, grid_Y = np.meshgrid(
387
- np.linspace(-1, 1, model.gridGen.out_w),
388
- np.linspace(-1, 1, model.gridGen.out_h)
389
- )
390
- state_dict['gridGen.grid_X'] = torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3)
391
- state_dict['gridGen.grid_Y'] = torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3)
 
 
392
 
393
- # Load state dict
394
- model.load_state_dict(state_dict, strict=strict)
395
 
396
  # Print warnings
397
  model_keys = set(model.state_dict().keys())
398
- ckpt_keys = set(state_dict.keys())
399
- missing = model_keys - ckpt_keys
400
- unexpected = ckpt_keys - model_keys
 
401
 
402
  if missing:
403
  print(f"Missing keys: {sorted(missing)}")
 
203
  grid_flat = grid.view(batch_size, n_points, 1)
204
 
205
  # Prepare control points
206
+ P = torch.cat([self.P_X_base, self.P_Y_base], 1) # (N,2)
207
+ P = P.unsqueeze(0).expand(batch_size, -1, -1) # (B, N, 2)
208
 
209
  # Compute distance between grid points and control points
210
  grid_expanded = grid_flat.expand(-1, -1, self.N) # (B, H*W, N)
211
+ P_expanded = P.permute(0, 2, 1).unsqueeze(1) # (B, 1, 2, N)
212
+ P_expanded = P_expanded.expand(-1, n_points, -1, -1) # (B, H*W, 2, N)
213
+
214
+ # Reshape grid for calculation
215
+ grid_reshaped = grid.view(batch_size, n_points, 1, 1).expand(-1, -1, -1, self.N) # (B, H*W, 1, N)
216
+
217
+ # Compute delta
218
+ delta = grid_reshaped - P_expanded
219
 
220
  # Compute U (radial basis function)
221
+ dist_squared = torch.sum(torch.pow(delta, 2), dim=2, keepdim=False) # (B, H*W, N)
222
  dist_squared[dist_squared == 0] = 1 # Avoid log(0)
223
  U = torch.mul(dist_squared, torch.log(dist_squared))
224
 
225
  # Compute affine transformation
226
+ # Create affine matrix [1, x, y]
227
  A = torch.cat([
228
  torch.ones(batch_size, n_points, 1, device=grid.device),
229
  grid_flat.view(batch_size, n_points, 1)
230
+ ], dim=2) # (B, H*W, 3)
231
 
232
+ # Extract affine parameters from Q
233
+ affine_params = Q.view(batch_size, 3, 1) # (B, 3, 1)
 
 
234
 
235
+ # Compute affine component
236
+ affine = torch.bmm(A, affine_params) # (B, H*W, 1)
237
+
238
+ # Compute non-affine component
239
+ non_affine = torch.bmm(U, W) # (B, H*W, 1)
240
+
241
+ # Combine components
242
+ points = affine + non_affine
243
  return points.view(batch_size, h, w, 1)
244
 
245
  class GMM(nn.Module):
 
380
 
381
  state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
382
 
383
+ # Create a new state dict that matches our model architecture
384
+ new_state_dict = {}
385
+ for key, value in state_dict.items():
386
+ # Handle any name changes here if needed
387
+ new_key = key
388
+ if 'gridGen' in key:
389
+ # Map old parameter names to new ones
390
+ if 'P_X' in key:
391
+ new_key = key.replace('P_X', 'P_X_base')
392
+ elif 'P_Y' in key:
393
+ new_key = key.replace('P_Y', 'P_Y_base')
394
+
395
+ # Only include keys that exist in the current model
396
+ if new_key in model.state_dict() and value.size() == model.state_dict()[new_key].size():
397
+ new_state_dict[new_key] = value
398
+
399
+ # Add missing TPS parameters if needed
400
+ tps_params = ['gridGen.P_X_base', 'gridGen.P_Y_base', 'gridGen.Li',
401
+ 'gridGen.grid_X', 'gridGen.grid_Y']
402
+ for param in tps_params:
403
+ if param not in new_state_dict and hasattr(model, 'gridGen'):
404
+ print(f"Initializing missing TPS parameter: {param}")
405
+ # Initialize with current model's value
406
+ new_state_dict[param] = model.state_dict()[param]
407
 
408
+ # Load the state dict
409
+ model.load_state_dict(new_state_dict, strict=strict)
410
 
411
  # Print warnings
412
  model_keys = set(model.state_dict().keys())
413
+ loaded_keys = set(new_state_dict.keys())
414
+
415
+ missing = model_keys - loaded_keys
416
+ unexpected = set(state_dict.keys()) - set(new_state_dict.keys())
417
 
418
  if missing:
419
  print(f"Missing keys: {sorted(missing)}")