Spaces:
Sleeping
Sleeping
Update networks.py
Browse files- networks.py +10 -6
networks.py
CHANGED
@@ -380,7 +380,7 @@ def load_checkpoint(model, checkpoint_path, strict=True):
|
|
380 |
# Create a new state dict that matches our model architecture
|
381 |
new_state_dict = {}
|
382 |
for key, value in state_dict.items():
|
383 |
-
# Handle
|
384 |
new_key = key
|
385 |
if 'gridGen' in key:
|
386 |
# Map old parameter names to new ones
|
@@ -388,6 +388,10 @@ def load_checkpoint(model, checkpoint_path, strict=True):
|
|
388 |
new_key = key.replace('P_X', 'P_X_base')
|
389 |
elif 'P_Y' in key and 'base' not in key:
|
390 |
new_key = key.replace('P_Y', 'P_Y_base')
|
|
|
|
|
|
|
|
|
391 |
|
392 |
# Only include keys that exist in the current model
|
393 |
if new_key in model.state_dict():
|
@@ -395,15 +399,15 @@ def load_checkpoint(model, checkpoint_path, strict=True):
|
|
395 |
|
396 |
# Add missing TPS parameters if needed
|
397 |
tps_params = ['gridGen.P_X_base', 'gridGen.P_Y_base', 'gridGen.Li',
|
398 |
-
'gridGen.
|
399 |
for param in tps_params:
|
400 |
if param not in new_state_dict and hasattr(model, 'gridGen'):
|
401 |
-
|
402 |
-
|
403 |
-
|
404 |
|
405 |
# Load the state dict
|
406 |
-
model.load_state_dict(new_state_dict, strict=False)
|
407 |
|
408 |
# Print warnings
|
409 |
model_keys = set(model.state_dict().keys())
|
|
|
380 |
# Create a new state dict that matches our model architecture
|
381 |
new_state_dict = {}
|
382 |
for key, value in state_dict.items():
|
383 |
+
# Handle name changes
|
384 |
new_key = key
|
385 |
if 'gridGen' in key:
|
386 |
# Map old parameter names to new ones
|
|
|
388 |
new_key = key.replace('P_X', 'P_X_base')
|
389 |
elif 'P_Y' in key and 'base' not in key:
|
390 |
new_key = key.replace('P_Y', 'P_Y_base')
|
391 |
+
elif 'grid_X' in key and 'base' not in key:
|
392 |
+
new_key = key.replace('grid_X', 'grid_X_base')
|
393 |
+
elif 'grid_Y' in key and 'base' not in key:
|
394 |
+
new_key = key.replace('grid_Y', 'grid_Y_base')
|
395 |
|
396 |
# Only include keys that exist in the current model
|
397 |
if new_key in model.state_dict():
|
|
|
399 |
|
400 |
# Add missing TPS parameters if needed
|
401 |
tps_params = ['gridGen.P_X_base', 'gridGen.P_Y_base', 'gridGen.Li',
|
402 |
+
'gridGen.grid_X_base', 'gridGen.grid_Y_base']
|
403 |
for param in tps_params:
|
404 |
if param not in new_state_dict and hasattr(model, 'gridGen'):
|
405 |
+
if param in model.state_dict():
|
406 |
+
print(f"Initializing missing TPS parameter: {param}")
|
407 |
+
new_state_dict[param] = model.state_dict()[param]
|
408 |
|
409 |
# Load the state dict
|
410 |
+
model.load_state_dict(new_state_dict, strict=False)
|
411 |
|
412 |
# Print warnings
|
413 |
model_keys = set(model.state_dict().keys())
|