# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # Base class for colmap / kapture # -------------------------------------------------------- import collections import os import pickle import numpy as np import PIL.Image import torch import torchvision.transforms as tvf from dust3r.datasets.utils.transforms import ImgNorm from dust3r.utils.geometry import colmap_to_opencv_intrinsics from dust3r_visloc.datasets.base_dataset import BaseVislocDataset from dust3r_visloc.datasets.utils import ( cam_to_world_from_kapture, get_resize_function, rescale_points3d, ) from kapture.core import CameraType from kapture.io.csv import kapture_from_dir from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file from scipy.spatial.transform import Rotation from tqdm import tqdm KaptureSensor = collections.namedtuple("Sensor", "sensor_params camera_params") def kapture_to_opencv_intrinsics(sensor): """ Convert from Kapture to OpenCV parameters. Warning: we assume that the camera and pixel coordinates follow Colmap conventions here. Args: sensor: Kapture sensor """ sensor_type = sensor.sensor_params[0] if sensor_type == "SIMPLE_PINHOLE": # Simple pinhole model. # We still call OpenCV undistorsion however for code simplicity. w, h, f, cx, cy = sensor.camera_params k1 = 0 k2 = 0 p1 = 0 p2 = 0 fx = fy = f elif sensor_type == "PINHOLE": w, h, fx, fy, cx, cy = sensor.camera_params k1 = 0 k2 = 0 p1 = 0 p2 = 0 elif sensor_type == "SIMPLE_RADIAL": w, h, f, cx, cy, k1 = sensor.camera_params k2 = 0 p1 = 0 p2 = 0 fx = fy = f elif sensor_type == "RADIAL": w, h, f, cx, cy, k1, k2 = sensor.camera_params p1 = 0 p2 = 0 fx = fy = f elif sensor_type == "OPENCV": w, h, fx, fy, cx, cy, k1, k2, p1, p2 = sensor.camera_params else: raise NotImplementedError(f"Sensor type {sensor_type} is not supported yet.") cameraMatrix = np.asarray([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32) # We assume that Kapture data comes from Colmap: the origin is different. cameraMatrix = colmap_to_opencv_intrinsics(cameraMatrix) distCoeffs = np.asarray([k1, k2, p1, p2], dtype=np.float32) return cameraMatrix, distCoeffs, (w, h) def K_from_colmap(elems): sensor = KaptureSensor(elems, tuple(map(float, elems[1:]))) cameraMatrix, distCoeffs, (w, h) = kapture_to_opencv_intrinsics(sensor) res = dict(resolution=(w, h), intrinsics=cameraMatrix, distortion=distCoeffs) return res def pose_from_qwxyz_txyz(elems): qw, qx, qy, qz, tx, ty, tz = map(float, elems) pose = np.eye(4) pose[:3, :3] = Rotation.from_quat((qx, qy, qz, qw)).as_matrix() pose[:3, 3] = (tx, ty, tz) return np.linalg.inv(pose) # returns cam2world class BaseVislocColmapDataset(BaseVislocDataset): def __init__( self, image_path, map_path, query_path, pairsfile_path, topk=1, cache_sfm=False ): super().__init__() self.topk = topk self.num_views = self.topk + 1 self.image_path = image_path self.cache_sfm = cache_sfm self._load_sfm(map_path) kdata_query = kapture_from_dir(query_path) assert ( kdata_query.records_camera is not None and kdata_query.trajectories is not None ) kdata_query_searchindex = { kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) for timestamp, sensor_id in kdata_query.records_camera.key_pairs() } self.query_data = {"kdata": kdata_query, "searchindex": kdata_query_searchindex} self.pairs = get_ordered_pairs_from_file(pairsfile_path) self.scenes = kdata_query.records_camera.data_list() def _load_sfm(self, sfm_dir): sfm_cache_path = os.path.join(sfm_dir, "dust3r_cache.pkl") if os.path.isfile(sfm_cache_path) and self.cache_sfm: with open(sfm_cache_path, "rb") as f: data = pickle.load(f) self.img_infos = data["img_infos"] self.points3D = data["points3D"] return # load cameras with open(os.path.join(sfm_dir, "cameras.txt"), "r") as f: raw = f.read().splitlines()[3:] # skip header intrinsics = {} for camera in tqdm(raw): camera = camera.split(" ") intrinsics[int(camera[0])] = K_from_colmap(camera[1:]) # load images with open(os.path.join(sfm_dir, "images.txt"), "r") as f: raw = f.read().splitlines() raw = [line for line in raw if not line.startswith("#")] # skip header self.img_infos = {} for image, points in tqdm(zip(raw[0::2], raw[1::2]), total=len(raw) // 2): image = image.split(" ") points = points.split(" ") img_name = image[-1] current_points2D = { int(i): (float(x), float(y)) for i, x, y in zip(points[2::3], points[0::3], points[1::3]) if i != "-1" } self.img_infos[img_name] = dict( intrinsics[int(image[-2])], path=img_name, camera_pose=pose_from_qwxyz_txyz(image[1:-2]), sparse_pts2d=current_points2D, ) # load 3D points with open(os.path.join(sfm_dir, "points3D.txt"), "r") as f: raw = f.read().splitlines() raw = [line for line in raw if not line.startswith("#")] # skip header self.points3D = {} for point in tqdm(raw): point = point.split() self.points3D[int(point[0])] = tuple(map(float, point[1:4])) if self.cache_sfm: to_save = {"img_infos": self.img_infos, "points3D": self.points3D} with open(sfm_cache_path, "wb") as f: pickle.dump(to_save, f) def __len__(self): return len(self.scenes) def _get_view_query(self, imgname): kdata, searchindex = map(self.query_data.get, ["kdata", "searchindex"]) timestamp, camera_id = searchindex[imgname] camera_params = kdata.sensors[camera_id].camera_params if kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_PINHOLE: W, H, f, cx, cy = camera_params k1 = 0 fx = fy = f elif kdata.sensors[camera_id].camera_type == CameraType.SIMPLE_RADIAL: W, H, f, cx, cy, k1 = camera_params fx = fy = f else: raise NotImplementedError("not implemented") W, H = int(W), int(H) intrinsics = np.float32([(fx, 0, cx), (0, fy, cy), (0, 0, 1)]) intrinsics = colmap_to_opencv_intrinsics(intrinsics) distortion = [k1, 0, 0, 0] if ( kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories ): cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) else: cam_to_world = np.eye(4, dtype=np.float32) # Load RGB image rgb_image = PIL.Image.open(os.path.join(self.image_path, imgname)).convert( "RGB" ) rgb_image.load() resize_func, _, to_orig = get_resize_function( self.maxdim, self.patch_size, H, W ) rgb_tensor = resize_func(ImgNorm(rgb_image)) view = { "intrinsics": intrinsics, "distortion": distortion, "cam_to_world": cam_to_world, "rgb": rgb_image, "rgb_rescaled": rgb_tensor, "to_orig": to_orig, "idx": 0, "image_name": imgname, } return view def _get_view_map(self, imgname, idx): infos = self.img_infos[imgname] rgb_image = PIL.Image.open( os.path.join(self.image_path, infos["path"]) ).convert("RGB") rgb_image.load() W, H = rgb_image.size intrinsics = infos["intrinsics"] intrinsics = colmap_to_opencv_intrinsics(intrinsics) distortion_coefs = infos["distortion"] pts2d = infos["sparse_pts2d"] sparse_pos2d = np.float32(list(pts2d.values())).reshape( (-1, 2) ) # pts2d from colmap sparse_pts3d = np.float32([self.points3D[i] for i in pts2d]).reshape((-1, 3)) # store full resolution 2D->3D sparse_pos2d_cv2 = sparse_pos2d.copy() sparse_pos2d_cv2[:, 0] -= 0.5 sparse_pos2d_cv2[:, 1] -= 0.5 sparse_pos2d_int = sparse_pos2d_cv2.round().astype(np.int64) valid = ( (sparse_pos2d_int[:, 0] >= 0) & (sparse_pos2d_int[:, 0] < W) & (sparse_pos2d_int[:, 1] >= 0) & (sparse_pos2d_int[:, 1] < H) ) sparse_pos2d_int = sparse_pos2d_int[valid] # nan => invalid pts3d = np.full((H, W, 3), np.nan, dtype=np.float32) pts3d[sparse_pos2d_int[:, 1], sparse_pos2d_int[:, 0]] = sparse_pts3d[valid] pts3d = torch.from_numpy(pts3d) cam_to_world = infos["camera_pose"] # cam2world # also store resized resolution 2D->3D resize_func, to_resize, to_orig = get_resize_function( self.maxdim, self.patch_size, H, W ) rgb_tensor = resize_func(ImgNorm(rgb_image)) HR, WR = rgb_tensor.shape[1:] _, _, pts3d_rescaled, valid_rescaled = rescale_points3d( sparse_pos2d_cv2, sparse_pts3d, to_resize, HR, WR ) pts3d_rescaled = torch.from_numpy(pts3d_rescaled) valid_rescaled = torch.from_numpy(valid_rescaled) view = { "intrinsics": intrinsics, "distortion": distortion_coefs, "cam_to_world": cam_to_world, "rgb": rgb_image, "pts3d": pts3d, "valid": pts3d.sum(dim=-1).isfinite(), "rgb_rescaled": rgb_tensor, "pts3d_rescaled": pts3d_rescaled, "valid_rescaled": valid_rescaled, "to_orig": to_orig, "idx": idx, "image_name": imgname, } return view def __getitem__(self, idx): assert self.maxdim is not None and self.patch_size is not None query_image = self.scenes[idx] map_images = [p[0] for p in self.pairs[query_image][: self.topk]] views = [] views.append(self._get_view_query(query_image)) for idx, map_image in enumerate(map_images): views.append(self._get_view_map(map_image, idx + 1)) return views