"""Script to create a dataset from panorama images.""" |
import hashlib |
import logging |
from concurrent import futures |
from pathlib import Path |
import hydra |
import matplotlib.pyplot as plt |
import numpy as np |
import pandas as pd |
import scipy |
import torch |
from omegaconf import DictConfig, OmegaConf |
from tqdm import tqdm |
from siclib.geometry.camera import camera_models |
from siclib.geometry.gravity import Gravity |
from siclib.utils.conversions import deg2rad, focal2fov, fov2focal, rad2deg |
from siclib.utils.image import load_image, write_image |
logger = logging.getLogger(__name__) |
def max_radius(a, b): |
"""Compute the maximum radius of a Brown distortion model.""" |
discrim = a * a - 4 * b |
valid = torch.isfinite(discrim) & (discrim >= 0.0) |
discrim = torch.sqrt(discrim) - a |
valid &= discrim > 0.0 |
return 2.0 / torch.where(valid, discrim, 0) |
def brown_max_radius(k1, k2): |
"""Compute the maximum radius of a Brown distortion model.""" |
a = k1 * 3 |
b = k2 * 5 |
return torch.sqrt(max_radius(a, b)) |
class ParallelProcessor: |
"""Generic parallel processor class.""" |
def __init__(self, max_workers): |
"""Init processor and pbars.""" |
self.max_workers = max_workers |
self.executor = futures.ProcessPoolExecutor(max_workers=self.max_workers) |
self.pbars = {} |
def update_pbar(self, pbar_key): |
"""Update progressbar.""" |
pbar = self.pbars.get(pbar_key) |
pbar.update(1) |
def submit_tasks(self, task_func, task_args, pbar_key): |
"""Submit tasks.""" |
pbar = tqdm(total=len(task_args), desc=f"Processing {pbar_key}", ncols=80) |
self.pbars[pbar_key] = pbar |
def update_pbar(future): |
self.update_pbar(pbar_key) |
futures = [] |
for args in task_args: |
future = self.executor.submit(task_func, *args) |
future.add_done_callback(update_pbar) |
futures.append(future) |
return futures |
def wait_for_completion(self, futures): |
"""Wait for completion and return results.""" |
results = [] |
for f in futures: |
results += f.result() |
for key in self.pbars.keys(): |
self.pbars[key].close() |
return results |
def shutdown(self): |
"""Close the executer.""" |
self.executor.shutdown() |
class DatasetGenerator: |
"""Dataset generator class to create perspective datasets from panoramas.""" |
default_conf = { |
"name": "???", |
"base_dir": "???", |
"pano_dir": "${.base_dir}/panoramas", |
"pano_train": "${.pano_dir}/train", |
"pano_val": "${.pano_dir}/val", |
"pano_test": "${.pano_dir}/test", |
"perspective_dir": "${.base_dir}/${.name}", |
"perspective_train": "${.perspective_dir}/train", |
"perspective_val": "${.perspective_dir}/val", |
"perspective_test": "${.perspective_dir}/test", |
"train_csv": "${.perspective_dir}/train.csv", |
"val_csv": "${.perspective_dir}/val.csv", |
"test_csv": "${.perspective_dir}/test.csv", |
"camera_model": "pinhole", |
"parameter_dists": { |
"roll": { |
"type": "uniform", |
"options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, |
}, |
"pitch": { |
"type": "uniform", |
"options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, |
}, |
"vfov": { |
"type": "uniform", |
"options": {"loc": deg2rad(20), "scale": deg2rad(85)}, |
}, |
"resize_factor": { |
"type": "uniform", |
"options": {"loc": 1.0, "scale": 1.0}, |
}, |
"shape": {"type": "fix", "value": (640, 640)}, |
}, |
"images_per_pano": 16, |
"n_workers": 10, |
"device": "cpu", |
"overwrite": False, |
} |
def __init__(self, conf): |
"""Init the class by merging and storing the config.""" |
self.conf = OmegaConf.merge( |
OmegaConf.create(self.default_conf), |
OmegaConf.create(conf), |
) |
logger.info(f"Config:\n{OmegaConf.to_yaml(self.conf)}") |
self.infos = {} |
self.device = self.conf.device |
self.camera_model = camera_models[self.conf.camera_model] |
def sample_value(self, parameter_name, seed=None): |
"""Sample a value from the specified distribution.""" |
param_conf = self.conf["parameter_dists"][parameter_name] |
if param_conf.type == "fix": |
return torch.tensor(param_conf.value) |
generator = None |
if seed: |
if not isinstance(seed, (int, float)): |
seed = int(hashlib.sha256(seed.encode()).hexdigest(), 16) % (2**32) |
generator = np.random.default_rng(seed) |
sampler = getattr(scipy.stats, param_conf.type) |
return torch.tensor(sampler.rvs(random_state=generator, **param_conf.options)) |
def plot_distributions(self): |
"""Plot parameter distributions.""" |
fig, ax = plt.subplots(3, 3, figsize=(15, 10)) |
for i, split in enumerate(["train", "val", "test"]): |
roll_vals = [rad2deg(row["roll"]) for row in self.infos[split]] |
ax[i, 0].hist(roll_vals, bins=100) |
ax[i, 0].set_xlabel("Roll (°)") |
ax[i, 0].set_ylabel(f"Count {split}") |
pitch_vals = [rad2deg(row["pitch"]) for row in self.infos[split]] |
ax[i, 1].hist(pitch_vals, bins=100) |
ax[i, 1].set_xlabel("Pitch (°)") |
ax[i, 1].set_ylabel(f"Count {split}") |
vfov_vals = [rad2deg(row["vfov"]) for row in self.infos[split]] |
ax[i, 2].hist(vfov_vals, bins=100) |
ax[i, 2].set_xlabel("vFoV (°)") |
ax[i, 2].set_ylabel(f"Count {split}") |
plt.tight_layout() |
plt.savefig(Path(self.conf.perspective_dir) / "distributions.pdf") |
fig, ax = plt.subplots(3, 3, figsize=(15, 10)) |
for i, k1 in enumerate(["roll", "pitch", "vfov"]): |
for j, k2 in enumerate(["roll", "pitch", "vfov"]): |
ax[i, j].scatter( |
[rad2deg(row[k1]) for row in self.infos["train"]], |
[rad2deg(row[k2]) for row in self.infos["train"]], |
s=1, |
label="train", |
) |
ax[i, j].scatter( |
[rad2deg(row[k1]) for row in self.infos["val"]], |
[rad2deg(row[k2]) for row in self.infos["val"]], |
s=1, |
label="val", |
) |
ax[i, j].scatter( |
[rad2deg(row[k1]) for row in self.infos["test"]], |
[rad2deg(row[k2]) for row in self.infos["test"]], |
s=1, |
label="test", |
) |
ax[i, j].set_xlabel(k1) |
ax[i, j].set_ylabel(k2) |
ax[i, j].legend() |
plt.tight_layout() |
plt.savefig(Path(self.conf.perspective_dir) / "distributions_scatter.pdf") |
def generate_images_from_pano(self, pano_path: Path, out_dir: Path): |
"""Generate perspective images from a single panorama.""" |
infos = [] |
pano = load_image(pano_path).to(self.device) |
yaws = np.linspace(0, 2 * np.pi, self.conf.images_per_pano, endpoint=False) |
params = { |
k: [self.sample_value(k, pano_path.stem + k + str(i)) for i in yaws] |
for k in self.conf.parameter_dists |
if k != "shape" |
} |
shapes = [self.sample_value("shape", pano_path.stem + "shape") for _ in yaws] |
params |= { |
"height": [shape[0] for shape in shapes], |
"width": [shape[1] for shape in shapes], |
} |
if "k1_hat" in params: |
height = torch.tensor(params["height"]) |
width = torch.tensor(params["width"]) |
k1_hat = torch.tensor(params["k1_hat"]) |
vfov = torch.tensor(params["vfov"]) |
focal = fov2focal(vfov, height) |
focal = focal |
rel_focal = focal / height |
k1 = k1_hat * rel_focal |
min_permissible_rmax = torch.sqrt((height / 2) ** 2 + (width / 2) ** 2) |
r_max = brown_max_radius(k1=k1, k2=0) |
lowest_possible_f_px = min_permissible_rmax / (r_max * (1 + k1 * r_max**2)) |
valid = lowest_possible_f_px <= focal |
f = torch.where(valid, focal, lowest_possible_f_px) |
vfov = focal2fov(f, height) |
params["vfov"] = vfov |
params |= {"k1": k1} |
cam = self.camera_model.from_dict(params).float().to(self.device) |
gravity = Gravity.from_rp(params["roll"], params["pitch"]).float().to(self.device) |
if (out_dir / f"{pano_path.stem}_0.jpg").exists() and not self.conf.overwrite: |
for i in range(self.conf.images_per_pano): |
perspective_name = f"{pano_path.stem}_{i}.jpg" |
info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()} |
infos.append(info) |
logger.info(f"Perspectives for {pano_path.stem} already exist.") |
return infos |
perspective_images = cam.get_img_from_pano( |
pano_img=pano, gravity=gravity, yaws=yaws, resize_factor=params["resize_factor"] |
) |
for i, perspective_image in enumerate(perspective_images): |
perspective_name = f"{pano_path.stem}_{i}.jpg" |
n_pixels = perspective_image.shape[-2] * perspective_image.shape[-1] |
valid = (torch.sum(perspective_image.sum(0) == 0) / n_pixels) < 0.01 |
if not valid: |
logger.debug(f"Perspective {perspective_name} has too many black pixels.") |
continue |
write_image(perspective_image, out_dir / perspective_name) |
info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()} |
infos.append(info) |
return infos |
def generate_split(self, split: str, parallel_processor: ParallelProcessor): |
"""Generate a single split of a dataset.""" |
self.infos[split] = [] |
panorama_paths = [ |
path |
for path in Path(self.conf[f"pano_{split}"]).glob("*") |
if not path.name.startswith(".") |
] |
out_dir = Path(self.conf[f"perspective_{split}"]) |
logger.info(f"Writing perspective images to {str(out_dir)}") |
if not out_dir.exists(): |
out_dir.mkdir(parents=True) |
futures = parallel_processor.submit_tasks( |
self.generate_images_from_pano, [(f, out_dir) for f in panorama_paths], split |
) |
self.infos[split] = parallel_processor.wait_for_completion(futures) |
metadata = pd.DataFrame(data=self.infos[split]) |
metadata.to_csv(self.conf[f"{split}_csv"]) |
def generate_dataset(self): |
"""Generate all splits of a dataset.""" |
out_dir = Path(self.conf.perspective_dir) |
if not out_dir.exists(): |
out_dir.mkdir(parents=True) |
OmegaConf.save(self.conf, out_dir / "config.yaml") |
processor = ParallelProcessor(self.conf.n_workers) |
for split in ["train", "val", "test"]: |
self.generate_split(split=split, parallel_processor=processor) |
processor.shutdown() |
for split in ["train", "val", "test"]: |
logger.info(f"Generated {len(self.infos[split])} {split} images.") |
self.plot_distributions() |
@hydra.main(version_base=None, config_path="configs", config_name="SUN360") |
def main(cfg: DictConfig) -> None: |
"""Run dataset generation.""" |
generator = DatasetGenerator(conf=cfg) |
generator.generate_dataset() |
if __name__ == "__main__": |
main() |