# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # InLoc dataloader # -------------------------------------------------------- import os import kapture import numpy as np import PIL.Image import scipy.io import torch from dust3r.datasets.utils.transforms import ImgNorm from dust3r.utils.geometry import geotrf, xy_grid from dust3r_visloc.datasets.base_dataset import BaseVislocDataset from dust3r_visloc.datasets.utils import ( cam_to_world_from_kapture, get_resize_function, rescale_points3d, ) from kapture.io.csv import kapture_from_dir from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file def read_alignments(path_to_alignment): aligns = {} with open(path_to_alignment, "r") as fid: while True: line = fid.readline() if not line: break if len(line) == 4: trans_nr = line[:-1] while line != "After general icp:\n": line = fid.readline() line = fid.readline() p = [] for i in range(4): elems = line.split(" ") line = fid.readline() for e in elems: if len(e) != 0: p.append(float(e)) P = np.array(p).reshape(4, 4) aligns[trans_nr] = P return aligns class VislocInLoc(BaseVislocDataset): def __init__(self, root, pairsfile, topk=1): super().__init__() self.root = root self.topk = topk self.num_views = self.topk + 1 self.maxdim = None self.patch_size = None query_path = os.path.join(self.root, "query") kdata_query = kapture_from_dir(query_path) assert kdata_query.records_camera is not None kdata_query_searchindex = { kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) for timestamp, sensor_id in kdata_query.records_camera.key_pairs() } self.query_data = { "path": query_path, "kdata": kdata_query, "searchindex": kdata_query_searchindex, } map_path = os.path.join(self.root, "mapping") kdata_map = kapture_from_dir(map_path) assert ( kdata_map.records_camera is not None and kdata_map.trajectories is not None ) kdata_map_searchindex = { kdata_map.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id) for timestamp, sensor_id in kdata_map.records_camera.key_pairs() } self.map_data = { "path": map_path, "kdata": kdata_map, "searchindex": kdata_map_searchindex, } try: self.pairs = get_ordered_pairs_from_file( os.path.join(self.root, "pairfiles/query", pairsfile + ".txt") ) except Exception as e: # if using pairs from hloc self.pairs = {} with open( os.path.join(self.root, "pairfiles/query", pairsfile + ".txt"), "r" ) as fid: lines = fid.readlines() for line in lines: splits = line.rstrip("\n\r").split(" ") self.pairs.setdefault(splits[0].replace("query/", ""), []).append( (splits[1].replace("database/cutouts/", ""), 1.0) ) self.scenes = kdata_query.records_camera.data_list() self.aligns_DUC1 = read_alignments( os.path.join(self.root, "mapping/DUC1_alignment/all_transformations.txt") ) self.aligns_DUC2 = read_alignments( os.path.join(self.root, "mapping/DUC2_alignment/all_transformations.txt") ) def __len__(self): return len(self.scenes) def __getitem__(self, idx): assert self.maxdim is not None and self.patch_size is not None query_image = self.scenes[idx] map_images = [p[0] for p in self.pairs[query_image][: self.topk]] views = [] dataarray = [(query_image, self.query_data, False)] + [ (map_image, self.map_data, True) for map_image in map_images ] for idx, (imgname, data, should_load_depth) in enumerate(dataarray): imgpath, kdata, searchindex = map( data.get, ["path", "kdata", "searchindex"] ) timestamp, camera_id = searchindex[imgname] # for InLoc, SIMPLE_PINHOLE camera_params = kdata.sensors[camera_id].camera_params W, H, f, cx, cy = camera_params distortion = [0, 0, 0, 0] intrinsics = np.float32([(f, 0, cx), (0, f, cy), (0, 0, 1)]) if ( kdata.trajectories is not None and (timestamp, camera_id) in kdata.trajectories ): cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id) else: cam_to_world = np.eye(4, dtype=np.float32) # Load RGB image rgb_image = PIL.Image.open( os.path.join(imgpath, "sensors/records_data", imgname) ).convert("RGB") rgb_image.load() W, H = rgb_image.size resize_func, to_resize, to_orig = get_resize_function( self.maxdim, self.patch_size, H, W ) rgb_tensor = resize_func(ImgNorm(rgb_image)) view = { "intrinsics": intrinsics, "distortion": distortion, "cam_to_world": cam_to_world, "rgb": rgb_image, "rgb_rescaled": rgb_tensor, "to_orig": to_orig, "idx": idx, "image_name": imgname, } # Load depthmap if should_load_depth: depthmap_filename = os.path.join( imgpath, "sensors/records_data", imgname + ".mat" ) depthmap = scipy.io.loadmat(depthmap_filename) pt3d_cut = depthmap["XYZcut"] scene_id = imgname.replace("\\", "/").split("/")[1] if imgname.startswith("DUC1"): pts3d_full = geotrf(self.aligns_DUC1[scene_id], pt3d_cut) else: pts3d_full = geotrf(self.aligns_DUC2[scene_id], pt3d_cut) pts3d_valid = np.isfinite(pts3d_full.sum(axis=-1)) pts3d = pts3d_full[pts3d_valid] pts2d_int = xy_grid(W, H)[pts3d_valid] pts2d = pts2d_int.astype(np.float64) # nan => invalid pts3d_full[~pts3d_valid] = np.nan pts3d_full = torch.from_numpy(pts3d_full) view["pts3d"] = pts3d_full view["valid"] = pts3d_full.sum(dim=-1).isfinite() HR, WR = rgb_tensor.shape[1:] _, _, pts3d_rescaled, valid_rescaled = rescale_points3d( pts2d, pts3d, to_resize, HR, WR ) pts3d_rescaled = torch.from_numpy(pts3d_rescaled) valid_rescaled = torch.from_numpy(valid_rescaled) view["pts3d_rescaled"] = pts3d_rescaled view["valid_rescaled"] = valid_rescaled views.append(view) return views