|
import logging |
|
from typing import Dict |
|
|
|
import torch |
|
|
|
from siclib.geometry.base_camera import BaseCamera |
|
from siclib.geometry.gravity import Gravity |
|
from siclib.utils.conversions import deg2rad, focal2fov |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
def get_initial_estimation( |
|
data: Dict[str, torch.Tensor], camera_model: BaseCamera, trivial_init: bool = True |
|
) -> BaseCamera: |
|
"""Get initial camera for optimization using heuristics.""" |
|
return ( |
|
get_trivial_estimation(data, camera_model) |
|
if trivial_init |
|
else get_heuristic_estimation(data, camera_model) |
|
) |
|
|
|
|
|
def get_heuristic_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera: |
|
"""Get initial camera for optimization using heuristics. |
|
|
|
Initial camera is initialized with the following heuristics: |
|
- roll is the angle of the up vector at the center of the image |
|
- pitch is the value at the center of the latitude map |
|
- vfov is the difference between the central top and bottom of the latitude map |
|
- distortions are set to zero |
|
|
|
Use the prior values if available. |
|
|
|
Args: |
|
data (Dict[str, torch.Tensor]): Input data dictionary. |
|
camera_model (BaseCamera): Camera model to use. |
|
|
|
Returns: |
|
BaseCamera: Initial camera for optimization. |
|
""" |
|
up_ref = data["up_field"].detach() |
|
latitude_ref = data["latitude_field"].detach() |
|
|
|
h, w = up_ref.shape[-2:] |
|
batch_h, batch_w = ( |
|
up_ref.new_ones((up_ref.shape[0],)) * h, |
|
up_ref.new_ones((up_ref.shape[0],)) * w, |
|
) |
|
|
|
|
|
init_r = -torch.atan2( |
|
up_ref[:, 0, int(h / 2), int(w / 2)], -up_ref[:, 1, int(h / 2), int(w / 2)] |
|
) |
|
init_r = init_r.clamp(min=-deg2rad(45), max=deg2rad(45)) |
|
|
|
|
|
init_p = latitude_ref[:, 0, int(h / 2), int(w / 2)] |
|
init_p = init_p.clamp(min=-deg2rad(45), max=deg2rad(45)) |
|
|
|
|
|
init_vfov = latitude_ref[:, 0, 0, int(w / 2)] - latitude_ref[:, 0, -1, int(w / 2)] |
|
init_vfov = torch.abs(init_vfov) |
|
init_vfov = init_vfov.clamp(min=deg2rad(20), max=deg2rad(120)) |
|
|
|
focal = data.get("prior_focal") |
|
init_vfov = init_vfov if focal is None else focal2fov(focal, h) |
|
|
|
params = {"width": batch_w, "height": batch_h, "vfov": init_vfov} |
|
params |= {"scales": data["scales"]} if "scales" in data else {} |
|
params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {} |
|
camera = camera_model.from_dict(params) |
|
camera = camera.float().to(data["up_field"].device) |
|
|
|
gravity = Gravity.from_rp(init_r, init_p).float().to(data["up_field"].device) |
|
if "prior_gravity" in data: |
|
gravity = data["prior_gravity"].float().to(up_ref.device) |
|
|
|
return camera, gravity |
|
|
|
|
|
def get_trivial_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera: |
|
"""Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w). |
|
|
|
Args: |
|
data (Dict[str, torch.Tensor]): Input data dictionary. |
|
camera_model (BaseCamera): Camera model to use. |
|
|
|
Returns: |
|
BaseCamera: Initial camera for optimization. |
|
""" |
|
"""Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w).""" |
|
ref = data.get("up_field", data["latitude_field"]) |
|
ref = ref.detach() |
|
|
|
h, w = ref.shape[-2:] |
|
batch_h, batch_w = ( |
|
ref.new_ones((ref.shape[0],)) * h, |
|
ref.new_ones((ref.shape[0],)) * w, |
|
) |
|
|
|
init_r = ref.new_zeros((ref.shape[0],)) |
|
init_p = ref.new_zeros((ref.shape[0],)) |
|
|
|
focal = data.get("prior_focal", 0.7 * torch.max(batch_h, batch_w)) |
|
init_vfov = init_vfov if focal is None else focal2fov(focal, h) |
|
|
|
params = {"width": batch_w, "height": batch_h, "vfov": init_vfov} |
|
params |= {"scales": data["scales"]} if "scales" in data else {} |
|
params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {} |
|
camera = camera_model.from_dict(params) |
|
camera = camera.float().to(ref.device) |
|
|
|
gravity = Gravity.from_rp(init_r, init_p).float().to(ref.device) |
|
|
|
if "prior_gravity" in data: |
|
gravity = data["prior_gravity"].float().to(ref.device) |
|
|
|
return camera, gravity |
|
|
|
|
|
def early_stop(new_cost: torch.Tensor, prev_cost: torch.Tensor, atol: float, rtol: float) -> bool: |
|
"""Early stopping criterion based on cost convergence.""" |
|
return torch.allclose(new_cost, prev_cost, atol=atol, rtol=rtol) |
|
|
|
|
|
def update_lambda( |
|
lamb: torch.Tensor, |
|
prev_cost: torch.Tensor, |
|
new_cost: torch.Tensor, |
|
lambda_min: float = 1e-6, |
|
lambda_max: float = 1e2, |
|
) -> torch.Tensor: |
|
"""Update damping factor for Levenberg-Marquardt optimization.""" |
|
new_lamb = lamb.new_zeros(lamb.shape) |
|
new_lamb = lamb * torch.where(new_cost > prev_cost, 10, 0.1) |
|
lamb = torch.clamp(new_lamb, lambda_min, lambda_max) |
|
return lamb |
|
|
|
|
|
def optimizer_step( |
|
G: torch.Tensor, H: torch.Tensor, lambda_: torch.Tensor, eps: float = 1e-6 |
|
) -> torch.Tensor: |
|
"""One optimization step with Gauss-Newton or Levenberg-Marquardt. |
|
|
|
Args: |
|
G (torch.Tensor): Batched gradient tensor of size (..., N). |
|
H (torch.Tensor): Batched hessian tensor of size (..., N, N). |
|
lambda_ (torch.Tensor): Damping factor for LM (use GN if lambda_=0) with shape (B,). |
|
eps (float, optional): Epsilon for damping. Defaults to 1e-6. |
|
|
|
Returns: |
|
torch.Tensor: Batched update tensor of size (..., N). |
|
""" |
|
diag = H.diagonal(dim1=-2, dim2=-1) |
|
diag = diag * lambda_.unsqueeze(-1) |
|
|
|
H = H + diag.clamp(min=eps).diag_embed() |
|
|
|
H_, G_ = H.cpu(), G.cpu() |
|
try: |
|
U = torch.linalg.cholesky(H_) |
|
except RuntimeError: |
|
logger.warning("Cholesky decomposition failed. Stopping.") |
|
delta = H.new_zeros((H.shape[0], H.shape[-1])) |
|
else: |
|
delta = torch.cholesky_solve(G_[..., None], U)[..., 0] |
|
|
|
return delta.to(H.device) |
|
|