Spaces:
Sleeping
Sleeping
Update networks.py
Browse files- networks.py +19 -41
networks.py
CHANGED
@@ -195,60 +195,38 @@ class TpsGridGen(nn.Module):
|
|
195 |
|
196 |
return torch.cat((points_X, points_Y), 3)
|
197 |
|
|
|
198 |
def transform_points(self, grid, W, Q):
|
199 |
batch_size, h, w, _ = grid.size()
|
200 |
n_points = h * w
|
201 |
|
202 |
-
#
|
203 |
-
|
204 |
-
|
205 |
-
# Prepare control points
|
206 |
-
P = torch.cat([self.P_X_base, self.P_Y_base], 1) # (N,2)
|
207 |
P = P.unsqueeze(0).expand(batch_size, -1, -1) # (B, N, 2)
|
208 |
|
209 |
-
# Compute
|
210 |
-
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
# Reshape grid for calculation
|
215 |
-
grid_reshaped = grid.view(batch_size, n_points, 1, 1).expand(-1, -1, -1, self.N) # (B, H*W, 1, N)
|
216 |
-
|
217 |
-
# Compute delta
|
218 |
-
delta = grid_reshaped - P_expanded
|
219 |
-
|
220 |
-
# Compute U (radial basis function)
|
221 |
-
dist_squared = torch.sum(torch.pow(delta, 2), dim=2, keepdim=False) # (B, H*W, N)
|
222 |
dist_squared[dist_squared == 0] = 1 # Avoid log(0)
|
223 |
-
U =
|
224 |
|
225 |
-
# Compute affine
|
226 |
-
|
227 |
-
A = torch.cat([
|
228 |
-
torch.ones(batch_size, n_points, 1, device=grid.device),
|
229 |
-
grid_flat.view(batch_size, n_points, 1)
|
230 |
-
], dim=2) # (B, H*W, 3)
|
231 |
|
232 |
-
#
|
233 |
-
|
234 |
-
|
235 |
-
# This is a simplified approach that works for the forward pass
|
236 |
|
237 |
-
#
|
238 |
-
affine_x = Q[:, :, 0].mean(dim=1, keepdim=True) # Simplified affine X
|
239 |
-
affine_y = Q[:, :, 0].mean(dim=1, keepdim=True) # Simplified affine Y
|
240 |
-
affine = torch.cat([
|
241 |
-
torch.ones(batch_size, n_points, 1, device=grid.device),
|
242 |
-
grid_flat.view(batch_size, n_points, 1) * affine_x,
|
243 |
-
grid_flat.view(batch_size, n_points, 1) * affine_y
|
244 |
-
], dim=2)
|
245 |
-
|
246 |
-
# Compute non-affine component
|
247 |
non_affine = torch.bmm(U, W) # (B, H*W, 1)
|
248 |
|
249 |
-
#
|
250 |
-
|
251 |
|
|
|
|
|
252 |
return points.view(batch_size, h, w, 1)
|
253 |
|
254 |
class GMM(nn.Module):
|
|
|
195 |
|
196 |
return torch.cat((points_X, points_Y), 3)
|
197 |
|
198 |
+
# In TpsGridGen class, replace transform_points method with this:
|
199 |
def transform_points(self, grid, W, Q):
|
200 |
batch_size, h, w, _ = grid.size()
|
201 |
n_points = h * w
|
202 |
|
203 |
+
# Control points P (N, 2)
|
204 |
+
P = torch.cat([self.P_X_base, self.P_Y_base], 1)
|
|
|
|
|
|
|
205 |
P = P.unsqueeze(0).expand(batch_size, -1, -1) # (B, N, 2)
|
206 |
|
207 |
+
# Compute U = r^2 * log(r^2)
|
208 |
+
grid_flat = grid.view(batch_size, n_points, 2) # (B, H*W, 2)
|
209 |
+
dist = grid_flat.unsqueeze(2) - P.unsqueeze(1) # (B, H*W, N, 2)
|
210 |
+
dist_squared = torch.sum(dist**2, dim=3) # (B, H*W, N)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
211 |
dist_squared[dist_squared == 0] = 1 # Avoid log(0)
|
212 |
+
U = dist_squared * torch.log(dist_squared)
|
213 |
|
214 |
+
# Compute affine part [1, x, y]
|
215 |
+
ones = torch.ones(batch_size, n_points, 1, device=grid.device)
|
216 |
+
A = torch.cat([ones, grid_flat], dim=2) # (B, H*W, 3)
|
|
|
|
|
|
|
217 |
|
218 |
+
# Warp coefficients
|
219 |
+
W = W.view(batch_size, self.N, 1)
|
220 |
+
Q = Q.view(batch_size, self.N, 1)
|
|
|
221 |
|
222 |
+
# Non-affine part
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
223 |
non_affine = torch.bmm(U, W) # (B, H*W, 1)
|
224 |
|
225 |
+
# Affine part
|
226 |
+
affine = torch.bmm(A, Q) # (B, H*W, 1)
|
227 |
|
228 |
+
# Combine components
|
229 |
+
points = affine + non_affine
|
230 |
return points.view(batch_size, h, w, 1)
|
231 |
|
232 |
class GMM(nn.Module):
|