Spaces:
Running
Running
File size: 12,240 Bytes
9665c2c b0cf684 9665c2c b0cf684 9665c2c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
import collections
import collections.abc
from collections import defaultdict
from pathlib import Path
from typing import Optional
import numpy as np
import pytorch_lightning as pl
import torch
import torch.utils.data as torchdata
from omegaconf import OmegaConf
from scipy.spatial.transform import Rotation
from ... import DATASETS_PATH, logger
from ...osm.tiling import TileManager
from ..dataset import MapLocDataset
from ..sequential import chunk_sequence
from ..torch import collate, worker_init_fn
from .utils import get_camera_calibration, parse_gps_file, parse_split_file
class KittiDataModule(pl.LightningDataModule):
default_cfg = {
**MapLocDataset.default_cfg,
"name": "kitti",
# paths and fetch
"data_dir": DATASETS_PATH / "kitti",
"tiles_filename": "tiles.pkl",
"splits": {
"train": "train_files.txt",
"val": "test1_files.txt",
"test": "test2_files.txt",
},
"loading": {
"train": "???",
"val": "${.test}",
"test": {"batch_size": 1, "num_workers": 0},
},
"max_num_val": 500,
"selection_subset_val": "furthest",
"drop_train_too_close_to_val": 5.0,
"skip_frames": 1,
"camera_index": 2,
# overwrite
"crop_size_meters": 64,
"max_init_error": 20,
"max_init_error_rotation": 10,
"add_map_mask": True,
"mask_pad": 2,
"target_focal_length": 256,
}
dummy_scene_name = "kitti"
def __init__(self, cfg, tile_manager: Optional[TileManager] = None):
super().__init__()
default_cfg = OmegaConf.create(self.default_cfg)
OmegaConf.set_struct(default_cfg, True) # cannot add new keys
self.cfg = OmegaConf.merge(default_cfg, cfg)
self.root = Path(self.cfg.data_dir)
self.tile_manager = tile_manager
if self.cfg.crop_size_meters < self.cfg.max_init_error:
raise ValueError("The ground truth location can be outside the map.")
assert self.cfg.selection_subset_val in ["random", "furthest"]
self.splits = {}
self.shifts = {}
self.calibrations = {}
self.data = {}
self.image_paths = {}
def prepare_data(self):
if not (self.root.exists() and (self.root / ".downloaded").exists()):
raise FileNotFoundError(
"Cannot find the KITTI dataset, run maploc.data.kitti.prepare"
)
def parse_split(self, split_arg):
if isinstance(split_arg, str):
names, shifts = parse_split_file(self.root / split_arg)
elif isinstance(split_arg, collections.abc.Sequence):
names = []
shifts = None
for date_drive in split_arg:
data_dir = (
self.root / date_drive / f"image_{self.cfg.camera_index:02}/data"
)
assert data_dir.exists(), data_dir
date_drive = tuple(date_drive.split("/"))
n = sorted(date_drive + (p.name,) for p in data_dir.glob("*.png"))
names.extend(n[:: self.cfg.skip_frames])
else:
raise ValueError(split_arg)
return names, shifts
def setup(self, stage: Optional[str] = None):
if stage == "fit":
stages = ["train", "val"]
elif stage is None:
stages = ["train", "val", "test"]
else:
stages = [stage]
for stage in stages:
self.splits[stage], self.shifts[stage] = self.parse_split(
self.cfg.splits[stage]
)
do_val_subset = "val" in stages and self.cfg.max_num_val is not None
if do_val_subset and self.cfg.selection_subset_val == "random":
select = np.random.RandomState(self.cfg.seed).choice(
len(self.splits["val"]), self.cfg.max_num_val, replace=False
)
self.splits["val"] = [self.splits["val"][i] for i in select]
if self.shifts["val"] is not None:
self.shifts["val"] = self.shifts["val"][select]
dates = {d for ns in self.splits.values() for d, _, _ in ns}
for d in dates:
self.calibrations[d] = get_camera_calibration(
self.root / d, self.cfg.camera_index
)
if self.tile_manager is None:
logger.info("Loading the tile manager...")
self.tile_manager = TileManager.load(self.root / self.cfg.tiles_filename)
self.cfg.num_classes = {k: len(g) for k, g in self.tile_manager.groups.items()}
self.cfg.pixel_per_meter = self.tile_manager.ppm
# pack all attributes in a single tensor to optimize memory access
self.pack_data(stages)
dists = None
if do_val_subset and self.cfg.selection_subset_val == "furthest":
dists = torch.cdist(
self.data["val"]["t_c2w"][:, :2].double(),
self.data["train"]["t_c2w"][:, :2].double(),
)
min_dists = dists.min(1).values
select = torch.argsort(min_dists)[-self.cfg.max_num_val :]
dists = dists[select]
self.splits["val"] = [self.splits["val"][i] for i in select]
if self.shifts["val"] is not None:
self.shifts["val"] = self.shifts["val"][select]
for k in list(self.data["val"]):
if k != "cameras":
self.data["val"][k] = self.data["val"][k][select]
self.image_paths["val"] = self.image_paths["val"][select]
if "train" in stages and self.cfg.drop_train_too_close_to_val is not None:
if dists is None:
dists = torch.cdist(
self.data["val"]["t_c2w"][:, :2].double(),
self.data["train"]["t_c2w"][:, :2].double(),
)
drop = torch.any(dists < self.cfg.drop_train_too_close_to_val, 0)
select = torch.where(~drop)[0]
logger.info(
"Dropping %d (%f %%) images that are too close to validation images.",
drop.sum(),
drop.float().mean(),
)
self.splits["train"] = [self.splits["train"][i] for i in select]
if self.shifts["train"] is not None:
self.shifts["train"] = self.shifts["train"][select]
for k in list(self.data["train"]):
if k != "cameras":
self.data["train"][k] = self.data["train"][k][select]
self.image_paths["train"] = self.image_paths["train"][select]
def pack_data(self, stages):
for stage in stages:
names = []
data = {}
for i, (date, drive, index) in enumerate(self.splits[stage]):
d = self.get_frame_data(date, drive, index)
for k, v in d.items():
if i == 0:
data[k] = []
data[k].append(v)
path = f"{date}/{drive}/image_{self.cfg.camera_index:02}/data/{index}"
names.append((self.dummy_scene_name, f"{date}/{drive}", path))
for k in list(data):
data[k] = torch.from_numpy(np.stack(data[k]))
data["camera_id"] = np.full(len(names), self.cfg.camera_index)
sequences = {date_drive for _, date_drive, _ in names}
data["cameras"] = {
self.dummy_scene_name: {
seq: {
self.cfg.camera_index: self.calibrations[seq.split("/")[0]][0]
}
for seq in sequences
}
}
shifts = self.shifts[stage]
if shifts is not None:
data["shifts"] = torch.from_numpy(shifts.astype(np.float32))
self.data[stage] = data
self.image_paths[stage] = np.array(names)
def get_frame_data(self, date, drive, index):
_, R_cam_gps, t_cam_gps = self.calibrations[date]
# Transform the GPS pose to the camera pose
gps_path = (
self.root / date / drive / "oxts/data" / Path(index).with_suffix(".txt")
)
_, R_world_gps, t_world_gps = parse_gps_file(
gps_path, self.tile_manager.projection
)
R_world_cam = R_world_gps @ R_cam_gps.T
t_world_cam = t_world_gps - R_world_gps @ R_cam_gps.T @ t_cam_gps
# Some voodoo to extract correct Euler angles from R_world_cam
R_cv_xyz = Rotation.from_euler("YX", [-90, 90], degrees=True).as_matrix()
R_world_cam_xyz = R_world_cam @ R_cv_xyz
y, p, r = Rotation.from_matrix(R_world_cam_xyz).as_euler("ZYX", degrees=True)
roll, pitch, yaw = r, -p, 90 - y
roll_pitch_yaw = np.array([-roll, -pitch, yaw], np.float32) # for some reason
return {
"t_c2w": t_world_cam.astype(np.float32),
"roll_pitch_yaw": roll_pitch_yaw,
"index": int(index.split(".")[0]),
}
def dataset(self, stage: str):
return MapLocDataset(
stage,
self.cfg,
self.image_paths[stage],
self.data[stage],
{self.dummy_scene_name: self.root},
{self.dummy_scene_name: self.tile_manager},
)
def dataloader(
self,
stage: str,
shuffle: bool = False,
num_workers: int = None,
sampler: Optional[torchdata.Sampler] = None,
):
dataset = self.dataset(stage)
cfg = self.cfg["loading"][stage]
num_workers = cfg["num_workers"] if num_workers is None else num_workers
loader = torchdata.DataLoader(
dataset,
batch_size=cfg["batch_size"],
num_workers=num_workers,
shuffle=shuffle or (stage == "train"),
pin_memory=True,
persistent_workers=num_workers > 0,
worker_init_fn=worker_init_fn,
collate_fn=collate,
sampler=sampler,
)
return loader
def train_dataloader(self, **kwargs):
return self.dataloader("train", **kwargs)
def val_dataloader(self, **kwargs):
return self.dataloader("val", **kwargs)
def test_dataloader(self, **kwargs):
return self.dataloader("test", **kwargs)
def sequence_dataset(self, stage: str, **kwargs):
keys = self.image_paths[stage]
# group images by sequence (date/drive)
seq2indices = defaultdict(list)
for index, (_, date_drive, _) in enumerate(keys):
seq2indices[date_drive].append(index)
# chunk the sequences to the required length
chunk2indices = {}
for seq, indices in seq2indices.items():
chunks = chunk_sequence(
self.data[stage], indices, names=self.image_paths[stage], **kwargs
)
for i, sub_indices in enumerate(chunks):
chunk2indices[seq, i] = sub_indices
# store the index of each chunk in its sequence
chunk_indices = torch.full((len(keys),), -1)
for (_, chunk_index), idx in chunk2indices.items():
chunk_indices[idx] = chunk_index
self.data[stage]["chunk_index"] = chunk_indices
dataset = self.dataset(stage)
return dataset, chunk2indices
def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs):
dataset, chunk2idx = self.sequence_dataset(stage, **kwargs)
seq_keys = sorted(chunk2idx)
if shuffle:
perm = torch.randperm(len(seq_keys))
seq_keys = [seq_keys[i] for i in perm]
key_indices = [i for key in seq_keys for i in chunk2idx[key]]
num_workers = self.cfg["loading"][stage]["num_workers"]
loader = torchdata.DataLoader(
dataset,
batch_size=None,
sampler=key_indices,
num_workers=num_workers,
shuffle=False,
pin_memory=True,
persistent_workers=num_workers > 0,
worker_init_fn=worker_init_fn,
collate_fn=collate,
)
return loader, seq_keys, chunk2idx
|