File size: 4,565 Bytes
205a7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0fc7b40
205a7af
0fc7b40
205a7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
"""Simple interface for GeoCalib model."""

from pathlib import Path
from typing import Dict, Optional

import torch
import torch.nn as nn
from torch.nn.functional import interpolate

from siclib.geometry.base_camera import BaseCamera
from siclib.models.networks.geocalib import GeoCalib as Model
from siclib.utils.image import ImagePreprocessor, load_image


class GeoCalib(nn.Module):
    """Simple interface for GeoCalib model."""

    def __init__(self, weights: str = "pinhole"):
        """Initialize the model with optional config overrides.

        Args:
            weights (str, optional): Weights to load. Defaults to "pinhole".
        """
        super().__init__()
        if weights not in {"pinhole", "distorted"}:
            raise ValueError(f"Unknown weights: {weights}")
        url = f"https://github.com/cvg/GeoCalib/releases/download/v1.0/geocalib-{weights}.tar"

        # load checkpoint
        model_dir = f"{torch.hub.get_dir()}/geocalib"
        state_dict = torch.hub.load_state_dict_from_url(
            url, model_dir, map_location="cpu", file_name=f"{weights}.tar"
        )

        self.model = Model({})
        self.model.flexible_load(state_dict["model"])
        self.model.eval()

        self.image_processor = ImagePreprocessor({"resize": 320, "edge_divisible_by": 32})

    def load_image(self, path: Path) -> torch.Tensor:
        """Load image from path."""
        return load_image(path)

    def _post_process(
        self, camera: BaseCamera, img_data: dict[str, torch.Tensor], out: dict[str, torch.Tensor]
    ) -> tuple[BaseCamera, dict[str, torch.Tensor]]:
        """Post-process model output by undoing scaling and cropping."""
        camera = camera.undo_scale_crop(img_data)

        w, h = camera.size.unbind(-1)
        h = h[0].round().int().item()
        w = w[0].round().int().item()

        for k in ["latitude_field", "up_field"]:
            out[k] = interpolate(out[k], size=(h, w), mode="bilinear")
        for k in ["up_confidence", "latitude_confidence"]:
            out[k] = interpolate(out[k][:, None], size=(h, w), mode="bilinear")[:, 0]

        inverse_scales = 1.0 / img_data["scales"]
        zero = camera.new_zeros(camera.f.shape[0])
        out["focal_uncertainty"] = out.get("focal_uncertainty", zero) * inverse_scales[1]
        return camera, out

    @torch.no_grad()
    def calibrate(
        self,
        img: torch.Tensor,
        camera_model: str = "pinhole",
        priors: Optional[Dict[str, torch.Tensor]] = None,
        shared_intrinsics: bool = False,
    ) -> Dict[str, torch.Tensor]:
        """Perform calibration with online resizing.

        Assumes input image is in range [0, 1] and in RGB format.

        Args:
            img (torch.Tensor): Input image, shape (C, H, W) or (1, C, H, W)
            camera_model (str, optional): Camera model. Defaults to "pinhole".
            priors (Dict[str, torch.Tensor], optional): Prior parameters. Defaults to {}.
            shared_intrinsics (bool, optional): Whether to share intrinsics. Defaults to False.

        Returns:
            Dict[str, torch.Tensor]: camera and gravity vectors and uncertainties.
        """
        if len(img.shape) == 3:
            img = img[None]  # add batch dim
        if not shared_intrinsics:
            assert len(img.shape) == 4 and img.shape[0] == 1

        img_data = self.image_processor(img)

        if priors is None:
            priors = {}

        prior_values = {}
        if prior_focal := priors.get("focal"):
            prior_focal = prior_focal[None] if len(prior_focal.shape) == 0 else prior_focal
            prior_values["prior_focal"] = prior_focal * img_data["scales"][1]

        if "gravity" in priors:
            prior_gravity = priors["gravity"]
            prior_gravity = prior_gravity[None] if len(prior_gravity.shape) == 0 else prior_gravity
            prior_values["prior_gravity"] = prior_gravity

        self.model.optimizer.set_camera_model(camera_model)
        self.model.optimizer.shared_intrinsics = shared_intrinsics

        out = self.model(img_data | prior_values)

        camera, gravity = out["camera"], out["gravity"]
        camera, out = self._post_process(camera, img_data, out)

        return {
            "camera": camera,
            "gravity": gravity,
            "covariance": out["covariance"],
            **{k: out[k] for k in out.keys() if "field" in k},
            **{k: out[k] for k in out.keys() if "confidence" in k},
            **{k: out[k] for k in out.keys() if "uncertainty" in k},
        }