GeoCalib / siclib /datasets /simple_dataset.py
veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
7.72 kB
"""Dataset for images created with 'create_dataset_from_pano.py'."""
import logging
from pathlib import Path
from typing import Any, Dict, List, Tuple
import pandas as pd
import torch
from omegaconf import DictConfig
from siclib.datasets.augmentations import IdentityAugmentation, augmentations
from siclib.datasets.base_dataset import BaseDataset
from siclib.geometry.camera import SimpleRadial
from siclib.geometry.gravity import Gravity
from siclib.geometry.perspective_fields import get_perspective_field
from siclib.utils.conversions import fov2focal
from siclib.utils.image import ImagePreprocessor, load_image
from siclib.utils.tools import fork_rng
logger = logging.getLogger(__name__)
# mypy: ignore-errors
def load_csv(
csv_file: Path, img_root: Path
) -> Tuple[List[Dict[str, Any]], torch.Tensor, torch.Tensor]:
"""Load a CSV file containing image information.
Args:
csv_file (str): Path to the CSV file.
img_root (str): Path to the root directory containing the images.
Returns:
list: List of dictionaries containing the image paths and camera parameters.
"""
df = pd.read_csv(csv_file)
infos, params, gravity = [], [], []
for _, row in df.iterrows():
h = row["height"]
w = row["width"]
px = row.get("px", w / 2)
py = row.get("py", h / 2)
vfov = row["vfov"]
f = fov2focal(torch.tensor(vfov), h)
k1 = row.get("k1", 0)
k2 = row.get("k2", 0)
params.append(torch.tensor([w, h, f, f, px, py, k1, k2]))
roll = row["roll"]
pitch = row["pitch"]
gravity.append(torch.tensor([roll, pitch]))
infos.append({"name": row["fname"], "file_name": str(img_root / row["fname"])})
params = torch.stack(params).float()
gravity = torch.stack(gravity).float()
return infos, params, gravity
class SimpleDataset(BaseDataset):
"""Dataset for images created with 'create_dataset_from_pano.py'."""
default_conf = {
# paths
"dataset_dir": "???",
"train_img_dir": "${.dataset_dir}/train",
"val_img_dir": "${.dataset_dir}/val",
"test_img_dir": "${.dataset_dir}/test",
"train_csv": "${.dataset_dir}/train.csv",
"val_csv": "${.dataset_dir}/val.csv",
"test_csv": "${.dataset_dir}/test.csv",
# data options
"use_up": True,
"use_latitude": True,
"use_prior_focal": False,
"use_prior_gravity": False,
"use_prior_k1": False,
# image options
"grayscale": False,
"preprocessing": ImagePreprocessor.default_conf,
"augmentations": {"name": "geocalib", "verbose": False},
"p_rotate": 0.0, # probability to rotate image by +/- 90°
"reseed": False,
"seed": 0,
# data loader options
"num_workers": 8,
"prefetch_factor": 2,
"train_batch_size": 32,
"val_batch_size": 32,
"test_batch_size": 32,
}
def _init(self, conf):
pass
def get_dataset(self, split: str) -> torch.utils.data.Dataset:
"""Return a dataset for a given split."""
return _SimpleDataset(self.conf, split)
class _SimpleDataset(torch.utils.data.Dataset):
"""Dataset for dataset for images created with 'create_dataset_from_pano.py'."""
def __init__(self, conf: DictConfig, split: str):
"""Initialize the dataset."""
self.conf = conf
self.split = split
self.img_dir = Path(conf.get(f"{split}_img_dir"))
self.preprocessor = ImagePreprocessor(conf.preprocessing)
# load image information
assert f"{split}_csv" in conf, f"Missing {split}_csv in conf"
infos_path = self.conf.get(f"{split}_csv")
self.infos, self.parameters, self.gravity = load_csv(infos_path, self.img_dir)
# define augmentations
aug_name = conf.augmentations.name
assert (
aug_name in augmentations.keys()
), f'{aug_name} not in {" ".join(augmentations.keys())}'
if self.split == "train":
self.augmentation = augmentations[aug_name](conf.augmentations)
else:
self.augmentation = IdentityAugmentation()
def __len__(self):
return len(self.infos)
def __getitem__(self, idx):
if not self.conf.reseed:
return self.getitem(idx)
with fork_rng(self.conf.seed + idx, False):
return self.getitem(idx)
def _read_image(
self, infos: Dict[str, Any], parameters: torch.Tensor, gravity: torch.Tensor
) -> Dict[str, Any]:
path = Path(str(infos["file_name"]))
# load image as uint8 and HWC for augmentation
image = load_image(path, self.conf.grayscale, return_tensor=False)
image = self.augmentation(image, return_tensor=True)
# create radial camera -> same as pinhole if k1 = 0
camera = SimpleRadial(parameters[None]).float()
roll, pitch = gravity[None].unbind(-1)
gravity = Gravity.from_rp(roll, pitch)
# preprocess
data = self.preprocessor(image)
camera = camera.scale(data["scales"])
camera = camera.crop(data["crop_pad"]) if "crop_pad" in data else camera
priors = {"prior_gravity": gravity} if self.conf.use_prior_gravity else {}
priors |= {"prior_focal": camera.f[..., 1]} if self.conf.use_prior_focal else {}
priors |= {"prior_k1": camera.k1} if self.conf.use_prior_k1 else {}
return {
"name": infos["name"],
"path": str(path),
"camera": camera[0],
"gravity": gravity[0],
**priors,
**data,
}
def _get_perspective(self, data):
"""Get perspective field."""
camera = data["camera"]
gravity = data["gravity"]
up_field, lat_field = get_perspective_field(
camera, gravity, use_up=self.conf.use_up, use_latitude=self.conf.use_latitude
)
out = {}
if self.conf.use_up:
out["up_field"] = up_field[0]
if self.conf.use_latitude:
out["latitude_field"] = lat_field[0]
return out
def getitem(self, idx: int):
"""Return a sample from the dataset."""
infos = self.infos[idx]
parameters = self.parameters[idx]
gravity = self.gravity[idx]
data = self._read_image(infos, parameters, gravity)
if self.conf.use_up or self.conf.use_latitude:
data |= self._get_perspective(data)
return data
if __name__ == "__main__":
# Create a dump of the dataset
import argparse
import matplotlib.pyplot as plt
from siclib.visualization.visualize_batch import make_perspective_figures
parser = argparse.ArgumentParser()
parser.add_argument("--name", type=str, required=True)
parser.add_argument("--data_dir", type=str)
parser.add_argument("--split", type=str, default="train")
parser.add_argument("--shuffle", action="store_true")
parser.add_argument("--n_rows", type=int, default=4)
parser.add_argument("--dpi", type=int, default=100)
args = parser.parse_intermixed_args()
dconf = SimpleDataset.default_conf
dconf["name"] = args.name
dconf["num_workers"] = 0
dconf["prefetch_factor"] = None
dconf["dataset_dir"] = args.data_dir
dconf[f"{args.split}_batch_size"] = args.n_rows
torch.set_grad_enabled(False)
dataset = SimpleDataset(dconf)
loader = dataset.get_data_loader(args.split, args.shuffle)
with fork_rng(seed=42):
for data in loader:
pred = data
break
fig = make_perspective_figures(pred, data, n_pairs=args.n_rows)
plt.show()