|
"""Script to create a dataset from panorama images.""" |
|
|
|
import hashlib |
|
import logging |
|
from concurrent import futures |
|
from pathlib import Path |
|
|
|
import hydra |
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import pandas as pd |
|
import scipy |
|
import torch |
|
from omegaconf import DictConfig, OmegaConf |
|
from tqdm import tqdm |
|
|
|
from siclib.geometry.camera import camera_models |
|
from siclib.geometry.gravity import Gravity |
|
from siclib.utils.conversions import deg2rad, focal2fov, fov2focal, rad2deg |
|
from siclib.utils.image import load_image, write_image |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def max_radius(a, b): |
|
"""Compute the maximum radius of a Brown distortion model.""" |
|
discrim = a * a - 4 * b |
|
|
|
|
|
|
|
|
|
|
|
valid = torch.isfinite(discrim) & (discrim >= 0.0) |
|
discrim = torch.sqrt(discrim) - a |
|
valid &= discrim > 0.0 |
|
return 2.0 / torch.where(valid, discrim, 0) |
|
|
|
|
|
def brown_max_radius(k1, k2): |
|
"""Compute the maximum radius of a Brown distortion model.""" |
|
|
|
a = k1 * 3 |
|
b = k2 * 5 |
|
return torch.sqrt(max_radius(a, b)) |
|
|
|
|
|
class ParallelProcessor: |
|
"""Generic parallel processor class.""" |
|
|
|
def __init__(self, max_workers): |
|
"""Init processor and pbars.""" |
|
self.max_workers = max_workers |
|
self.executor = futures.ProcessPoolExecutor(max_workers=self.max_workers) |
|
self.pbars = {} |
|
|
|
def update_pbar(self, pbar_key): |
|
"""Update progressbar.""" |
|
pbar = self.pbars.get(pbar_key) |
|
pbar.update(1) |
|
|
|
def submit_tasks(self, task_func, task_args, pbar_key): |
|
"""Submit tasks.""" |
|
pbar = tqdm(total=len(task_args), desc=f"Processing {pbar_key}", ncols=80) |
|
self.pbars[pbar_key] = pbar |
|
|
|
def update_pbar(future): |
|
self.update_pbar(pbar_key) |
|
|
|
futures = [] |
|
for args in task_args: |
|
future = self.executor.submit(task_func, *args) |
|
future.add_done_callback(update_pbar) |
|
futures.append(future) |
|
|
|
return futures |
|
|
|
def wait_for_completion(self, futures): |
|
"""Wait for completion and return results.""" |
|
results = [] |
|
for f in futures: |
|
results += f.result() |
|
|
|
for key in self.pbars.keys(): |
|
self.pbars[key].close() |
|
|
|
return results |
|
|
|
def shutdown(self): |
|
"""Close the executer.""" |
|
self.executor.shutdown() |
|
|
|
|
|
class DatasetGenerator: |
|
"""Dataset generator class to create perspective datasets from panoramas.""" |
|
|
|
default_conf = { |
|
"name": "???", |
|
|
|
"base_dir": "???", |
|
"pano_dir": "${.base_dir}/panoramas", |
|
"pano_train": "${.pano_dir}/train", |
|
"pano_val": "${.pano_dir}/val", |
|
"pano_test": "${.pano_dir}/test", |
|
"perspective_dir": "${.base_dir}/${.name}", |
|
"perspective_train": "${.perspective_dir}/train", |
|
"perspective_val": "${.perspective_dir}/val", |
|
"perspective_test": "${.perspective_dir}/test", |
|
"train_csv": "${.perspective_dir}/train.csv", |
|
"val_csv": "${.perspective_dir}/val.csv", |
|
"test_csv": "${.perspective_dir}/test.csv", |
|
|
|
"camera_model": "pinhole", |
|
"parameter_dists": { |
|
"roll": { |
|
"type": "uniform", |
|
"options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, |
|
}, |
|
"pitch": { |
|
"type": "uniform", |
|
"options": {"loc": deg2rad(-45), "scale": deg2rad(90)}, |
|
}, |
|
"vfov": { |
|
"type": "uniform", |
|
"options": {"loc": deg2rad(20), "scale": deg2rad(85)}, |
|
}, |
|
"resize_factor": { |
|
"type": "uniform", |
|
"options": {"loc": 1.0, "scale": 1.0}, |
|
}, |
|
"shape": {"type": "fix", "value": (640, 640)}, |
|
}, |
|
"images_per_pano": 16, |
|
"n_workers": 10, |
|
"device": "cpu", |
|
"overwrite": False, |
|
} |
|
|
|
def __init__(self, conf): |
|
"""Init the class by merging and storing the config.""" |
|
self.conf = OmegaConf.merge( |
|
OmegaConf.create(self.default_conf), |
|
OmegaConf.create(conf), |
|
) |
|
logger.info(f"Config:\n{OmegaConf.to_yaml(self.conf)}") |
|
|
|
self.infos = {} |
|
self.device = self.conf.device |
|
|
|
self.camera_model = camera_models[self.conf.camera_model] |
|
|
|
def sample_value(self, parameter_name, seed=None): |
|
"""Sample a value from the specified distribution.""" |
|
param_conf = self.conf["parameter_dists"][parameter_name] |
|
|
|
if param_conf.type == "fix": |
|
return torch.tensor(param_conf.value) |
|
|
|
|
|
generator = None |
|
if seed: |
|
if not isinstance(seed, (int, float)): |
|
seed = int(hashlib.sha256(seed.encode()).hexdigest(), 16) % (2**32) |
|
generator = np.random.default_rng(seed) |
|
|
|
sampler = getattr(scipy.stats, param_conf.type) |
|
return torch.tensor(sampler.rvs(random_state=generator, **param_conf.options)) |
|
|
|
def plot_distributions(self): |
|
"""Plot parameter distributions.""" |
|
fig, ax = plt.subplots(3, 3, figsize=(15, 10)) |
|
for i, split in enumerate(["train", "val", "test"]): |
|
roll_vals = [rad2deg(row["roll"]) for row in self.infos[split]] |
|
ax[i, 0].hist(roll_vals, bins=100) |
|
ax[i, 0].set_xlabel("Roll (°)") |
|
ax[i, 0].set_ylabel(f"Count {split}") |
|
|
|
pitch_vals = [rad2deg(row["pitch"]) for row in self.infos[split]] |
|
ax[i, 1].hist(pitch_vals, bins=100) |
|
ax[i, 1].set_xlabel("Pitch (°)") |
|
ax[i, 1].set_ylabel(f"Count {split}") |
|
|
|
vfov_vals = [rad2deg(row["vfov"]) for row in self.infos[split]] |
|
ax[i, 2].hist(vfov_vals, bins=100) |
|
ax[i, 2].set_xlabel("vFoV (°)") |
|
ax[i, 2].set_ylabel(f"Count {split}") |
|
|
|
plt.tight_layout() |
|
plt.savefig(Path(self.conf.perspective_dir) / "distributions.pdf") |
|
|
|
fig, ax = plt.subplots(3, 3, figsize=(15, 10)) |
|
for i, k1 in enumerate(["roll", "pitch", "vfov"]): |
|
for j, k2 in enumerate(["roll", "pitch", "vfov"]): |
|
ax[i, j].scatter( |
|
[rad2deg(row[k1]) for row in self.infos["train"]], |
|
[rad2deg(row[k2]) for row in self.infos["train"]], |
|
s=1, |
|
label="train", |
|
) |
|
|
|
ax[i, j].scatter( |
|
[rad2deg(row[k1]) for row in self.infos["val"]], |
|
[rad2deg(row[k2]) for row in self.infos["val"]], |
|
s=1, |
|
label="val", |
|
) |
|
|
|
ax[i, j].scatter( |
|
[rad2deg(row[k1]) for row in self.infos["test"]], |
|
[rad2deg(row[k2]) for row in self.infos["test"]], |
|
s=1, |
|
label="test", |
|
) |
|
|
|
ax[i, j].set_xlabel(k1) |
|
ax[i, j].set_ylabel(k2) |
|
ax[i, j].legend() |
|
|
|
plt.tight_layout() |
|
plt.savefig(Path(self.conf.perspective_dir) / "distributions_scatter.pdf") |
|
|
|
def generate_images_from_pano(self, pano_path: Path, out_dir: Path): |
|
"""Generate perspective images from a single panorama.""" |
|
infos = [] |
|
|
|
pano = load_image(pano_path).to(self.device) |
|
|
|
yaws = np.linspace(0, 2 * np.pi, self.conf.images_per_pano, endpoint=False) |
|
params = { |
|
k: [self.sample_value(k, pano_path.stem + k + str(i)) for i in yaws] |
|
for k in self.conf.parameter_dists |
|
if k != "shape" |
|
} |
|
shapes = [self.sample_value("shape", pano_path.stem + "shape") for _ in yaws] |
|
params |= { |
|
"height": [shape[0] for shape in shapes], |
|
"width": [shape[1] for shape in shapes], |
|
} |
|
|
|
if "k1_hat" in params: |
|
height = torch.tensor(params["height"]) |
|
width = torch.tensor(params["width"]) |
|
k1_hat = torch.tensor(params["k1_hat"]) |
|
vfov = torch.tensor(params["vfov"]) |
|
focal = fov2focal(vfov, height) |
|
focal = focal |
|
rel_focal = focal / height |
|
k1 = k1_hat * rel_focal |
|
|
|
|
|
|
|
|
|
min_permissible_rmax = torch.sqrt((height / 2) ** 2 + (width / 2) ** 2) |
|
r_max = brown_max_radius(k1=k1, k2=0) |
|
lowest_possible_f_px = min_permissible_rmax / (r_max * (1 + k1 * r_max**2)) |
|
valid = lowest_possible_f_px <= focal |
|
|
|
f = torch.where(valid, focal, lowest_possible_f_px) |
|
vfov = focal2fov(f, height) |
|
|
|
params["vfov"] = vfov |
|
params |= {"k1": k1} |
|
|
|
cam = self.camera_model.from_dict(params).float().to(self.device) |
|
gravity = Gravity.from_rp(params["roll"], params["pitch"]).float().to(self.device) |
|
|
|
if (out_dir / f"{pano_path.stem}_0.jpg").exists() and not self.conf.overwrite: |
|
for i in range(self.conf.images_per_pano): |
|
perspective_name = f"{pano_path.stem}_{i}.jpg" |
|
info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()} |
|
infos.append(info) |
|
|
|
logger.info(f"Perspectives for {pano_path.stem} already exist.") |
|
|
|
return infos |
|
|
|
perspective_images = cam.get_img_from_pano( |
|
pano_img=pano, gravity=gravity, yaws=yaws, resize_factor=params["resize_factor"] |
|
) |
|
|
|
for i, perspective_image in enumerate(perspective_images): |
|
perspective_name = f"{pano_path.stem}_{i}.jpg" |
|
|
|
n_pixels = perspective_image.shape[-2] * perspective_image.shape[-1] |
|
valid = (torch.sum(perspective_image.sum(0) == 0) / n_pixels) < 0.01 |
|
if not valid: |
|
logger.debug(f"Perspective {perspective_name} has too many black pixels.") |
|
continue |
|
|
|
write_image(perspective_image, out_dir / perspective_name) |
|
|
|
info = {"fname": perspective_name} | {k: v[i].item() for k, v in params.items()} |
|
infos.append(info) |
|
|
|
return infos |
|
|
|
def generate_split(self, split: str, parallel_processor: ParallelProcessor): |
|
"""Generate a single split of a dataset.""" |
|
self.infos[split] = [] |
|
panorama_paths = [ |
|
path |
|
for path in Path(self.conf[f"pano_{split}"]).glob("*") |
|
if not path.name.startswith(".") |
|
] |
|
|
|
out_dir = Path(self.conf[f"perspective_{split}"]) |
|
logger.info(f"Writing perspective images to {str(out_dir)}") |
|
if not out_dir.exists(): |
|
out_dir.mkdir(parents=True) |
|
|
|
futures = parallel_processor.submit_tasks( |
|
self.generate_images_from_pano, [(f, out_dir) for f in panorama_paths], split |
|
) |
|
self.infos[split] = parallel_processor.wait_for_completion(futures) |
|
|
|
|
|
metadata = pd.DataFrame(data=self.infos[split]) |
|
metadata.to_csv(self.conf[f"{split}_csv"]) |
|
|
|
def generate_dataset(self): |
|
"""Generate all splits of a dataset.""" |
|
out_dir = Path(self.conf.perspective_dir) |
|
if not out_dir.exists(): |
|
out_dir.mkdir(parents=True) |
|
|
|
OmegaConf.save(self.conf, out_dir / "config.yaml") |
|
|
|
processor = ParallelProcessor(self.conf.n_workers) |
|
for split in ["train", "val", "test"]: |
|
self.generate_split(split=split, parallel_processor=processor) |
|
|
|
processor.shutdown() |
|
|
|
for split in ["train", "val", "test"]: |
|
logger.info(f"Generated {len(self.infos[split])} {split} images.") |
|
|
|
self.plot_distributions() |
|
|
|
|
|
@hydra.main(version_base=None, config_path="configs", config_name="SUN360") |
|
def main(cfg: DictConfig) -> None: |
|
"""Run dataset generation.""" |
|
generator = DatasetGenerator(conf=cfg) |
|
generator.generate_dataset() |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|