GeoCalib / siclib /datasets /create_dataset_from_pano.py
veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
12.4 kB
"""Script to create a dataset from panorama images."""
import hashlib
import logging
from concurrent import futures
from pathlib import Path
import hydra
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy
import torch
from omegaconf import DictConfig, OmegaConf
from tqdm import tqdm
from siclib.geometry.camera import camera_models
from siclib.geometry.gravity import Gravity
from siclib.utils.conversions import deg2rad, focal2fov, fov2focal, rad2deg
from siclib.utils.image import load_image, write_image
logger = logging.getLogger(__name__)
# mypy: ignore-errors
def max_radius(a, b):
"""Compute the maximum radius of a Brown distortion model."""
discrim = a * a - 4 * b
# if torch.isfinite(discrim) and discrim >= 0.0:
# discrim = np.sqrt(discrim) - a
# if discrim > 0.0:
# return 2.0 / discrim
valid = torch.isfinite(discrim) & (discrim >= 0.0)
discrim = torch.sqrt(discrim) - a
valid &= discrim > 0.0
return 2.0 / torch.where(valid, discrim, 0)
def brown_max_radius(k1, k2):
"""Compute the maximum radius of a Brown distortion model."""
# fold the constants from the derivative into a and b
a = k1 * 3
b = k2 * 5
return torch.sqrt(max_radius(a, b))
class ParallelProcessor:
"""Generic parallel processor class."""
def __init__(self, max_workers):
"""Init processor and pbars."""
self.max_workers = max_workers
self.executor = futures.ProcessPoolExecutor(max_workers=self.max_workers)
self.pbars = {}
def update_pbar(self, pbar_key):
"""Update progressbar."""
pbar = self.pbars.get(pbar_key)
pbar.update(1)
def submit_tasks(self, task_func, task_args, pbar_key):
"""Submit tasks."""
pbar = tqdm(total=len(task_args), desc=f"Processing {pbar_key}", ncols=80)
self.pbars[pbar_key] = pbar
def update_pbar(future):
self.update_pbar(pbar_key)
futures = []
for args in task_args:
future = self.executor.submit(task_func, *args)
future.add_done_callback(update_pbar)
futures.append(future)
return futures
def wait_for_completion(self, futures):
"""Wait for completion and return results."""
results = []
for f in futures:
results += f.result()
for key in self.pbars.keys():
self.pbars[key].close()
return results
def shutdown(self):
"""Close the executer."""
self.executor.shutdown()
class DatasetGenerator:
"""Dataset generator class to create perspective datasets from panoramas."""
default_conf = {
"name": "???",
# paths
"base_dir": "???",
"pano_dir": "${.base_dir}/panoramas",
"pano_train": "${.pano_dir}/train",
"pano_val": "${.pano_dir}/val",
"pano_test": "${.pano_dir}/test",
"perspective_dir": "${.base_dir}/${.name}",
"perspective_train": "${.perspective_dir}/train",
"perspective_val": "${.perspective_dir}/val",
"perspective_test": "${.perspective_dir}/test",
"train_csv": "${.perspective_dir}/train.csv",
"val_csv": "${.perspective_dir}/val.csv",
"test_csv": "${.perspective_dir}/test.csv",
# data options
"camera_model": "pinhole",
"parameter_dists": {
"roll": {
"type": "uniform",
"options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, # in [-45, 45]
},
"pitch": {
"type": "uniform",
"options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, # in [-45, 45]
},
"vfov": {
"type": "uniform",
"options": {"loc": deg2rad(20), "scale": deg2rad(85)}, # in [20, 105]
},
"resize_factor": {
"type": "uniform",
"options": {"loc": 1.0, "scale": 1.0}, # factor in [1.0, 2.0]
},
"shape": {"type": "fix", "value": (640, 640)},
},
"images_per_pano": 16,
"n_workers": 10,
"device": "cpu",
"overwrite": False,
}
def __init__(self, conf):
"""Init the class by merging and storing the config."""
self.conf = OmegaConf.merge(
OmegaConf.create(self.default_conf),
OmegaConf.create(conf),
)
logger.info(f"Config:\n{OmegaConf.to_yaml(self.conf)}")
self.infos = {}
self.device = self.conf.device
self.camera_model = camera_models[self.conf.camera_model]
def sample_value(self, parameter_name, seed=None):
"""Sample a value from the specified distribution."""
param_conf = self.conf["parameter_dists"][parameter_name]
if param_conf.type == "fix":
return torch.tensor(param_conf.value)
# fix seed for reproducibility
generator = None
if seed:
if not isinstance(seed, (int, float)):
seed = int(hashlib.sha256(seed.encode()).hexdigest(), 16) % (2**32)
generator = np.random.default_rng(seed)
sampler = getattr(scipy.stats, param_conf.type)
return torch.tensor(sampler.rvs(random_state=generator, **param_conf.options))
def plot_distributions(self):
"""Plot parameter distributions."""
fig, ax = plt.subplots(3, 3, figsize=(15, 10))
for i, split in enumerate(["train", "val", "test"]):
roll_vals = [rad2deg(row["roll"]) for row in self.infos[split]]
ax[i, 0].hist(roll_vals, bins=100)
ax[i, 0].set_xlabel("Roll (°)")
ax[i, 0].set_ylabel(f"Count {split}")
pitch_vals = [rad2deg(row["pitch"]) for row in self.infos[split]]
ax[i, 1].hist(pitch_vals, bins=100)
ax[i, 1].set_xlabel("Pitch (°)")
ax[i, 1].set_ylabel(f"Count {split}")
vfov_vals = [rad2deg(row["vfov"]) for row in self.infos[split]]
ax[i, 2].hist(vfov_vals, bins=100)
ax[i, 2].set_xlabel("vFoV (°)")
ax[i, 2].set_ylabel(f"Count {split}")
plt.tight_layout()
plt.savefig(Path(self.conf.perspective_dir) / "distributions.pdf")
fig, ax = plt.subplots(3, 3, figsize=(15, 10))
for i, k1 in enumerate(["roll", "pitch", "vfov"]):
for j, k2 in enumerate(["roll", "pitch", "vfov"]):
ax[i, j].scatter(
[rad2deg(row[k1]) for row in self.infos["train"]],
[rad2deg(row[k2]) for row in self.infos["train"]],
s=1,
label="train",
)
ax[i, j].scatter(
[rad2deg(row[k1]) for row in self.infos["val"]],
[rad2deg(row[k2]) for row in self.infos["val"]],
s=1,
label="val",
)
ax[i, j].scatter(
[rad2deg(row[k1]) for row in self.infos["test"]],
[rad2deg(row[k2]) for row in self.infos["test"]],
s=1,
label="test",
)
ax[i, j].set_xlabel(k1)
ax[i, j].set_ylabel(k2)
ax[i, j].legend()
plt.tight_layout()
plt.savefig(Path(self.conf.perspective_dir) / "distributions_scatter.pdf")
def generate_images_from_pano(self, pano_path: Path, out_dir: Path):
"""Generate perspective images from a single panorama."""
infos = []
pano = load_image(pano_path).to(self.device)
yaws = np.linspace(0, 2 * np.pi, self.conf.images_per_pano, endpoint=False)
params = {
k: [self.sample_value(k, pano_path.stem + k + str(i)) for i in yaws]
for k in self.conf.parameter_dists
if k != "shape"
}
shapes = [self.sample_value("shape", pano_path.stem + "shape") for _ in yaws]
params |= {
"height": [shape[0] for shape in shapes],
"width": [shape[1] for shape in shapes],
}
if "k1_hat" in params:
height = torch.tensor(params["height"])
width = torch.tensor(params["width"])
k1_hat = torch.tensor(params["k1_hat"])
vfov = torch.tensor(params["vfov"])
focal = fov2focal(vfov, height)
focal = focal
rel_focal = focal / height
k1 = k1_hat * rel_focal
# distance to image corner
# r_max_im = f_px * r_max * (1 + k1*r_max**2)
# function of r_max_im: f_px = r_max_im / (r_max * (1 + k1*r_max**2))
min_permissible_rmax = torch.sqrt((height / 2) ** 2 + (width / 2) ** 2)
r_max = brown_max_radius(k1=k1, k2=0)
lowest_possible_f_px = min_permissible_rmax / (r_max * (1 + k1 * r_max**2))
valid = lowest_possible_f_px <= focal
f = torch.where(valid, focal, lowest_possible_f_px)
vfov = focal2fov(f, height)
params["vfov"] = vfov
params |= {"k1": k1}
cam = self.camera_model.from_dict(params).float().to(self.device)
gravity = Gravity.from_rp(params["roll"], params["pitch"]).float().to(self.device)
if (out_dir / f"{pano_path.stem}_0.jpg").exists() and not self.conf.overwrite:
for i in range(self.conf.images_per_pano):
perspective_name = f"{pano_path.stem}_{i}.jpg"
info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()}
infos.append(info)
logger.info(f"Perspectives for {pano_path.stem} already exist.")
return infos
perspective_images = cam.get_img_from_pano(
pano_img=pano, gravity=gravity, yaws=yaws, resize_factor=params["resize_factor"]
)
for i, perspective_image in enumerate(perspective_images):
perspective_name = f"{pano_path.stem}_{i}.jpg"
n_pixels = perspective_image.shape[-2] * perspective_image.shape[-1]
valid = (torch.sum(perspective_image.sum(0) == 0) / n_pixels) < 0.01
if not valid:
logger.debug(f"Perspective {perspective_name} has too many black pixels.")
continue
write_image(perspective_image, out_dir / perspective_name)
info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()}
infos.append(info)
return infos
def generate_split(self, split: str, parallel_processor: ParallelProcessor):
"""Generate a single split of a dataset."""
self.infos[split] = []
panorama_paths = [
path
for path in Path(self.conf[f"pano_{split}"]).glob("*")
if not path.name.startswith(".")
]
out_dir = Path(self.conf[f"perspective_{split}"])
logger.info(f"Writing perspective images to {str(out_dir)}")
if not out_dir.exists():
out_dir.mkdir(parents=True)
futures = parallel_processor.submit_tasks(
self.generate_images_from_pano, [(f, out_dir) for f in panorama_paths], split
)
self.infos[split] = parallel_processor.wait_for_completion(futures)
# parallel_processor.shutdown()
metadata = pd.DataFrame(data=self.infos[split])
metadata.to_csv(self.conf[f"{split}_csv"])
def generate_dataset(self):
"""Generate all splits of a dataset."""
out_dir = Path(self.conf.perspective_dir)
if not out_dir.exists():
out_dir.mkdir(parents=True)
OmegaConf.save(self.conf, out_dir / "config.yaml")
processor = ParallelProcessor(self.conf.n_workers)
for split in ["train", "val", "test"]:
self.generate_split(split=split, parallel_processor=processor)
processor.shutdown()
for split in ["train", "val", "test"]:
logger.info(f"Generated {len(self.infos[split])} {split} images.")
self.plot_distributions()
@hydra.main(version_base=None, config_path="configs", config_name="SUN360")
def main(cfg: DictConfig) -> None:
"""Run dataset generation."""
generator = DatasetGenerator(conf=cfg)
generator.generate_dataset()
if __name__ == "__main__":
main()