Spaces:
Running
Running
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
import collections | |
import collections.abc | |
from collections import defaultdict | |
from pathlib import Path | |
from typing import Optional | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
import torch.utils.data as torchdata | |
from omegaconf import OmegaConf | |
from scipy.spatial.transform import Rotation | |
from ... import DATASETS_PATH, logger | |
from ...osm.tiling import TileManager | |
from ..dataset import MapLocDataset | |
from ..sequential import chunk_sequence | |
from ..torch import collate, worker_init_fn | |
from .utils import get_camera_calibration, parse_gps_file, parse_split_file | |
class KittiDataModule(pl.LightningDataModule): | |
default_cfg = { | |
**MapLocDataset.default_cfg, | |
"name": "kitti", | |
# paths and fetch | |
"data_dir": DATASETS_PATH / "kitti", | |
"tiles_filename": "tiles.pkl", | |
"splits": { | |
"train": "train_files.txt", | |
"val": "test1_files.txt", | |
"test": "test2_files.txt", | |
}, | |
"loading": { | |
"train": "???", | |
"val": "${.test}", | |
"test": {"batch_size": 1, "num_workers": 0}, | |
}, | |
"max_num_val": 500, | |
"selection_subset_val": "furthest", | |
"drop_train_too_close_to_val": 5.0, | |
"skip_frames": 1, | |
"camera_index": 2, | |
# overwrite | |
"crop_size_meters": 64, | |
"max_init_error": 20, | |
"max_init_error_rotation": 10, | |
"add_map_mask": True, | |
"mask_pad": 2, | |
"target_focal_length": 256, | |
} | |
dummy_scene_name = "kitti" | |
def __init__(self, cfg, tile_manager: Optional[TileManager] = None): | |
super().__init__() | |
default_cfg = OmegaConf.create(self.default_cfg) | |
OmegaConf.set_struct(default_cfg, True) # cannot add new keys | |
self.cfg = OmegaConf.merge(default_cfg, cfg) | |
self.root = Path(self.cfg.data_dir) | |
self.tile_manager = tile_manager | |
if self.cfg.crop_size_meters < self.cfg.max_init_error: | |
raise ValueError("The ground truth location can be outside the map.") | |
assert self.cfg.selection_subset_val in ["random", "furthest"] | |
self.splits = {} | |
self.shifts = {} | |
self.calibrations = {} | |
self.data = {} | |
self.image_paths = {} | |
def prepare_data(self): | |
if not (self.root.exists() and (self.root / ".downloaded").exists()): | |
raise FileNotFoundError( | |
"Cannot find the KITTI dataset, run maploc.data.kitti.prepare" | |
) | |
def parse_split(self, split_arg): | |
if isinstance(split_arg, str): | |
names, shifts = parse_split_file(self.root / split_arg) | |
elif isinstance(split_arg, collections.abc.Sequence): | |
names = [] | |
shifts = None | |
for date_drive in split_arg: | |
data_dir = ( | |
self.root / date_drive / f"image_{self.cfg.camera_index:02}/data" | |
) | |
assert data_dir.exists(), data_dir | |
date_drive = tuple(date_drive.split("/")) | |
n = sorted(date_drive + (p.name,) for p in data_dir.glob("*.png")) | |
names.extend(n[:: self.cfg.skip_frames]) | |
else: | |
raise ValueError(split_arg) | |
return names, shifts | |
def setup(self, stage: Optional[str] = None): | |
if stage == "fit": | |
stages = ["train", "val"] | |
elif stage is None: | |
stages = ["train", "val", "test"] | |
else: | |
stages = [stage] | |
for stage in stages: | |
self.splits[stage], self.shifts[stage] = self.parse_split( | |
self.cfg.splits[stage] | |
) | |
do_val_subset = "val" in stages and self.cfg.max_num_val is not None | |
if do_val_subset and self.cfg.selection_subset_val == "random": | |
select = np.random.RandomState(self.cfg.seed).choice( | |
len(self.splits["val"]), self.cfg.max_num_val, replace=False | |
) | |
self.splits["val"] = [self.splits["val"][i] for i in select] | |
if self.shifts["val"] is not None: | |
self.shifts["val"] = self.shifts["val"][select] | |
dates = {d for ns in self.splits.values() for d, _, _ in ns} | |
for d in dates: | |
self.calibrations[d] = get_camera_calibration( | |
self.root / d, self.cfg.camera_index | |
) | |
if self.tile_manager is None: | |
logger.info("Loading the tile manager...") | |
self.tile_manager = TileManager.load(self.root / self.cfg.tiles_filename) | |
self.cfg.num_classes = {k: len(g) for k, g in self.tile_manager.groups.items()} | |
self.cfg.pixel_per_meter = self.tile_manager.ppm | |
# pack all attributes in a single tensor to optimize memory access | |
self.pack_data(stages) | |
dists = None | |
if do_val_subset and self.cfg.selection_subset_val == "furthest": | |
dists = torch.cdist( | |
self.data["val"]["t_c2w"][:, :2].double(), | |
self.data["train"]["t_c2w"][:, :2].double(), | |
) | |
min_dists = dists.min(1).values | |
select = torch.argsort(min_dists)[-self.cfg.max_num_val :] | |
dists = dists[select] | |
self.splits["val"] = [self.splits["val"][i] for i in select] | |
if self.shifts["val"] is not None: | |
self.shifts["val"] = self.shifts["val"][select] | |
for k in list(self.data["val"]): | |
if k != "cameras": | |
self.data["val"][k] = self.data["val"][k][select] | |
self.image_paths["val"] = self.image_paths["val"][select] | |
if "train" in stages and self.cfg.drop_train_too_close_to_val is not None: | |
if dists is None: | |
dists = torch.cdist( | |
self.data["val"]["t_c2w"][:, :2].double(), | |
self.data["train"]["t_c2w"][:, :2].double(), | |
) | |
drop = torch.any(dists < self.cfg.drop_train_too_close_to_val, 0) | |
select = torch.where(~drop)[0] | |
logger.info( | |
"Dropping %d (%f %%) images that are too close to validation images.", | |
drop.sum(), | |
drop.float().mean(), | |
) | |
self.splits["train"] = [self.splits["train"][i] for i in select] | |
if self.shifts["train"] is not None: | |
self.shifts["train"] = self.shifts["train"][select] | |
for k in list(self.data["train"]): | |
if k != "cameras": | |
self.data["train"][k] = self.data["train"][k][select] | |
self.image_paths["train"] = self.image_paths["train"][select] | |
def pack_data(self, stages): | |
for stage in stages: | |
names = [] | |
data = {} | |
for i, (date, drive, index) in enumerate(self.splits[stage]): | |
d = self.get_frame_data(date, drive, index) | |
for k, v in d.items(): | |
if i == 0: | |
data[k] = [] | |
data[k].append(v) | |
path = f"{date}/{drive}/image_{self.cfg.camera_index:02}/data/{index}" | |
names.append((self.dummy_scene_name, f"{date}/{drive}", path)) | |
for k in list(data): | |
data[k] = torch.from_numpy(np.stack(data[k])) | |
data["camera_id"] = np.full(len(names), self.cfg.camera_index) | |
sequences = {date_drive for _, date_drive, _ in names} | |
data["cameras"] = { | |
self.dummy_scene_name: { | |
seq: { | |
self.cfg.camera_index: self.calibrations[seq.split("/")[0]][0] | |
} | |
for seq in sequences | |
} | |
} | |
shifts = self.shifts[stage] | |
if shifts is not None: | |
data["shifts"] = torch.from_numpy(shifts.astype(np.float32)) | |
self.data[stage] = data | |
self.image_paths[stage] = np.array(names) | |
def get_frame_data(self, date, drive, index): | |
_, R_cam_gps, t_cam_gps = self.calibrations[date] | |
# Transform the GPS pose to the camera pose | |
gps_path = ( | |
self.root / date / drive / "oxts/data" / Path(index).with_suffix(".txt") | |
) | |
_, R_world_gps, t_world_gps = parse_gps_file( | |
gps_path, self.tile_manager.projection | |
) | |
R_world_cam = R_world_gps @ R_cam_gps.T | |
t_world_cam = t_world_gps - R_world_gps @ R_cam_gps.T @ t_cam_gps | |
# Some voodoo to extract correct Euler angles from R_world_cam | |
R_cv_xyz = Rotation.from_euler("YX", [-90, 90], degrees=True).as_matrix() | |
R_world_cam_xyz = R_world_cam @ R_cv_xyz | |
y, p, r = Rotation.from_matrix(R_world_cam_xyz).as_euler("ZYX", degrees=True) | |
roll, pitch, yaw = r, -p, 90 - y | |
roll_pitch_yaw = np.array([-roll, -pitch, yaw], np.float32) # for some reason | |
return { | |
"t_c2w": t_world_cam.astype(np.float32), | |
"roll_pitch_yaw": roll_pitch_yaw, | |
"index": int(index.split(".")[0]), | |
} | |
def dataset(self, stage: str): | |
return MapLocDataset( | |
stage, | |
self.cfg, | |
self.image_paths[stage], | |
self.data[stage], | |
{self.dummy_scene_name: self.root}, | |
{self.dummy_scene_name: self.tile_manager}, | |
) | |
def dataloader( | |
self, | |
stage: str, | |
shuffle: bool = False, | |
num_workers: int = None, | |
sampler: Optional[torchdata.Sampler] = None, | |
): | |
dataset = self.dataset(stage) | |
cfg = self.cfg["loading"][stage] | |
num_workers = cfg["num_workers"] if num_workers is None else num_workers | |
loader = torchdata.DataLoader( | |
dataset, | |
batch_size=cfg["batch_size"], | |
num_workers=num_workers, | |
shuffle=shuffle or (stage == "train"), | |
pin_memory=True, | |
persistent_workers=num_workers > 0, | |
worker_init_fn=worker_init_fn, | |
collate_fn=collate, | |
sampler=sampler, | |
) | |
return loader | |
def train_dataloader(self, **kwargs): | |
return self.dataloader("train", **kwargs) | |
def val_dataloader(self, **kwargs): | |
return self.dataloader("val", **kwargs) | |
def test_dataloader(self, **kwargs): | |
return self.dataloader("test", **kwargs) | |
def sequence_dataset(self, stage: str, **kwargs): | |
keys = self.image_paths[stage] | |
# group images by sequence (date/drive) | |
seq2indices = defaultdict(list) | |
for index, (_, date_drive, _) in enumerate(keys): | |
seq2indices[date_drive].append(index) | |
# chunk the sequences to the required length | |
chunk2indices = {} | |
for seq, indices in seq2indices.items(): | |
chunks = chunk_sequence( | |
self.data[stage], indices, names=self.image_paths[stage], **kwargs | |
) | |
for i, sub_indices in enumerate(chunks): | |
chunk2indices[seq, i] = sub_indices | |
# store the index of each chunk in its sequence | |
chunk_indices = torch.full((len(keys),), -1) | |
for (_, chunk_index), idx in chunk2indices.items(): | |
chunk_indices[idx] = chunk_index | |
self.data[stage]["chunk_index"] = chunk_indices | |
dataset = self.dataset(stage) | |
return dataset, chunk2indices | |
def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs): | |
dataset, chunk2idx = self.sequence_dataset(stage, **kwargs) | |
seq_keys = sorted(chunk2idx) | |
if shuffle: | |
perm = torch.randperm(len(seq_keys)) | |
seq_keys = [seq_keys[i] for i in perm] | |
key_indices = [i for key in seq_keys for i in chunk2idx[key]] | |
num_workers = self.cfg["loading"][stage]["num_workers"] | |
loader = torchdata.DataLoader( | |
dataset, | |
batch_size=None, | |
sampler=key_indices, | |
num_workers=num_workers, | |
shuffle=False, | |
pin_memory=True, | |
persistent_workers=num_workers > 0, | |
worker_init_fn=worker_init_fn, | |
collate_fn=collate, | |
) | |
return loader, seq_keys, chunk2idx | |