import numpy as np import torch import cv2 import open3d as o3d from dust3r.post_process import estimate_focal_knowing_depth from dust3r.utils.geometry import inv def estimate_focal(pts3d_i, pp=None): if pp is None: H, W, THREE = pts3d_i.shape assert THREE == 3 pp = torch.tensor((W/2, H/2), device=pts3d_i.device) focal = estimate_focal_knowing_depth(pts3d_i.unsqueeze(0), pp.unsqueeze(0), focal_mode='weiszfeld').ravel() return float(focal) def pixel_grid(H, W): return np.mgrid[:W, :H].T.astype(np.float32) def sRT_to_4x4(scale, R, T, device): trf = torch.eye(4, device=device) trf[:3, :3] = R * scale trf[:3, 3] = T.ravel() # doesn't need scaling return trf def to_numpy(tensor): return tensor.cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor def calculate_depth_map(pts3d, R, T): """ Calculate ray depths directly using camera center and 3D points. Args: pts3d (np.array): 3D points in world coordinates, shape (H, W, 3) R (np.array): Rotation matrix, shape (3, 3) T (np.array): Translation vector, shape (3, 1) Returns: np.array: Depth map of shape (H, W) """ # Camera center in world coordinates is simply -T C = -T.ravel() # Calculate ray vectors ray_vectors = pts3d - C # Calculate ray depths depth_map = np.linalg.norm(ray_vectors, axis=2) return depth_map def fast_pnp(pts3d, focal, msk, device, pp=None, niter_PnP=10): # extract camera poses and focals with RANSAC-PnP if msk.sum() < 4: return None # we need at least 4 points for PnP pts3d, msk = map(to_numpy, (pts3d, msk)) H, W, THREE = pts3d.shape assert THREE == 3 pixels = pixel_grid(H, W) if focal is None: S = max(W, H) tentative_focals = np.geomspace(S/2, S*3, 21) else: tentative_focals = [focal] if pp is None: pp = (W/2, H/2) else: pp = to_numpy(pp) best = 0, None, None, None, None for focal in tentative_focals: K = np.float32([(focal, 0, pp[0]), (0, focal, pp[1]), (0, 0, 1)]) success, R, T, inliers = cv2.solvePnPRansac(pts3d[msk], pixels[msk], K, None, iterationsCount=niter_PnP, reprojectionError=5, flags=cv2.SOLVEPNP_SQPNP) if not success: continue score = len(inliers) if success and score > best[0]: depth_map = calculate_depth_map(pts3d, R, T) best = score, R, T, focal, depth_map if not best[0]: return None _, R, T, best_focal, depth_map = best R = cv2.Rodrigues(R)[0] # world to cam R, T = map(torch.from_numpy, (R, T)) depth_map = torch.from_numpy(depth_map).to(device) cam_to_world = inv(sRT_to_4x4(1, R, T, device)) # cam to world return best_focal, cam_to_world, depth_map def solve_cemara(pts3d, msk, device, focal=None, pp=None): # Estimate focal length if focal is None: focal = estimate_focal(pts3d, pp) # Compute camera pose using PnP result = fast_pnp(pts3d, focal, msk, device, pp) if result is None: return None, focal, None best_focal, camera_to_world, depth_map = result # Construct K matrix H, W, _ = pts3d.shape if pp is None: pp = (W/2, H/2) camera_parameters = o3d.camera.PinholeCameraParameters() intrinsic = o3d.camera.PinholeCameraIntrinsic() intrinsic.set_intrinsics(W, H, best_focal, best_focal, pp[0], pp[1]) camera_parameters.intrinsic = intrinsic camera_parameters.extrinsic = torch.inverse(camera_to_world).cpu().numpy() return camera_parameters, best_focal, depth_map