gaur3009 commited on
Commit
aa8b396
·
verified ·
1 Parent(s): 3f625d4

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +47 -40
networks.py CHANGED
@@ -182,12 +182,12 @@ class TpsGridGen(nn.Module):
182
  Li_block = self.Li[:self.N, :self.N]
183
 
184
  # Compute weights
185
- W_X = torch.bmm(Li_block.expand(batch_size, self.N, self.N), Q_X)
186
- W_Y = torch.bmm(Li_block.expand(batch_size, self.N, self.N), Q_Y)
187
 
188
  # Prepare grid tensors
189
- grid_X = self.grid_X.expand(batch_size, self.out_h, self.out_w, 1).to(device)
190
- grid_Y = self.grid_Y.expand(batch_size, self.out_h, self.out_w, 1).to(device)
191
 
192
  # Compute transformed coordinates
193
  points_X = self.transform_points(grid_X, W_X, Q_X)
@@ -197,30 +197,36 @@ class TpsGridGen(nn.Module):
197
 
198
  def transform_points(self, grid, W, Q):
199
  batch_size, h, w, _ = grid.size()
 
200
 
201
- # Flatten grid to (batch_size, H*W, 2)
202
- grid_flat = grid.view(batch_size, -1, 1)
203
 
204
  # Prepare control points
205
- P = torch.cat([self.P_X_base, self.P_Y_base], 1).expand(batch_size, -1, -1).to(grid.device)
 
206
 
207
  # Compute distance between grid points and control points
208
- delta = grid_flat - P
 
 
209
 
210
  # Compute U (radial basis function)
211
- dist_squared = torch.sum(torch.pow(delta, 2), 2, keepdim=True)
212
  dist_squared[dist_squared == 0] = 1 # Avoid log(0)
213
  U = torch.mul(dist_squared, torch.log(dist_squared))
214
 
215
  # Compute affine transformation
216
  A = torch.cat([
217
- torch.ones(batch_size, h*w, 1, device=grid.device),
218
- grid_flat[:, :, 0:1],
219
- grid_flat[:, :, 1:2]
220
- ], 2)
221
 
222
  # Combine affine and non-affine components
223
- points = torch.bmm(A, Q.view(batch_size, 3, 1)) + torch.bmm(U, W)
 
 
 
224
  return points.view(batch_size, h, w, 1)
225
 
226
  class GMM(nn.Module):
@@ -361,36 +367,37 @@ def load_checkpoint(model, checkpoint_path, strict=True):
361
 
362
  state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
363
 
364
- # Create mapping for old buffer names to new names
365
- buffer_mapping = {
366
- 'gridGen.P_X': 'gridGen.P_X_base',
367
- 'gridGen.P_Y': 'gridGen.P_Y_base',
368
- 'gridGen.Li': 'gridGen.Li',
369
- 'gridGen.grid_X': 'gridGen.grid_X',
370
- 'gridGen.grid_Y': 'gridGen.grid_Y'
371
- }
372
-
373
- # Update state_dict keys
374
- updated_state_dict = {}
375
- for key, value in state_dict.items():
376
- # Handle buffer name changes
377
- for old_name, new_name in buffer_mapping.items():
378
- if key.startswith(old_name):
379
- key = key.replace(old_name, new_name)
380
 
381
- # Only keep keys that match current model
382
- if key in model.state_dict() and value.size() == model.state_dict()[key].size():
383
- updated_state_dict[key] = value
 
 
 
 
 
 
 
 
384
 
385
- # Load the updated state dict
386
- model.load_state_dict(updated_state_dict, strict=strict)
387
 
388
- # Print warnings about missing keys
389
  model_keys = set(model.state_dict().keys())
390
- checkpoint_keys = set(state_dict.keys())
391
-
392
- missing = model_keys - set(updated_state_dict.keys())
393
- unexpected = checkpoint_keys - set(updated_state_dict.keys())
394
 
395
  if missing:
396
  print(f"Missing keys: {sorted(missing)}")
 
182
  Li_block = self.Li[:self.N, :self.N]
183
 
184
  # Compute weights
185
+ W_X = torch.bmm(Li_block.expand(batch_size, -1, -1), Q_X)
186
+ W_Y = torch.bmm(Li_block.expand(batch_size, -1, -1), Q_Y)
187
 
188
  # Prepare grid tensors
189
+ grid_X = self.grid_X.expand(batch_size, -1, -1, -1)
190
+ grid_Y = self.grid_Y.expand(batch_size, -1, -1, -1)
191
 
192
  # Compute transformed coordinates
193
  points_X = self.transform_points(grid_X, W_X, Q_X)
 
197
 
198
  def transform_points(self, grid, W, Q):
199
  batch_size, h, w, _ = grid.size()
200
+ n_points = h * w
201
 
202
+ # Flatten grid to (batch_size, H*W, 1)
203
+ grid_flat = grid.view(batch_size, n_points, 1)
204
 
205
  # Prepare control points
206
+ P = torch.cat([self.P_X_base, self.P_Y_base], 1).t().unsqueeze(0) # (1, 2, N)
207
+ P = P.expand(batch_size, -1, -1) # (B, 2, N)
208
 
209
  # Compute distance between grid points and control points
210
+ grid_expanded = grid_flat.expand(-1, -1, self.N) # (B, H*W, N)
211
+ P_expanded = P.expand(n_points, -1, -1).permute(1, 0, 2) # (B, H*W, N)
212
+ delta = grid_expanded - P_expanded
213
 
214
  # Compute U (radial basis function)
215
+ dist_squared = torch.sum(torch.pow(delta, 2), dim=1, keepdim=True) # (B, H*W, 1)
216
  dist_squared[dist_squared == 0] = 1 # Avoid log(0)
217
  U = torch.mul(dist_squared, torch.log(dist_squared))
218
 
219
  # Compute affine transformation
220
  A = torch.cat([
221
+ torch.ones(batch_size, n_points, 1, device=grid.device),
222
+ grid_flat.view(batch_size, n_points, 1)
223
+ ], dim=2)
 
224
 
225
  # Combine affine and non-affine components
226
+ affine = torch.bmm(A, Q.view(batch_size, 1, 3).permute(0, 2, 1))
227
+ non_affine = torch.bmm(U.permute(0, 2, 1), W).permute(0, 2, 1)
228
+ points = affine + non_affine
229
+
230
  return points.view(batch_size, h, w, 1)
231
 
232
  class GMM(nn.Module):
 
367
 
368
  state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
369
 
370
+ # Initialize TPS grid parameters if missing
371
+ if 'gridGen.P_X_base' not in state_dict:
372
+ print("Initializing TPS grid parameters...")
373
+ grid_size = model.gridGen.grid_size
374
+ axis_coords = np.linspace(-1, 1, grid_size)
375
+ P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
376
+ P_X = torch.FloatTensor(P_X.reshape(-1, 1))
377
+ P_Y = torch.FloatTensor(P_Y.reshape(-1, 1))
378
+ state_dict['gridGen.P_X_base'] = P_X
379
+ state_dict['gridGen.P_Y_base'] = P_Y
 
 
 
 
 
 
380
 
381
+ # Compute Li
382
+ Li = model.gridGen.compute_L_inverse(P_X, P_Y)
383
+ state_dict['gridGen.Li'] = Li
384
+
385
+ # Create grid
386
+ grid_X, grid_Y = np.meshgrid(
387
+ np.linspace(-1, 1, model.gridGen.out_w),
388
+ np.linspace(-1, 1, model.gridGen.out_h)
389
+ )
390
+ state_dict['gridGen.grid_X'] = torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3)
391
+ state_dict['gridGen.grid_Y'] = torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3)
392
 
393
+ # Load state dict
394
+ model.load_state_dict(state_dict, strict=strict)
395
 
396
+ # Print warnings
397
  model_keys = set(model.state_dict().keys())
398
+ ckpt_keys = set(state_dict.keys())
399
+ missing = model_keys - ckpt_keys
400
+ unexpected = ckpt_keys - model_keys
 
401
 
402
  if missing:
403
  print(f"Missing keys: {sorted(missing)}")