gaur3009 commited on
Commit
15e6502
·
verified ·
1 Parent(s): e6d6add

Update networks.py

Browse files
Files changed (1) hide show
  1. networks.py +10 -6
networks.py CHANGED
@@ -380,7 +380,7 @@ def load_checkpoint(model, checkpoint_path, strict=True):
380
  # Create a new state dict that matches our model architecture
381
  new_state_dict = {}
382
  for key, value in state_dict.items():
383
- # Handle any name changes here if needed
384
  new_key = key
385
  if 'gridGen' in key:
386
  # Map old parameter names to new ones
@@ -388,6 +388,10 @@ def load_checkpoint(model, checkpoint_path, strict=True):
388
  new_key = key.replace('P_X', 'P_X_base')
389
  elif 'P_Y' in key and 'base' not in key:
390
  new_key = key.replace('P_Y', 'P_Y_base')
 
 
 
 
391
 
392
  # Only include keys that exist in the current model
393
  if new_key in model.state_dict():
@@ -395,15 +399,15 @@ def load_checkpoint(model, checkpoint_path, strict=True):
395
 
396
  # Add missing TPS parameters if needed
397
  tps_params = ['gridGen.P_X_base', 'gridGen.P_Y_base', 'gridGen.Li',
398
- 'gridGen.grid_X', 'gridGen.grid_Y']
399
  for param in tps_params:
400
  if param not in new_state_dict and hasattr(model, 'gridGen'):
401
- print(f"Initializing missing TPS parameter: {param}")
402
- # Initialize with current model's value
403
- new_state_dict[param] = model.state_dict()[param]
404
 
405
  # Load the state dict
406
- model.load_state_dict(new_state_dict, strict=False) # Use strict=False to ignore missing keys
407
 
408
  # Print warnings
409
  model_keys = set(model.state_dict().keys())
 
380
  # Create a new state dict that matches our model architecture
381
  new_state_dict = {}
382
  for key, value in state_dict.items():
383
+ # Handle name changes
384
  new_key = key
385
  if 'gridGen' in key:
386
  # Map old parameter names to new ones
 
388
  new_key = key.replace('P_X', 'P_X_base')
389
  elif 'P_Y' in key and 'base' not in key:
390
  new_key = key.replace('P_Y', 'P_Y_base')
391
+ elif 'grid_X' in key and 'base' not in key:
392
+ new_key = key.replace('grid_X', 'grid_X_base')
393
+ elif 'grid_Y' in key and 'base' not in key:
394
+ new_key = key.replace('grid_Y', 'grid_Y_base')
395
 
396
  # Only include keys that exist in the current model
397
  if new_key in model.state_dict():
 
399
 
400
  # Add missing TPS parameters if needed
401
  tps_params = ['gridGen.P_X_base', 'gridGen.P_Y_base', 'gridGen.Li',
402
+ 'gridGen.grid_X_base', 'gridGen.grid_Y_base']
403
  for param in tps_params:
404
  if param not in new_state_dict and hasattr(model, 'gridGen'):
405
+ if param in model.state_dict():
406
+ print(f"Initializing missing TPS parameter: {param}")
407
+ new_state_dict[param] = model.state_dict()[param]
408
 
409
  # Load the state dict
410
+ model.load_state_dict(new_state_dict, strict=False)
411
 
412
  # Print warnings
413
  model_keys = set(model.state_dict().keys())