Spaces:
Sleeping
Sleeping
Update networks.py
Browse files- networks.py +53 -37
networks.py
CHANGED
@@ -203,30 +203,43 @@ class TpsGridGen(nn.Module):
|
|
203 |
grid_flat = grid.view(batch_size, n_points, 1)
|
204 |
|
205 |
# Prepare control points
|
206 |
-
P = torch.cat([self.P_X_base, self.P_Y_base], 1)
|
207 |
-
P = P.expand(batch_size, -1, -1) # (B,
|
208 |
|
209 |
# Compute distance between grid points and control points
|
210 |
grid_expanded = grid_flat.expand(-1, -1, self.N) # (B, H*W, N)
|
211 |
-
P_expanded = P.
|
212 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
213 |
|
214 |
# Compute U (radial basis function)
|
215 |
-
dist_squared = torch.sum(torch.pow(delta, 2), dim=
|
216 |
dist_squared[dist_squared == 0] = 1 # Avoid log(0)
|
217 |
U = torch.mul(dist_squared, torch.log(dist_squared))
|
218 |
|
219 |
# Compute affine transformation
|
|
|
220 |
A = torch.cat([
|
221 |
torch.ones(batch_size, n_points, 1, device=grid.device),
|
222 |
grid_flat.view(batch_size, n_points, 1)
|
223 |
-
], dim=2)
|
224 |
|
225 |
-
#
|
226 |
-
|
227 |
-
non_affine = torch.bmm(U.permute(0, 2, 1), W).permute(0, 2, 1)
|
228 |
-
points = affine + non_affine
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
return points.view(batch_size, h, w, 1)
|
231 |
|
232 |
class GMM(nn.Module):
|
@@ -367,37 +380,40 @@ def load_checkpoint(model, checkpoint_path, strict=True):
|
|
367 |
|
368 |
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
369 |
|
370 |
-
#
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
|
376 |
-
|
377 |
-
|
378 |
-
|
379 |
-
|
380 |
-
|
381 |
-
|
382 |
-
|
383 |
-
state_dict
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
|
|
|
|
392 |
|
393 |
-
# Load state dict
|
394 |
-
model.load_state_dict(
|
395 |
|
396 |
# Print warnings
|
397 |
model_keys = set(model.state_dict().keys())
|
398 |
-
|
399 |
-
|
400 |
-
|
|
|
401 |
|
402 |
if missing:
|
403 |
print(f"Missing keys: {sorted(missing)}")
|
|
|
203 |
grid_flat = grid.view(batch_size, n_points, 1)
|
204 |
|
205 |
# Prepare control points
|
206 |
+
P = torch.cat([self.P_X_base, self.P_Y_base], 1) # (N,2)
|
207 |
+
P = P.unsqueeze(0).expand(batch_size, -1, -1) # (B, N, 2)
|
208 |
|
209 |
# Compute distance between grid points and control points
|
210 |
grid_expanded = grid_flat.expand(-1, -1, self.N) # (B, H*W, N)
|
211 |
+
P_expanded = P.permute(0, 2, 1).unsqueeze(1) # (B, 1, 2, N)
|
212 |
+
P_expanded = P_expanded.expand(-1, n_points, -1, -1) # (B, H*W, 2, N)
|
213 |
+
|
214 |
+
# Reshape grid for calculation
|
215 |
+
grid_reshaped = grid.view(batch_size, n_points, 1, 1).expand(-1, -1, -1, self.N) # (B, H*W, 1, N)
|
216 |
+
|
217 |
+
# Compute delta
|
218 |
+
delta = grid_reshaped - P_expanded
|
219 |
|
220 |
# Compute U (radial basis function)
|
221 |
+
dist_squared = torch.sum(torch.pow(delta, 2), dim=2, keepdim=False) # (B, H*W, N)
|
222 |
dist_squared[dist_squared == 0] = 1 # Avoid log(0)
|
223 |
U = torch.mul(dist_squared, torch.log(dist_squared))
|
224 |
|
225 |
# Compute affine transformation
|
226 |
+
# Create affine matrix [1, x, y]
|
227 |
A = torch.cat([
|
228 |
torch.ones(batch_size, n_points, 1, device=grid.device),
|
229 |
grid_flat.view(batch_size, n_points, 1)
|
230 |
+
], dim=2) # (B, H*W, 3)
|
231 |
|
232 |
+
# Extract affine parameters from Q
|
233 |
+
affine_params = Q.view(batch_size, 3, 1) # (B, 3, 1)
|
|
|
|
|
234 |
|
235 |
+
# Compute affine component
|
236 |
+
affine = torch.bmm(A, affine_params) # (B, H*W, 1)
|
237 |
+
|
238 |
+
# Compute non-affine component
|
239 |
+
non_affine = torch.bmm(U, W) # (B, H*W, 1)
|
240 |
+
|
241 |
+
# Combine components
|
242 |
+
points = affine + non_affine
|
243 |
return points.view(batch_size, h, w, 1)
|
244 |
|
245 |
class GMM(nn.Module):
|
|
|
380 |
|
381 |
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
382 |
|
383 |
+
# Create a new state dict that matches our model architecture
|
384 |
+
new_state_dict = {}
|
385 |
+
for key, value in state_dict.items():
|
386 |
+
# Handle any name changes here if needed
|
387 |
+
new_key = key
|
388 |
+
if 'gridGen' in key:
|
389 |
+
# Map old parameter names to new ones
|
390 |
+
if 'P_X' in key:
|
391 |
+
new_key = key.replace('P_X', 'P_X_base')
|
392 |
+
elif 'P_Y' in key:
|
393 |
+
new_key = key.replace('P_Y', 'P_Y_base')
|
394 |
+
|
395 |
+
# Only include keys that exist in the current model
|
396 |
+
if new_key in model.state_dict() and value.size() == model.state_dict()[new_key].size():
|
397 |
+
new_state_dict[new_key] = value
|
398 |
+
|
399 |
+
# Add missing TPS parameters if needed
|
400 |
+
tps_params = ['gridGen.P_X_base', 'gridGen.P_Y_base', 'gridGen.Li',
|
401 |
+
'gridGen.grid_X', 'gridGen.grid_Y']
|
402 |
+
for param in tps_params:
|
403 |
+
if param not in new_state_dict and hasattr(model, 'gridGen'):
|
404 |
+
print(f"Initializing missing TPS parameter: {param}")
|
405 |
+
# Initialize with current model's value
|
406 |
+
new_state_dict[param] = model.state_dict()[param]
|
407 |
|
408 |
+
# Load the state dict
|
409 |
+
model.load_state_dict(new_state_dict, strict=strict)
|
410 |
|
411 |
# Print warnings
|
412 |
model_keys = set(model.state_dict().keys())
|
413 |
+
loaded_keys = set(new_state_dict.keys())
|
414 |
+
|
415 |
+
missing = model_keys - loaded_keys
|
416 |
+
unexpected = set(state_dict.keys()) - set(new_state_dict.keys())
|
417 |
|
418 |
if missing:
|
419 |
print(f"Missing keys: {sorted(missing)}")
|