|
"""Dataset for images created with 'create_dataset_from_pano.py'.""" |
|
|
|
import logging |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Tuple |
|
|
|
import pandas as pd |
|
import torch |
|
from omegaconf import DictConfig |
|
|
|
from siclib.datasets.augmentations import IdentityAugmentation, augmentations |
|
from siclib.datasets.base_dataset import BaseDataset |
|
from siclib.geometry.camera import SimpleRadial |
|
from siclib.geometry.gravity import Gravity |
|
from siclib.geometry.perspective_fields import get_perspective_field |
|
from siclib.utils.conversions import fov2focal |
|
from siclib.utils.image import ImagePreprocessor, load_image |
|
from siclib.utils.tools import fork_rng |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
def load_csv( |
|
csv_file: Path, img_root: Path |
|
) -> Tuple[List[Dict[str, Any]], torch.Tensor, torch.Tensor]: |
|
"""Load a CSV file containing image information. |
|
|
|
Args: |
|
csv_file (str): Path to the CSV file. |
|
img_root (str): Path to the root directory containing the images. |
|
|
|
Returns: |
|
list: List of dictionaries containing the image paths and camera parameters. |
|
""" |
|
df = pd.read_csv(csv_file) |
|
|
|
infos, params, gravity = [], [], [] |
|
for _, row in df.iterrows(): |
|
h = row["height"] |
|
w = row["width"] |
|
px = row.get("px", w / 2) |
|
py = row.get("py", h / 2) |
|
vfov = row["vfov"] |
|
f = fov2focal(torch.tensor(vfov), h) |
|
k1 = row.get("k1", 0) |
|
k2 = row.get("k2", 0) |
|
params.append(torch.tensor([w, h, f, f, px, py, k1, k2])) |
|
|
|
roll = row["roll"] |
|
pitch = row["pitch"] |
|
gravity.append(torch.tensor([roll, pitch])) |
|
|
|
infos.append({"name": row["fname"], "file_name": str(img_root / row["fname"])}) |
|
|
|
params = torch.stack(params).float() |
|
gravity = torch.stack(gravity).float() |
|
return infos, params, gravity |
|
|
|
|
|
class SimpleDataset(BaseDataset): |
|
"""Dataset for images created with 'create_dataset_from_pano.py'.""" |
|
|
|
default_conf = { |
|
|
|
"dataset_dir": "???", |
|
"train_img_dir": "${.dataset_dir}/train", |
|
"val_img_dir": "${.dataset_dir}/val", |
|
"test_img_dir": "${.dataset_dir}/test", |
|
"train_csv": "${.dataset_dir}/train.csv", |
|
"val_csv": "${.dataset_dir}/val.csv", |
|
"test_csv": "${.dataset_dir}/test.csv", |
|
|
|
"use_up": True, |
|
"use_latitude": True, |
|
"use_prior_focal": False, |
|
"use_prior_gravity": False, |
|
"use_prior_k1": False, |
|
|
|
"grayscale": False, |
|
"preprocessing": ImagePreprocessor.default_conf, |
|
"augmentations": {"name": "geocalib", "verbose": False}, |
|
"p_rotate": 0.0, |
|
"reseed": False, |
|
"seed": 0, |
|
|
|
"num_workers": 8, |
|
"prefetch_factor": 2, |
|
"train_batch_size": 32, |
|
"val_batch_size": 32, |
|
"test_batch_size": 32, |
|
} |
|
|
|
def _init(self, conf): |
|
pass |
|
|
|
def get_dataset(self, split: str) -> torch.utils.data.Dataset: |
|
"""Return a dataset for a given split.""" |
|
return _SimpleDataset(self.conf, split) |
|
|
|
|
|
class _SimpleDataset(torch.utils.data.Dataset): |
|
"""Dataset for dataset for images created with 'create_dataset_from_pano.py'.""" |
|
|
|
def __init__(self, conf: DictConfig, split: str): |
|
"""Initialize the dataset.""" |
|
self.conf = conf |
|
self.split = split |
|
self.img_dir = Path(conf.get(f"{split}_img_dir")) |
|
|
|
self.preprocessor = ImagePreprocessor(conf.preprocessing) |
|
|
|
|
|
assert f"{split}_csv" in conf, f"Missing {split}_csv in conf" |
|
infos_path = self.conf.get(f"{split}_csv") |
|
self.infos, self.parameters, self.gravity = load_csv(infos_path, self.img_dir) |
|
|
|
|
|
aug_name = conf.augmentations.name |
|
assert ( |
|
aug_name in augmentations.keys() |
|
), f'{aug_name} not in {" ".join(augmentations.keys())}' |
|
|
|
if self.split == "train": |
|
self.augmentation = augmentations[aug_name](conf.augmentations) |
|
else: |
|
self.augmentation = IdentityAugmentation() |
|
|
|
def __len__(self): |
|
return len(self.infos) |
|
|
|
def __getitem__(self, idx): |
|
if not self.conf.reseed: |
|
return self.getitem(idx) |
|
with fork_rng(self.conf.seed + idx, False): |
|
return self.getitem(idx) |
|
|
|
def _read_image( |
|
self, infos: Dict[str, Any], parameters: torch.Tensor, gravity: torch.Tensor |
|
) -> Dict[str, Any]: |
|
path = Path(str(infos["file_name"])) |
|
|
|
|
|
image = load_image(path, self.conf.grayscale, return_tensor=False) |
|
image = self.augmentation(image, return_tensor=True) |
|
|
|
|
|
camera = SimpleRadial(parameters[None]).float() |
|
|
|
roll, pitch = gravity[None].unbind(-1) |
|
gravity = Gravity.from_rp(roll, pitch) |
|
|
|
|
|
data = self.preprocessor(image) |
|
camera = camera.scale(data["scales"]) |
|
camera = camera.crop(data["crop_pad"]) if "crop_pad" in data else camera |
|
|
|
priors = {"prior_gravity": gravity} if self.conf.use_prior_gravity else {} |
|
priors |= {"prior_focal": camera.f[..., 1]} if self.conf.use_prior_focal else {} |
|
priors |= {"prior_k1": camera.k1} if self.conf.use_prior_k1 else {} |
|
return { |
|
"name": infos["name"], |
|
"path": str(path), |
|
"camera": camera[0], |
|
"gravity": gravity[0], |
|
**priors, |
|
**data, |
|
} |
|
|
|
def _get_perspective(self, data): |
|
"""Get perspective field.""" |
|
camera = data["camera"] |
|
gravity = data["gravity"] |
|
|
|
up_field, lat_field = get_perspective_field( |
|
camera, gravity, use_up=self.conf.use_up, use_latitude=self.conf.use_latitude |
|
) |
|
|
|
out = {} |
|
if self.conf.use_up: |
|
out["up_field"] = up_field[0] |
|
if self.conf.use_latitude: |
|
out["latitude_field"] = lat_field[0] |
|
|
|
return out |
|
|
|
def getitem(self, idx: int): |
|
"""Return a sample from the dataset.""" |
|
infos = self.infos[idx] |
|
parameters = self.parameters[idx] |
|
gravity = self.gravity[idx] |
|
data = self._read_image(infos, parameters, gravity) |
|
|
|
if self.conf.use_up or self.conf.use_latitude: |
|
data |= self._get_perspective(data) |
|
|
|
return data |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import argparse |
|
|
|
import matplotlib.pyplot as plt |
|
|
|
from siclib.visualization.visualize_batch import make_perspective_figures |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--name", type=str, required=True) |
|
parser.add_argument("--data_dir", type=str) |
|
parser.add_argument("--split", type=str, default="train") |
|
parser.add_argument("--shuffle", action="store_true") |
|
parser.add_argument("--n_rows", type=int, default=4) |
|
parser.add_argument("--dpi", type=int, default=100) |
|
args = parser.parse_intermixed_args() |
|
|
|
dconf = SimpleDataset.default_conf |
|
dconf["name"] = args.name |
|
dconf["num_workers"] = 0 |
|
dconf["prefetch_factor"] = None |
|
|
|
dconf["dataset_dir"] = args.data_dir |
|
dconf[f"{args.split}_batch_size"] = args.n_rows |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
dataset = SimpleDataset(dconf) |
|
loader = dataset.get_data_loader(args.split, args.shuffle) |
|
|
|
with fork_rng(seed=42): |
|
for data in loader: |
|
pred = data |
|
break |
|
fig = make_perspective_figures(pred, data, n_pairs=args.n_rows) |
|
|
|
plt.show() |
|
|