Paul-Edouard Sarlin
Code formatting (#47)
b0cf684 unverified
# Copyright (c) Meta Platforms, Inc. and affiliates.
import collections
import collections.abc
from collections import defaultdict
from pathlib import Path
from typing import Optional
import numpy as np
import pytorch_lightning as pl
import torch
import torch.utils.data as torchdata
from omegaconf import OmegaConf
from scipy.spatial.transform import Rotation
from ... import DATASETS_PATH, logger
from ...osm.tiling import TileManager
from ..dataset import MapLocDataset
from ..sequential import chunk_sequence
from ..torch import collate, worker_init_fn
from .utils import get_camera_calibration, parse_gps_file, parse_split_file
class KittiDataModule(pl.LightningDataModule):
default_cfg = {
**MapLocDataset.default_cfg,
"name": "kitti",
# paths and fetch
"data_dir": DATASETS_PATH / "kitti",
"tiles_filename": "tiles.pkl",
"splits": {
"train": "train_files.txt",
"val": "test1_files.txt",
"test": "test2_files.txt",
},
"loading": {
"train": "???",
"val": "${.test}",
"test": {"batch_size": 1, "num_workers": 0},
},
"max_num_val": 500,
"selection_subset_val": "furthest",
"drop_train_too_close_to_val": 5.0,
"skip_frames": 1,
"camera_index": 2,
# overwrite
"crop_size_meters": 64,
"max_init_error": 20,
"max_init_error_rotation": 10,
"add_map_mask": True,
"mask_pad": 2,
"target_focal_length": 256,
}
dummy_scene_name = "kitti"
def __init__(self, cfg, tile_manager: Optional[TileManager] = None):
super().__init__()
default_cfg = OmegaConf.create(self.default_cfg)
OmegaConf.set_struct(default_cfg, True) # cannot add new keys
self.cfg = OmegaConf.merge(default_cfg, cfg)
self.root = Path(self.cfg.data_dir)
self.tile_manager = tile_manager
if self.cfg.crop_size_meters < self.cfg.max_init_error:
raise ValueError("The ground truth location can be outside the map.")
assert self.cfg.selection_subset_val in ["random", "furthest"]
self.splits = {}
self.shifts = {}
self.calibrations = {}
self.data = {}
self.image_paths = {}
def prepare_data(self):
if not (self.root.exists() and (self.root / ".downloaded").exists()):
raise FileNotFoundError(
"Cannot find the KITTI dataset, run maploc.data.kitti.prepare"
)
def parse_split(self, split_arg):
if isinstance(split_arg, str):
names, shifts = parse_split_file(self.root / split_arg)
elif isinstance(split_arg, collections.abc.Sequence):
names = []
shifts = None
for date_drive in split_arg:
data_dir = (
self.root / date_drive / f"image_{self.cfg.camera_index:02}/data"
)
assert data_dir.exists(), data_dir
date_drive = tuple(date_drive.split("/"))
n = sorted(date_drive + (p.name,) for p in data_dir.glob("*.png"))
names.extend(n[:: self.cfg.skip_frames])
else:
raise ValueError(split_arg)
return names, shifts
def setup(self, stage: Optional[str] = None):
if stage == "fit":
stages = ["train", "val"]
elif stage is None:
stages = ["train", "val", "test"]
else:
stages = [stage]
for stage in stages:
self.splits[stage], self.shifts[stage] = self.parse_split(
self.cfg.splits[stage]
)
do_val_subset = "val" in stages and self.cfg.max_num_val is not None
if do_val_subset and self.cfg.selection_subset_val == "random":
select = np.random.RandomState(self.cfg.seed).choice(
len(self.splits["val"]), self.cfg.max_num_val, replace=False
)
self.splits["val"] = [self.splits["val"][i] for i in select]
if self.shifts["val"] is not None:
self.shifts["val"] = self.shifts["val"][select]
dates = {d for ns in self.splits.values() for d, _, _ in ns}
for d in dates:
self.calibrations[d] = get_camera_calibration(
self.root / d, self.cfg.camera_index
)
if self.tile_manager is None:
logger.info("Loading the tile manager...")
self.tile_manager = TileManager.load(self.root / self.cfg.tiles_filename)
self.cfg.num_classes = {k: len(g) for k, g in self.tile_manager.groups.items()}
self.cfg.pixel_per_meter = self.tile_manager.ppm
# pack all attributes in a single tensor to optimize memory access
self.pack_data(stages)
dists = None
if do_val_subset and self.cfg.selection_subset_val == "furthest":
dists = torch.cdist(
self.data["val"]["t_c2w"][:, :2].double(),
self.data["train"]["t_c2w"][:, :2].double(),
)
min_dists = dists.min(1).values
select = torch.argsort(min_dists)[-self.cfg.max_num_val :]
dists = dists[select]
self.splits["val"] = [self.splits["val"][i] for i in select]
if self.shifts["val"] is not None:
self.shifts["val"] = self.shifts["val"][select]
for k in list(self.data["val"]):
if k != "cameras":
self.data["val"][k] = self.data["val"][k][select]
self.image_paths["val"] = self.image_paths["val"][select]
if "train" in stages and self.cfg.drop_train_too_close_to_val is not None:
if dists is None:
dists = torch.cdist(
self.data["val"]["t_c2w"][:, :2].double(),
self.data["train"]["t_c2w"][:, :2].double(),
)
drop = torch.any(dists < self.cfg.drop_train_too_close_to_val, 0)
select = torch.where(~drop)[0]
logger.info(
"Dropping %d (%f %%) images that are too close to validation images.",
drop.sum(),
drop.float().mean(),
)
self.splits["train"] = [self.splits["train"][i] for i in select]
if self.shifts["train"] is not None:
self.shifts["train"] = self.shifts["train"][select]
for k in list(self.data["train"]):
if k != "cameras":
self.data["train"][k] = self.data["train"][k][select]
self.image_paths["train"] = self.image_paths["train"][select]
def pack_data(self, stages):
for stage in stages:
names = []
data = {}
for i, (date, drive, index) in enumerate(self.splits[stage]):
d = self.get_frame_data(date, drive, index)
for k, v in d.items():
if i == 0:
data[k] = []
data[k].append(v)
path = f"{date}/{drive}/image_{self.cfg.camera_index:02}/data/{index}"
names.append((self.dummy_scene_name, f"{date}/{drive}", path))
for k in list(data):
data[k] = torch.from_numpy(np.stack(data[k]))
data["camera_id"] = np.full(len(names), self.cfg.camera_index)
sequences = {date_drive for _, date_drive, _ in names}
data["cameras"] = {
self.dummy_scene_name: {
seq: {
self.cfg.camera_index: self.calibrations[seq.split("/")[0]][0]
}
for seq in sequences
}
}
shifts = self.shifts[stage]
if shifts is not None:
data["shifts"] = torch.from_numpy(shifts.astype(np.float32))
self.data[stage] = data
self.image_paths[stage] = np.array(names)
def get_frame_data(self, date, drive, index):
_, R_cam_gps, t_cam_gps = self.calibrations[date]
# Transform the GPS pose to the camera pose
gps_path = (
self.root / date / drive / "oxts/data" / Path(index).with_suffix(".txt")
)
_, R_world_gps, t_world_gps = parse_gps_file(
gps_path, self.tile_manager.projection
)
R_world_cam = R_world_gps @ R_cam_gps.T
t_world_cam = t_world_gps - R_world_gps @ R_cam_gps.T @ t_cam_gps
# Some voodoo to extract correct Euler angles from R_world_cam
R_cv_xyz = Rotation.from_euler("YX", [-90, 90], degrees=True).as_matrix()
R_world_cam_xyz = R_world_cam @ R_cv_xyz
y, p, r = Rotation.from_matrix(R_world_cam_xyz).as_euler("ZYX", degrees=True)
roll, pitch, yaw = r, -p, 90 - y
roll_pitch_yaw = np.array([-roll, -pitch, yaw], np.float32) # for some reason
return {
"t_c2w": t_world_cam.astype(np.float32),
"roll_pitch_yaw": roll_pitch_yaw,
"index": int(index.split(".")[0]),
}
def dataset(self, stage: str):
return MapLocDataset(
stage,
self.cfg,
self.image_paths[stage],
self.data[stage],
{self.dummy_scene_name: self.root},
{self.dummy_scene_name: self.tile_manager},
)
def dataloader(
self,
stage: str,
shuffle: bool = False,
num_workers: int = None,
sampler: Optional[torchdata.Sampler] = None,
):
dataset = self.dataset(stage)
cfg = self.cfg["loading"][stage]
num_workers = cfg["num_workers"] if num_workers is None else num_workers
loader = torchdata.DataLoader(
dataset,
batch_size=cfg["batch_size"],
num_workers=num_workers,
shuffle=shuffle or (stage == "train"),
pin_memory=True,
persistent_workers=num_workers > 0,
worker_init_fn=worker_init_fn,
collate_fn=collate,
sampler=sampler,
)
return loader
def train_dataloader(self, **kwargs):
return self.dataloader("train", **kwargs)
def val_dataloader(self, **kwargs):
return self.dataloader("val", **kwargs)
def test_dataloader(self, **kwargs):
return self.dataloader("test", **kwargs)
def sequence_dataset(self, stage: str, **kwargs):
keys = self.image_paths[stage]
# group images by sequence (date/drive)
seq2indices = defaultdict(list)
for index, (_, date_drive, _) in enumerate(keys):
seq2indices[date_drive].append(index)
# chunk the sequences to the required length
chunk2indices = {}
for seq, indices in seq2indices.items():
chunks = chunk_sequence(
self.data[stage], indices, names=self.image_paths[stage], **kwargs
)
for i, sub_indices in enumerate(chunks):
chunk2indices[seq, i] = sub_indices
# store the index of each chunk in its sequence
chunk_indices = torch.full((len(keys),), -1)
for (_, chunk_index), idx in chunk2indices.items():
chunk_indices[idx] = chunk_index
self.data[stage]["chunk_index"] = chunk_indices
dataset = self.dataset(stage)
return dataset, chunk2indices
def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs):
dataset, chunk2idx = self.sequence_dataset(stage, **kwargs)
seq_keys = sorted(chunk2idx)
if shuffle:
perm = torch.randperm(len(seq_keys))
seq_keys = [seq_keys[i] for i in perm]
key_indices = [i for key in seq_keys for i in chunk2idx[key]]
num_workers = self.cfg["loading"][stage]["num_workers"]
loader = torchdata.DataLoader(
dataset,
batch_size=None,
sampler=key_indices,
num_workers=num_workers,
shuffle=False,
pin_memory=True,
persistent_workers=num_workers > 0,
worker_init_fn=worker_init_fn,
collate_fn=collate,
)
return loader, seq_keys, chunk2idx