sengerchen's picture
Upload folder using huggingface_hub
1bb1365 verified
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
#
# --------------------------------------------------------
# InLoc dataloader
# --------------------------------------------------------
import os
import kapture
import numpy as np
import PIL.Image
import scipy.io
import torch
from dust3r.datasets.utils.transforms import ImgNorm
from dust3r.utils.geometry import geotrf, xy_grid
from dust3r_visloc.datasets.base_dataset import BaseVislocDataset
from dust3r_visloc.datasets.utils import (
cam_to_world_from_kapture,
get_resize_function,
rescale_points3d,
)
from kapture.io.csv import kapture_from_dir
from kapture_localization.utils.pairsfile import get_ordered_pairs_from_file
def read_alignments(path_to_alignment):
aligns = {}
with open(path_to_alignment, "r") as fid:
while True:
line = fid.readline()
if not line:
break
if len(line) == 4:
trans_nr = line[:-1]
while line != "After general icp:\n":
line = fid.readline()
line = fid.readline()
p = []
for i in range(4):
elems = line.split(" ")
line = fid.readline()
for e in elems:
if len(e) != 0:
p.append(float(e))
P = np.array(p).reshape(4, 4)
aligns[trans_nr] = P
return aligns
class VislocInLoc(BaseVislocDataset):
def __init__(self, root, pairsfile, topk=1):
super().__init__()
self.root = root
self.topk = topk
self.num_views = self.topk + 1
self.maxdim = None
self.patch_size = None
query_path = os.path.join(self.root, "query")
kdata_query = kapture_from_dir(query_path)
assert kdata_query.records_camera is not None
kdata_query_searchindex = {
kdata_query.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id)
for timestamp, sensor_id in kdata_query.records_camera.key_pairs()
}
self.query_data = {
"path": query_path,
"kdata": kdata_query,
"searchindex": kdata_query_searchindex,
}
map_path = os.path.join(self.root, "mapping")
kdata_map = kapture_from_dir(map_path)
assert (
kdata_map.records_camera is not None and kdata_map.trajectories is not None
)
kdata_map_searchindex = {
kdata_map.records_camera[(timestamp, sensor_id)]: (timestamp, sensor_id)
for timestamp, sensor_id in kdata_map.records_camera.key_pairs()
}
self.map_data = {
"path": map_path,
"kdata": kdata_map,
"searchindex": kdata_map_searchindex,
}
try:
self.pairs = get_ordered_pairs_from_file(
os.path.join(self.root, "pairfiles/query", pairsfile + ".txt")
)
except Exception as e:
# if using pairs from hloc
self.pairs = {}
with open(
os.path.join(self.root, "pairfiles/query", pairsfile + ".txt"), "r"
) as fid:
lines = fid.readlines()
for line in lines:
splits = line.rstrip("\n\r").split(" ")
self.pairs.setdefault(splits[0].replace("query/", ""), []).append(
(splits[1].replace("database/cutouts/", ""), 1.0)
)
self.scenes = kdata_query.records_camera.data_list()
self.aligns_DUC1 = read_alignments(
os.path.join(self.root, "mapping/DUC1_alignment/all_transformations.txt")
)
self.aligns_DUC2 = read_alignments(
os.path.join(self.root, "mapping/DUC2_alignment/all_transformations.txt")
)
def __len__(self):
return len(self.scenes)
def __getitem__(self, idx):
assert self.maxdim is not None and self.patch_size is not None
query_image = self.scenes[idx]
map_images = [p[0] for p in self.pairs[query_image][: self.topk]]
views = []
dataarray = [(query_image, self.query_data, False)] + [
(map_image, self.map_data, True) for map_image in map_images
]
for idx, (imgname, data, should_load_depth) in enumerate(dataarray):
imgpath, kdata, searchindex = map(
data.get, ["path", "kdata", "searchindex"]
)
timestamp, camera_id = searchindex[imgname]
# for InLoc, SIMPLE_PINHOLE
camera_params = kdata.sensors[camera_id].camera_params
W, H, f, cx, cy = camera_params
distortion = [0, 0, 0, 0]
intrinsics = np.float32([(f, 0, cx), (0, f, cy), (0, 0, 1)])
if (
kdata.trajectories is not None
and (timestamp, camera_id) in kdata.trajectories
):
cam_to_world = cam_to_world_from_kapture(kdata, timestamp, camera_id)
else:
cam_to_world = np.eye(4, dtype=np.float32)
# Load RGB image
rgb_image = PIL.Image.open(
os.path.join(imgpath, "sensors/records_data", imgname)
).convert("RGB")
rgb_image.load()
W, H = rgb_image.size
resize_func, to_resize, to_orig = get_resize_function(
self.maxdim, self.patch_size, H, W
)
rgb_tensor = resize_func(ImgNorm(rgb_image))
view = {
"intrinsics": intrinsics,
"distortion": distortion,
"cam_to_world": cam_to_world,
"rgb": rgb_image,
"rgb_rescaled": rgb_tensor,
"to_orig": to_orig,
"idx": idx,
"image_name": imgname,
}
# Load depthmap
if should_load_depth:
depthmap_filename = os.path.join(
imgpath, "sensors/records_data", imgname + ".mat"
)
depthmap = scipy.io.loadmat(depthmap_filename)
pt3d_cut = depthmap["XYZcut"]
scene_id = imgname.replace("\\", "/").split("/")[1]
if imgname.startswith("DUC1"):
pts3d_full = geotrf(self.aligns_DUC1[scene_id], pt3d_cut)
else:
pts3d_full = geotrf(self.aligns_DUC2[scene_id], pt3d_cut)
pts3d_valid = np.isfinite(pts3d_full.sum(axis=-1))
pts3d = pts3d_full[pts3d_valid]
pts2d_int = xy_grid(W, H)[pts3d_valid]
pts2d = pts2d_int.astype(np.float64)
# nan => invalid
pts3d_full[~pts3d_valid] = np.nan
pts3d_full = torch.from_numpy(pts3d_full)
view["pts3d"] = pts3d_full
view["valid"] = pts3d_full.sum(dim=-1).isfinite()
HR, WR = rgb_tensor.shape[1:]
_, _, pts3d_rescaled, valid_rescaled = rescale_points3d(
pts2d, pts3d, to_resize, HR, WR
)
pts3d_rescaled = torch.from_numpy(pts3d_rescaled)
valid_rescaled = torch.from_numpy(valid_rescaled)
view["pts3d_rescaled"] = pts3d_rescaled
view["valid_rescaled"] = valid_rescaled
views.append(view)
return views