|
"""Simple interface for GeoCalib model.""" |
|
|
|
from pathlib import Path |
|
from typing import Dict, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn.functional import interpolate |
|
|
|
from siclib.geometry.base_camera import BaseCamera |
|
from siclib.models.networks.geocalib import GeoCalib as Model |
|
from siclib.utils.image import ImagePreprocessor, load_image |
|
|
|
|
|
class GeoCalib(nn.Module): |
|
"""Simple interface for GeoCalib model.""" |
|
|
|
def __init__(self, weights: str = "pinhole"): |
|
"""Initialize the model with optional config overrides. |
|
|
|
Args: |
|
weights (str, optional): Weights to load. Defaults to "pinhole". |
|
""" |
|
super().__init__() |
|
if weights not in {"pinhole", "distorted"}: |
|
raise ValueError(f"Unknown weights: {weights}") |
|
url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar" |
|
|
|
|
|
model_dir = f"{torch.hub.get_dir()}/geocalib" |
|
state_dict = torch.hub.load_state_dict_from_url( |
|
url, model_dir, map_location="cpu", file_name=f"{weights}.tar" |
|
) |
|
|
|
self.model = Model({}) |
|
self.model.flexible_load(state_dict["model"]) |
|
self.model.eval() |
|
|
|
self.image_processor = ImagePreprocessor({"resize": 320, "edge_divisible_by": 32}) |
|
|
|
def load_image(self, path: Path) -> torch.Tensor: |
|
"""Load image from path.""" |
|
return load_image(path) |
|
|
|
def _post_process( |
|
self, camera: BaseCamera, img_data: dict[str, torch.Tensor], out: dict[str, torch.Tensor] |
|
) -> tuple[BaseCamera, dict[str, torch.Tensor]]: |
|
"""Post-process model output by undoing scaling and cropping.""" |
|
camera = camera.undo_scale_crop(img_data) |
|
|
|
w, h = camera.size.unbind(-1) |
|
h = h[0].round().int().item() |
|
w = w[0].round().int().item() |
|
|
|
for k in ["latitude_field", "up_field"]: |
|
out[k] = interpolate(out[k], size=(h, w), mode="bilinear") |
|
for k in ["up_confidence", "latitude_confidence"]: |
|
out[k] = interpolate(out[k][:, None], size=(h, w), mode="bilinear")[:, 0] |
|
|
|
inverse_scales = 1.0 / img_data["scales"] |
|
zero = camera.new_zeros(camera.f.shape[0]) |
|
out["focal_uncertainty"] = out.get("focal_uncertainty", zero) * inverse_scales[1] |
|
return camera, out |
|
|
|
@torch.no_grad() |
|
def calibrate( |
|
self, |
|
img: torch.Tensor, |
|
camera_model: str = "pinhole", |
|
priors: Optional[Dict[str, torch.Tensor]] = None, |
|
shared_intrinsics: bool = False, |
|
) -> Dict[str, torch.Tensor]: |
|
"""Perform calibration with online resizing. |
|
|
|
Assumes input image is in range [0, 1] and in RGB format. |
|
|
|
Args: |
|
img (torch.Tensor): Input image, shape (C, H, W) or (1, C, H, W) |
|
camera_model (str, optional): Camera model. Defaults to "pinhole". |
|
priors (Dict[str, torch.Tensor], optional): Prior parameters. Defaults to {}. |
|
shared_intrinsics (bool, optional): Whether to share intrinsics. Defaults to False. |
|
|
|
Returns: |
|
Dict[str, torch.Tensor]: camera and gravity vectors and uncertainties. |
|
""" |
|
if len(img.shape) == 3: |
|
img = img[None] |
|
if not shared_intrinsics: |
|
assert len(img.shape) == 4 and img.shape[0] == 1 |
|
|
|
img_data = self.image_processor(img) |
|
|
|
if priors is None: |
|
priors = {} |
|
|
|
prior_values = {} |
|
if prior_focal := priors.get("focal"): |
|
prior_focal = prior_focal[None] if len(prior_focal.shape) == 0 else prior_focal |
|
prior_values["prior_focal"] = prior_focal * img_data["scales"][1] |
|
|
|
if "gravity" in priors: |
|
prior_gravity = priors["gravity"] |
|
prior_gravity = prior_gravity[None] if len(prior_gravity.shape) == 0 else prior_gravity |
|
prior_values["prior_gravity"] = prior_gravity |
|
|
|
self.model.optimizer.set_camera_model(camera_model) |
|
self.model.optimizer.shared_intrinsics = shared_intrinsics |
|
|
|
out = self.model(img_data | prior_values) |
|
|
|
camera, gravity = out["camera"], out["gravity"] |
|
camera, out = self._post_process(camera, img_data, out) |
|
|
|
return { |
|
"camera": camera, |
|
"gravity": gravity, |
|
"covariance": out["covariance"], |
|
**{k: out[k] for k in out.keys() if "field" in k}, |
|
**{k: out[k] for k in out.keys() if "confidence" in k}, |
|
**{k: out[k] for k in out.keys() if "uncertainty" in k}, |
|
} |
|
|