"""Simple interface for GeoCalib model.""" from pathlib import Path from typing import Dict, Optional import torch import torch.nn as nn from torch.nn.functional import interpolate from siclib.geometry.base_camera import BaseCamera from siclib.models.networks.geocalib import GeoCalib as Model from siclib.utils.image import ImagePreprocessor, load_image class GeoCalib(nn.Module): """Simple interface for GeoCalib model.""" def __init__(self, weights: str = "pinhole"): """Initialize the model with optional config overrides. Args: weights (str, optional): Weights to load. Defaults to "pinhole". """ super().__init__() if weights not in {"pinhole", "distorted"}: raise ValueError(f"Unknown weights: {weights}") url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar" # load checkpoint model_dir = f"{torch.hub.get_dir()}/geocalib" state_dict = torch.hub.load_state_dict_from_url( url, model_dir, map_location="cpu", file_name=f"{weights}.tar" ) self.model = Model({}) self.model.flexible_load(state_dict["model"]) self.model.eval() self.image_processor = ImagePreprocessor({"resize": 320, "edge_divisible_by": 32}) def load_image(self, path: Path) -> torch.Tensor: """Load image from path.""" return load_image(path) def _post_process( self, camera: BaseCamera, img_data: dict[str, torch.Tensor], out: dict[str, torch.Tensor] ) -> tuple[BaseCamera, dict[str, torch.Tensor]]: """Post-process model output by undoing scaling and cropping.""" camera = camera.undo_scale_crop(img_data) w, h = camera.size.unbind(-1) h = h[0].round().int().item() w = w[0].round().int().item() for k in ["latitude_field", "up_field"]: out[k] = interpolate(out[k], size=(h, w), mode="bilinear") for k in ["up_confidence", "latitude_confidence"]: out[k] = interpolate(out[k][:, None], size=(h, w), mode="bilinear")[:, 0] inverse_scales = 1.0 / img_data["scales"] zero = camera.new_zeros(camera.f.shape[0]) out["focal_uncertainty"] = out.get("focal_uncertainty", zero) * inverse_scales[1] return camera, out @torch.no_grad() def calibrate( self, img: torch.Tensor, camera_model: str = "pinhole", priors: Optional[Dict[str, torch.Tensor]] = None, shared_intrinsics: bool = False, ) -> Dict[str, torch.Tensor]: """Perform calibration with online resizing. Assumes input image is in range [0, 1] and in RGB format. Args: img (torch.Tensor): Input image, shape (C, H, W) or (1, C, H, W) camera_model (str, optional): Camera model. Defaults to "pinhole". priors (Dict[str, torch.Tensor], optional): Prior parameters. Defaults to {}. shared_intrinsics (bool, optional): Whether to share intrinsics. Defaults to False. Returns: Dict[str, torch.Tensor]: camera and gravity vectors and uncertainties. """ if len(img.shape) == 3: img = img[None] # add batch dim if not shared_intrinsics: assert len(img.shape) == 4 and img.shape[0] == 1 img_data = self.image_processor(img) if priors is None: priors = {} prior_values = {} if prior_focal := priors.get("focal"): prior_focal = prior_focal[None] if len(prior_focal.shape) == 0 else prior_focal prior_values["prior_focal"] = prior_focal * img_data["scales"][1] if "gravity" in priors: prior_gravity = priors["gravity"] prior_gravity = prior_gravity[None] if len(prior_gravity.shape) == 0 else prior_gravity prior_values["prior_gravity"] = prior_gravity self.model.optimizer.set_camera_model(camera_model) self.model.optimizer.shared_intrinsics = shared_intrinsics out = self.model(img_data | prior_values) camera, gravity = out["camera"], out["gravity"] camera, out = self._post_process(camera, img_data, out) return { "camera": camera, "gravity": gravity, "covariance": out["covariance"], **{k: out[k] for k in out.keys() if "field" in k}, **{k: out[k] for k in out.keys() if "confidence" in k}, **{k: out[k] for k in out.keys() if "uncertainty" in k}, }