Stable-X commited on
Commit
3717f04
1 Parent(s): 2c5f88b

fix: Clean pose_utils

Browse files
Files changed (1) hide show
  1. pose_utils.py +7 -34
pose_utils.py CHANGED
@@ -25,29 +25,6 @@ def sRT_to_4x4(scale, R, T, device):
25
  def to_numpy(tensor):
26
  return tensor.cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
27
 
28
- def calculate_depth_map(pts3d, R, T):
29
- """
30
- Calculate ray depths directly using camera center and 3D points.
31
-
32
- Args:
33
- pts3d (np.array): 3D points in world coordinates, shape (H, W, 3)
34
- R (np.array): Rotation matrix, shape (3, 3)
35
- T (np.array): Translation vector, shape (3, 1)
36
-
37
- Returns:
38
- np.array: Depth map of shape (H, W)
39
- """
40
- # Camera center in world coordinates is simply -T
41
- C = -T.ravel()
42
-
43
- # Calculate ray vectors
44
- ray_vectors = pts3d - C
45
-
46
- # Calculate ray depths
47
- depth_map = np.linalg.norm(ray_vectors, axis=2)
48
-
49
- return depth_map
50
-
51
  def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
52
  # extract camera poses and focals with RANSAC-PnP
53
  if msk.sum() < 4:
@@ -69,7 +46,7 @@ def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
69
  else:
70
  pp = to_numpy(pp)
71
 
72
- best = 0, None, None, None, None
73
  for focal in tentative_focals:
74
  K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
75
 
@@ -81,20 +58,16 @@ def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
81
 
82
  score = len(inliers)
83
  if success and score > best[0]:
84
- depth_map = calculate_depth_map(pts3d, R, T)
85
- best = score, R, T, focal, depth_map
86
 
87
  if not best[0]:
88
  return None
89
 
90
- _, R, T, best_focal, depth_map = best
91
  R = cv2.Rodrigues(R)[0] # world to cam
92
  R, T = map(torch.from_numpy, (R, T))
93
- depth_map = torch.from_numpy(depth_map).to(device)
94
-
95
- cam_to_world = inv(sRT_to_4x4(1, R, T, device)) # cam to world
96
 
97
- return best_focal, cam_to_world, depth_map
98
 
99
  def solve_cemara(pts3d, msk, device, focal=None, pp=None):
100
  # Estimate focal length
@@ -105,9 +78,9 @@ def solve_cemara(pts3d, msk, device, focal=None, pp=None):
105
  result = fast_pnp(pts3d, focal, msk, device, pp)
106
 
107
  if result is None:
108
- return None, focal, None
109
 
110
- best_focal, camera_to_world, depth_map = result
111
 
112
  # Construct K matrix
113
  H, W, _ = pts3d.shape
@@ -123,4 +96,4 @@ def solve_cemara(pts3d, msk, device, focal=None, pp=None):
123
  camera_parameters.intrinsic = intrinsic
124
  camera_parameters.extrinsic = torch.inverse(camera_to_world).cpu().numpy()
125
 
126
- return camera_parameters, best_focal, depth_map
 
25
  def to_numpy(tensor):
26
  return tensor.cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10):
29
  # extract camera poses and focals with RANSAC-PnP
30
  if msk.sum() < 4:
 
46
  else:
47
  pp = to_numpy(pp)
48
 
49
+ best = 0,
50
  for focal in tentative_focals:
51
  K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)])
52
 
 
58
 
59
  score = len(inliers)
60
  if success and score > best[0]:
61
+ best = score, R, T, focal
 
62
 
63
  if not best[0]:
64
  return None
65
 
66
+ _, R, T, best_focal = best
67
  R = cv2.Rodrigues(R)[0] # world to cam
68
  R, T = map(torch.from_numpy, (R, T))
 
 
 
69
 
70
+ return best_focal, inv(sRT_to_4x4(1, R, T, device)) # cam to world
71
 
72
  def solve_cemara(pts3d, msk, device, focal=None, pp=None):
73
  # Estimate focal length
 
78
  result = fast_pnp(pts3d, focal, msk, device, pp)
79
 
80
  if result is None:
81
+ return None, focal
82
 
83
+ best_focal, camera_to_world = result
84
 
85
  # Construct K matrix
86
  H, W, _ = pts3d.shape
 
96
  camera_parameters.intrinsic = intrinsic
97
  camera_parameters.extrinsic = torch.inverse(camera_to_world).cpu().numpy()
98
 
99
+ return camera_parameters, best_focal