Spaces:
Sleeping
Sleeping
Update networks.py
Browse files- networks.py +47 -40
networks.py
CHANGED
@@ -182,12 +182,12 @@ class TpsGridGen(nn.Module):
|
|
182 |
Li_block = self.Li[:self.N, :self.N]
|
183 |
|
184 |
# Compute weights
|
185 |
-
W_X = torch.bmm(Li_block.expand(batch_size,
|
186 |
-
W_Y = torch.bmm(Li_block.expand(batch_size,
|
187 |
|
188 |
# Prepare grid tensors
|
189 |
-
grid_X = self.grid_X.expand(batch_size,
|
190 |
-
grid_Y = self.grid_Y.expand(batch_size,
|
191 |
|
192 |
# Compute transformed coordinates
|
193 |
points_X = self.transform_points(grid_X, W_X, Q_X)
|
@@ -197,30 +197,36 @@ class TpsGridGen(nn.Module):
|
|
197 |
|
198 |
def transform_points(self, grid, W, Q):
|
199 |
batch_size, h, w, _ = grid.size()
|
|
|
200 |
|
201 |
-
# Flatten grid to (batch_size, H*W,
|
202 |
-
grid_flat = grid.view(batch_size,
|
203 |
|
204 |
# Prepare control points
|
205 |
-
P = torch.cat([self.P_X_base, self.P_Y_base], 1).
|
|
|
206 |
|
207 |
# Compute distance between grid points and control points
|
208 |
-
|
|
|
|
|
209 |
|
210 |
# Compute U (radial basis function)
|
211 |
-
dist_squared = torch.sum(torch.pow(delta, 2),
|
212 |
dist_squared[dist_squared == 0] = 1 # Avoid log(0)
|
213 |
U = torch.mul(dist_squared, torch.log(dist_squared))
|
214 |
|
215 |
# Compute affine transformation
|
216 |
A = torch.cat([
|
217 |
-
torch.ones(batch_size,
|
218 |
-
grid_flat
|
219 |
-
|
220 |
-
], 2)
|
221 |
|
222 |
# Combine affine and non-affine components
|
223 |
-
|
|
|
|
|
|
|
224 |
return points.view(batch_size, h, w, 1)
|
225 |
|
226 |
class GMM(nn.Module):
|
@@ -361,36 +367,37 @@ def load_checkpoint(model, checkpoint_path, strict=True):
|
|
361 |
|
362 |
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
363 |
|
364 |
-
#
|
365 |
-
|
366 |
-
|
367 |
-
|
368 |
-
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
|
373 |
-
|
374 |
-
updated_state_dict = {}
|
375 |
-
for key, value in state_dict.items():
|
376 |
-
# Handle buffer name changes
|
377 |
-
for old_name, new_name in buffer_mapping.items():
|
378 |
-
if key.startswith(old_name):
|
379 |
-
key = key.replace(old_name, new_name)
|
380 |
|
381 |
-
#
|
382 |
-
|
383 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
385 |
-
# Load
|
386 |
-
model.load_state_dict(
|
387 |
|
388 |
-
# Print warnings
|
389 |
model_keys = set(model.state_dict().keys())
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
unexpected = checkpoint_keys - set(updated_state_dict.keys())
|
394 |
|
395 |
if missing:
|
396 |
print(f"Missing keys: {sorted(missing)}")
|
|
|
182 |
Li_block = self.Li[:self.N, :self.N]
|
183 |
|
184 |
# Compute weights
|
185 |
+
W_X = torch.bmm(Li_block.expand(batch_size, -1, -1), Q_X)
|
186 |
+
W_Y = torch.bmm(Li_block.expand(batch_size, -1, -1), Q_Y)
|
187 |
|
188 |
# Prepare grid tensors
|
189 |
+
grid_X = self.grid_X.expand(batch_size, -1, -1, -1)
|
190 |
+
grid_Y = self.grid_Y.expand(batch_size, -1, -1, -1)
|
191 |
|
192 |
# Compute transformed coordinates
|
193 |
points_X = self.transform_points(grid_X, W_X, Q_X)
|
|
|
197 |
|
198 |
def transform_points(self, grid, W, Q):
|
199 |
batch_size, h, w, _ = grid.size()
|
200 |
+
n_points = h * w
|
201 |
|
202 |
+
# Flatten grid to (batch_size, H*W, 1)
|
203 |
+
grid_flat = grid.view(batch_size, n_points, 1)
|
204 |
|
205 |
# Prepare control points
|
206 |
+
P = torch.cat([self.P_X_base, self.P_Y_base], 1).t().unsqueeze(0) # (1, 2, N)
|
207 |
+
P = P.expand(batch_size, -1, -1) # (B, 2, N)
|
208 |
|
209 |
# Compute distance between grid points and control points
|
210 |
+
grid_expanded = grid_flat.expand(-1, -1, self.N) # (B, H*W, N)
|
211 |
+
P_expanded = P.expand(n_points, -1, -1).permute(1, 0, 2) # (B, H*W, N)
|
212 |
+
delta = grid_expanded - P_expanded
|
213 |
|
214 |
# Compute U (radial basis function)
|
215 |
+
dist_squared = torch.sum(torch.pow(delta, 2), dim=1, keepdim=True) # (B, H*W, 1)
|
216 |
dist_squared[dist_squared == 0] = 1 # Avoid log(0)
|
217 |
U = torch.mul(dist_squared, torch.log(dist_squared))
|
218 |
|
219 |
# Compute affine transformation
|
220 |
A = torch.cat([
|
221 |
+
torch.ones(batch_size, n_points, 1, device=grid.device),
|
222 |
+
grid_flat.view(batch_size, n_points, 1)
|
223 |
+
], dim=2)
|
|
|
224 |
|
225 |
# Combine affine and non-affine components
|
226 |
+
affine = torch.bmm(A, Q.view(batch_size, 1, 3).permute(0, 2, 1))
|
227 |
+
non_affine = torch.bmm(U.permute(0, 2, 1), W).permute(0, 2, 1)
|
228 |
+
points = affine + non_affine
|
229 |
+
|
230 |
return points.view(batch_size, h, w, 1)
|
231 |
|
232 |
class GMM(nn.Module):
|
|
|
367 |
|
368 |
state_dict = torch.load(checkpoint_path, map_location=torch.device('cpu'))
|
369 |
|
370 |
+
# Initialize TPS grid parameters if missing
|
371 |
+
if 'gridGen.P_X_base' not in state_dict:
|
372 |
+
print("Initializing TPS grid parameters...")
|
373 |
+
grid_size = model.gridGen.grid_size
|
374 |
+
axis_coords = np.linspace(-1, 1, grid_size)
|
375 |
+
P_Y, P_X = np.meshgrid(axis_coords, axis_coords)
|
376 |
+
P_X = torch.FloatTensor(P_X.reshape(-1, 1))
|
377 |
+
P_Y = torch.FloatTensor(P_Y.reshape(-1, 1))
|
378 |
+
state_dict['gridGen.P_X_base'] = P_X
|
379 |
+
state_dict['gridGen.P_Y_base'] = P_Y
|
|
|
|
|
|
|
|
|
|
|
|
|
380 |
|
381 |
+
# Compute Li
|
382 |
+
Li = model.gridGen.compute_L_inverse(P_X, P_Y)
|
383 |
+
state_dict['gridGen.Li'] = Li
|
384 |
+
|
385 |
+
# Create grid
|
386 |
+
grid_X, grid_Y = np.meshgrid(
|
387 |
+
np.linspace(-1, 1, model.gridGen.out_w),
|
388 |
+
np.linspace(-1, 1, model.gridGen.out_h)
|
389 |
+
)
|
390 |
+
state_dict['gridGen.grid_X'] = torch.FloatTensor(grid_X).unsqueeze(0).unsqueeze(3)
|
391 |
+
state_dict['gridGen.grid_Y'] = torch.FloatTensor(grid_Y).unsqueeze(0).unsqueeze(3)
|
392 |
|
393 |
+
# Load state dict
|
394 |
+
model.load_state_dict(state_dict, strict=strict)
|
395 |
|
396 |
+
# Print warnings
|
397 |
model_keys = set(model.state_dict().keys())
|
398 |
+
ckpt_keys = set(state_dict.keys())
|
399 |
+
missing = model_keys - ckpt_keys
|
400 |
+
unexpected = ckpt_keys - model_keys
|
|
|
401 |
|
402 |
if missing:
|
403 |
print(f"Missing keys: {sorted(missing)}")
|