|
import logging |
|
import resource |
|
from collections import defaultdict |
|
from pathlib import Path |
|
from pprint import pprint |
|
from typing import Dict, List, Tuple |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
import torch |
|
from omegaconf import OmegaConf |
|
from tqdm import tqdm |
|
|
|
from siclib.datasets import get_dataset |
|
from siclib.eval.eval_pipeline import EvalPipeline |
|
from siclib.eval.io import get_eval_parser, load_model, parse_eval_args |
|
from siclib.eval.utils import download_and_extract_benchmark, plot_scatter_grid |
|
from siclib.geometry.base_camera import BaseCamera |
|
from siclib.geometry.camera import Pinhole |
|
from siclib.geometry.gravity import Gravity |
|
from siclib.models.cache_loader import CacheLoader |
|
from siclib.models.utils.metrics import ( |
|
gravity_error, |
|
latitude_error, |
|
pitch_error, |
|
roll_error, |
|
up_error, |
|
vfov_error, |
|
) |
|
from siclib.settings import EVAL_PATH |
|
from siclib.utils.conversions import rad2deg |
|
from siclib.utils.export_predictions import export_predictions |
|
from siclib.utils.tensor import add_batch_dim |
|
from siclib.utils.tools import AUCMetric, set_seed |
|
from siclib.visualization import visualize_batch, viz2d |
|
|
|
|
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) |
|
resource.setrlimit(resource.RLIMIT_NOFILE, (4096, rlimit[1])) |
|
|
|
torch.set_grad_enabled(False) |
|
|
|
|
|
def calculate_pixel_projection_error( |
|
camera_pred: BaseCamera, camera_gt: BaseCamera, N: int = 500, distortion_only: bool = True |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Calculate the pixel projection error between two cameras. |
|
|
|
1. Project a grid of points with the ground truth camera to the image plane. |
|
2. Project the same grid of points with the estimated camera to the image plane. |
|
3. Calculate the pixel distance between the ground truth and estimated points. |
|
|
|
Args: |
|
camera_pred (Camera): Predicted camera. |
|
camera_gt (Camera): Ground truth camera. |
|
N (int, optional): Number of points in the grid. Defaults to 500. |
|
|
|
Returns: |
|
Tuple[torch.Tensor, torch.Tensor]: Pixel distance and valid pixels. |
|
""" |
|
H, W = camera_gt.size.unbind(-1) |
|
H, W = H.int(), W.int() |
|
|
|
assert torch.allclose( |
|
camera_gt.size, camera_pred.size |
|
), f"Cameras must have the same size: {camera_gt.size} != {camera_pred.size}" |
|
|
|
if distortion_only: |
|
params = camera_gt._data.clone() |
|
params[..., -2:] = camera_pred._data[..., -2:] |
|
CameraModel = type(camera_gt) |
|
camera_pred = CameraModel(params) |
|
|
|
x_gt, y_gt = torch.meshgrid( |
|
torch.linspace(0, H - 1, N), torch.linspace(0, W - 1, N), indexing="xy" |
|
) |
|
xy = torch.stack((x_gt, y_gt), dim=-1).reshape(-1, 2) |
|
|
|
camera_pin_gt = camera_gt.pinhole() |
|
uv_pin, _ = camera_pin_gt.image2world(xy) |
|
|
|
|
|
xy_undist_gt, valid_dist_gt = camera_gt.world2image(uv_pin) |
|
|
|
xy_undist, valid_dist = camera_pred.world2image(uv_pin) |
|
|
|
valid = valid_dist_gt & valid_dist |
|
|
|
dist = (xy_undist - xy_undist_gt) ** 2 |
|
dist = (dist.sum(-1)).sqrt() |
|
|
|
return dist[valid_dist_gt], valid[valid_dist_gt] |
|
|
|
|
|
def compute_camera_metrics( |
|
camera_pred: BaseCamera, camera_gt: BaseCamera, thresholds: List[float] |
|
) -> Dict[str, float]: |
|
results = defaultdict(list) |
|
results["vfov"].append(rad2deg(camera_pred.vfov).item()) |
|
results["vfov_error"].append(vfov_error(camera_pred, camera_gt).item()) |
|
|
|
results["focal"].append(camera_pred.f[..., 1].item()) |
|
focal_error = torch.abs(camera_pred.f[..., 1] - camera_gt.f[..., 1]) |
|
results["focal_error"].append(focal_error.item()) |
|
|
|
rel_focal_error = torch.abs(camera_pred.f[..., 1] - camera_gt.f[..., 1]) / camera_gt.f[..., 1] |
|
results["rel_focal_error"].append(rel_focal_error.item()) |
|
|
|
if hasattr(camera_pred, "k1"): |
|
results["k1"].append(camera_pred.k1.item()) |
|
k1_error = torch.abs(camera_pred.k1 - camera_gt.k1) |
|
results["k1_error"].append(k1_error.item()) |
|
|
|
if thresholds is None: |
|
return results |
|
|
|
err, valid = calculate_pixel_projection_error(camera_pred, camera_gt, distortion_only=False) |
|
for th in thresholds: |
|
results[f"pixel_projection_error@{th}"].append( |
|
((err[valid] < th).sum() / len(valid)).float().item() |
|
) |
|
|
|
err, valid = calculate_pixel_projection_error(camera_pred, camera_gt, distortion_only=True) |
|
for th in thresholds: |
|
results[f"pixel_distortion_error@{th}"].append( |
|
((err[valid] < th).sum() / len(valid)).float().item() |
|
) |
|
return results |
|
|
|
|
|
def compute_gravity_metrics(gravity_pred: Gravity, gravity_gt: Gravity) -> Dict[str, float]: |
|
results = defaultdict(list) |
|
results["roll"].append(rad2deg(gravity_pred.roll).item()) |
|
results["pitch"].append(rad2deg(gravity_pred.pitch).item()) |
|
|
|
results["roll_error"].append(roll_error(gravity_pred, gravity_gt).item()) |
|
results["pitch_error"].append(pitch_error(gravity_pred, gravity_gt).item()) |
|
results["gravity_error"].append(gravity_error(gravity_pred[None], gravity_gt[None]).item()) |
|
return results |
|
|
|
|
|
class SimplePipeline(EvalPipeline): |
|
default_conf = { |
|
"data": {}, |
|
"model": {}, |
|
"eval": { |
|
"thresholds": [1, 5, 10], |
|
"pixel_thresholds": [0.5, 1, 3, 5], |
|
"num_vis": 10, |
|
"verbose": True, |
|
}, |
|
"url": None, |
|
} |
|
|
|
export_keys = [ |
|
"camera", |
|
"gravity", |
|
] |
|
|
|
optional_export_keys = [ |
|
"focal_uncertainty", |
|
"vfov_uncertainty", |
|
"roll_uncertainty", |
|
"pitch_uncertainty", |
|
"gravity_uncertainty", |
|
"up_field", |
|
"up_confidence", |
|
"latitude_field", |
|
"latitude_confidence", |
|
] |
|
|
|
def _init(self, conf): |
|
self.verbose = conf.eval.verbose |
|
self.num_vis = self.conf.eval.num_vis |
|
|
|
self.CameraModel = Pinhole |
|
|
|
if conf.url is not None: |
|
ds_dir = Path(conf.data.dataset_dir) |
|
download_and_extract_benchmark(ds_dir.name, conf.url, ds_dir.parent) |
|
|
|
@classmethod |
|
def get_dataloader(cls, data_conf=None, batch_size=None): |
|
"""Returns a data loader with samples for each eval datapoint""" |
|
data_conf = data_conf or cls.default_conf["data"] |
|
|
|
if batch_size is not None: |
|
data_conf["test_batch_size"] = batch_size |
|
|
|
do_shuffle = data_conf["test_batch_size"] > 1 |
|
dataset = get_dataset(data_conf["name"])(data_conf) |
|
return dataset.get_data_loader("test", shuffle=do_shuffle) |
|
|
|
def get_predictions(self, experiment_dir, model=None, overwrite=False): |
|
"""Export a prediction file for each eval datapoint""" |
|
|
|
pred_file = experiment_dir / "predictions.h5" |
|
if not pred_file.exists() or overwrite: |
|
if model is None: |
|
model = load_model(self.conf.model, self.conf.checkpoint) |
|
export_predictions( |
|
self.get_dataloader(self.conf.data), |
|
model, |
|
pred_file, |
|
keys=self.export_keys, |
|
optional_keys=self.optional_export_keys, |
|
verbose=self.verbose, |
|
) |
|
return pred_file |
|
|
|
def get_figures(self, results): |
|
figures = {} |
|
|
|
if self.num_vis == 0: |
|
return figures |
|
|
|
gl = ["up", "latitude"] |
|
rpf = ["roll", "pitch", "vfov"] |
|
|
|
|
|
if all(k in results for k in rpf): |
|
x_keys = [f"{k}_gt" for k in rpf] |
|
|
|
|
|
y_keys = [f"{k}_error" for k in rpf] |
|
fig, _ = plot_scatter_grid(results, x_keys, y_keys, show_means=False) |
|
figures |= {"rpf_gt_error": fig} |
|
|
|
|
|
y_keys = [f"{k}" for k in rpf] |
|
fig, _ = plot_scatter_grid(results, x_keys, y_keys, diag=True, show_means=False) |
|
figures |= {"rpf_gt_pred": fig} |
|
|
|
if all(f"{k}_error" in results for k in gl): |
|
x_keys = [f"{k}_gt" for k in rpf] |
|
y_keys = [f"{k}_error" for k in gl] |
|
fig, _ = plot_scatter_grid(results, x_keys, y_keys, show_means=False) |
|
figures |= {"gl_gt_error": fig} |
|
|
|
return figures |
|
|
|
def run_eval(self, loader, pred_file): |
|
conf = self.conf.eval |
|
results = defaultdict(list) |
|
|
|
save_to = Path(pred_file).parent / "figures" |
|
if not save_to.exists() and self.num_vis > 0: |
|
save_to.mkdir() |
|
|
|
cache_loader = CacheLoader({"path": str(pred_file), "collate": None}).eval() |
|
|
|
if not self.verbose: |
|
logger.info(f"Evaluating {pred_file}") |
|
|
|
for i, data in enumerate( |
|
tqdm(loader, desc="Evaluating", total=len(loader), ncols=80, disable=not self.verbose) |
|
): |
|
|
|
pred = cache_loader(data) |
|
|
|
results["names"].append(data["name"][0]) |
|
|
|
gt_cam = data["camera"][0] |
|
gt_gravity = data["gravity"][0] |
|
|
|
results["roll_gt"].append(rad2deg(gt_gravity.roll).item()) |
|
results["pitch_gt"].append(rad2deg(gt_gravity.pitch).item()) |
|
results["vfov_gt"].append(rad2deg(gt_cam.vfov).item()) |
|
results["focal_gt"].append(gt_cam.f[1].item()) |
|
|
|
results["k1_gt"].append(gt_cam.k1.item()) |
|
|
|
if "camera" in pred: |
|
|
|
pred_cam = self.CameraModel(pred["camera"]) |
|
|
|
pred_camera = pred_cam[None].undo_scale_crop(data)[0] |
|
gt_camera = gt_cam[None].undo_scale_crop(data)[0] |
|
|
|
camera_metrics = compute_camera_metrics( |
|
pred_camera, gt_camera, conf.pixel_thresholds |
|
) |
|
|
|
for k, v in camera_metrics.items(): |
|
results[k].extend(v) |
|
|
|
if "focal_uncertainty" in pred: |
|
focal_uncertainty = pred["focal_uncertainty"] |
|
results["focal_uncertainty"].append(focal_uncertainty.item()) |
|
|
|
if "vfov_uncertainty" in pred: |
|
vfov_uncertainty = rad2deg(pred["vfov_uncertainty"]) |
|
results["vfov_uncertainty"].append(vfov_uncertainty.item()) |
|
|
|
if "gravity" in pred: |
|
|
|
pred_gravity = Gravity(pred["gravity"]) |
|
|
|
gravity_metrics = compute_gravity_metrics(pred_gravity, gt_gravity) |
|
for k, v in gravity_metrics.items(): |
|
results[k].extend(v) |
|
|
|
if "roll_uncertainty" in pred: |
|
roll_uncertainty = rad2deg(pred["roll_uncertainty"]) |
|
results["roll_uncertainty"].append(roll_uncertainty.item()) |
|
|
|
if "pitch_uncertainty" in pred: |
|
pitch_uncertainty = rad2deg(pred["pitch_uncertainty"]) |
|
results["pitch_uncertainty"].append(pitch_uncertainty.item()) |
|
|
|
if "gravity_uncertainty" in pred: |
|
gravity_uncertainty = rad2deg(pred["gravity_uncertainty"]) |
|
results["gravity_uncertainty"].append(gravity_uncertainty.item()) |
|
|
|
if "up_field" in pred: |
|
up_err = up_error(pred["up_field"].unsqueeze(0), data["up_field"]) |
|
results["up_error"].append(up_err.mean(axis=(1, 2)).item()) |
|
results["up_med_error"].append(up_err.median().item()) |
|
|
|
if "up_confidence" in pred: |
|
up_confidence = pred["up_confidence"].unsqueeze(0) |
|
weighted_error = (up_err * up_confidence).sum(axis=(1, 2)) |
|
weighted_error = weighted_error / up_confidence.sum(axis=(1, 2)) |
|
results["up_weighted_error"].append(weighted_error.item()) |
|
|
|
if i < self.num_vis: |
|
pred_batched = add_batch_dim(pred) |
|
up_fig = visualize_batch.make_up_figure(pred=pred_batched, data=data) |
|
up_fig = up_fig["up"] |
|
plt.tight_layout() |
|
viz2d.save_plot(save_to / f"up-{i}-{up_err.median().item():.3f}.jpg") |
|
plt.close() |
|
|
|
if "latitude_field" in pred: |
|
lat_err = latitude_error( |
|
pred["latitude_field"].unsqueeze(0), data["latitude_field"] |
|
) |
|
results["latitude_error"].append(lat_err.mean(axis=(1, 2)).item()) |
|
results["latitude_med_error"].append(lat_err.median().item()) |
|
|
|
if "latitude_confidence" in pred: |
|
lat_confidence = pred["latitude_confidence"].unsqueeze(0) |
|
weighted_error = (lat_err * lat_confidence).sum(axis=(1, 2)) |
|
weighted_error = weighted_error / lat_confidence.sum(axis=(1, 2)) |
|
results["latitude_weighted_error"].append(weighted_error.item()) |
|
|
|
if i < self.num_vis: |
|
pred_batched = add_batch_dim(pred) |
|
lat_fig = visualize_batch.make_latitude_figure(pred=pred_batched, data=data) |
|
lat_fig = lat_fig["latitude"] |
|
plt.tight_layout() |
|
viz2d.save_plot(save_to / f"latitude-{i}-{lat_err.median().item():.3f}.jpg") |
|
plt.close() |
|
|
|
summaries = {} |
|
for k, v in results.items(): |
|
arr = np.array(v) |
|
if not np.issubdtype(np.array(v).dtype, np.number): |
|
continue |
|
|
|
if k.endswith("_error") or "recall" in k or "pixel" in k: |
|
summaries[f"mean_{k}"] = round(np.nanmean(arr), 3) |
|
summaries[f"median_{k}"] = round(np.nanmedian(arr), 3) |
|
|
|
if any(keyword in k for keyword in ["roll", "pitch", "vfov", "gravity"]): |
|
if not conf.thresholds: |
|
continue |
|
|
|
auc = AUCMetric( |
|
elements=arr, thresholds=list(conf.thresholds), min_error=1 |
|
).compute() |
|
for i, t in enumerate(conf.thresholds): |
|
summaries[f"auc_{k}@{t}"] = round(auc[i], 3) |
|
|
|
return summaries, self.get_figures(results), results |
|
|
|
|
|
if __name__ == "__main__": |
|
dataset_name = Path(__file__).stem |
|
parser = get_eval_parser() |
|
args = parser.parse_intermixed_args() |
|
|
|
default_conf = OmegaConf.create(SimplePipeline.default_conf) |
|
|
|
|
|
output_dir = Path(EVAL_PATH, dataset_name) |
|
output_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
name, conf = parse_eval_args(dataset_name, args, "configs/", default_conf) |
|
|
|
experiment_dir = output_dir / name |
|
experiment_dir.mkdir(exist_ok=True) |
|
|
|
pipeline = SimplePipeline(conf) |
|
s, f, r = pipeline.run( |
|
experiment_dir, overwrite=args.overwrite, overwrite_eval=args.overwrite_eval |
|
) |
|
|
|
pprint(s) |
|
|
|
if args.plot: |
|
for name, fig in f.items(): |
|
fig.canvas.manager.set_window_title(name) |
|
plt.show() |
|
|