veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
6.08 kB
import logging
from typing import Dict
import torch
from siclib.geometry.base_camera import BaseCamera
from siclib.geometry.gravity import Gravity
from siclib.utils.conversions import deg2rad, focal2fov
logger = logging.getLogger(__name__)
# flake8: noqa
# mypy: ignore-errors
def get_initial_estimation(
data: Dict[str, torch.Tensor], camera_model: BaseCamera, trivial_init: bool = True
) -> BaseCamera:
"""Get initial camera for optimization using heuristics."""
return (
get_trivial_estimation(data, camera_model)
if trivial_init
else get_heuristic_estimation(data, camera_model)
)
def get_heuristic_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera:
"""Get initial camera for optimization using heuristics.
Initial camera is initialized with the following heuristics:
- roll is the angle of the up vector at the center of the image
- pitch is the value at the center of the latitude map
- vfov is the difference between the central top and bottom of the latitude map
- distortions are set to zero
Use the prior values if available.
Args:
data (Dict[str, torch.Tensor]): Input data dictionary.
camera_model (BaseCamera): Camera model to use.
Returns:
BaseCamera: Initial camera for optimization.
"""
up_ref = data["up_field"].detach()
latitude_ref = data["latitude_field"].detach()
h, w = up_ref.shape[-2:]
batch_h, batch_w = (
up_ref.new_ones((up_ref.shape[0],)) * h,
up_ref.new_ones((up_ref.shape[0],)) * w,
)
# init roll is angle of the up vector at the center of the image
init_r = -torch.atan2(
up_ref[:, 0, int(h / 2), int(w / 2)], -up_ref[:, 1, int(h / 2), int(w / 2)]
)
init_r = init_r.clamp(min=-deg2rad(45), max=deg2rad(45))
# init pitch is the value at the center of the latitude map
init_p = latitude_ref[:, 0, int(h / 2), int(w / 2)]
init_p = init_p.clamp(min=-deg2rad(45), max=deg2rad(45))
# init vfov is the difference between the central top and bottom of the latitude map
init_vfov = latitude_ref[:, 0, 0, int(w / 2)] - latitude_ref[:, 0, -1, int(w / 2)]
init_vfov = torch.abs(init_vfov)
init_vfov = init_vfov.clamp(min=deg2rad(20), max=deg2rad(120))
focal = data.get("prior_focal")
init_vfov = init_vfov if focal is None else focal2fov(focal, h)
params = {"width": batch_w, "height": batch_h, "vfov": init_vfov}
params |= {"scales": data["scales"]} if "scales" in data else {}
params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {}
camera = camera_model.from_dict(params)
camera = camera.float().to(data["up_field"].device)
gravity = Gravity.from_rp(init_r, init_p).float().to(data["up_field"].device)
if "prior_gravity" in data:
gravity = data["prior_gravity"].float().to(up_ref.device)
return camera, gravity
def get_trivial_estimation(data: Dict[str, torch.Tensor], camera_model: BaseCamera) -> BaseCamera:
"""Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w).
Args:
data (Dict[str, torch.Tensor]): Input data dictionary.
camera_model (BaseCamera): Camera model to use.
Returns:
BaseCamera: Initial camera for optimization.
"""
"""Get initial camera for optimization with roll=0, pitch=0, vfov=0.7 * max(h, w)."""
ref = data.get("up_field", data["latitude_field"])
ref = ref.detach()
h, w = ref.shape[-2:]
batch_h, batch_w = (
ref.new_ones((ref.shape[0],)) * h,
ref.new_ones((ref.shape[0],)) * w,
)
init_r = ref.new_zeros((ref.shape[0],))
init_p = ref.new_zeros((ref.shape[0],))
focal = data.get("prior_focal", 0.7 * torch.max(batch_h, batch_w))
init_vfov = init_vfov if focal is None else focal2fov(focal, h)
params = {"width": batch_w, "height": batch_h, "vfov": init_vfov}
params |= {"scales": data["scales"]} if "scales" in data else {}
params |= {"k1": data["prior_k1"]} if "prior_k1" in data else {}
camera = camera_model.from_dict(params)
camera = camera.float().to(ref.device)
gravity = Gravity.from_rp(init_r, init_p).float().to(ref.device)
if "prior_gravity" in data:
gravity = data["prior_gravity"].float().to(ref.device)
return camera, gravity
def early_stop(new_cost: torch.Tensor, prev_cost: torch.Tensor, atol: float, rtol: float) -> bool:
"""Early stopping criterion based on cost convergence."""
return torch.allclose(new_cost, prev_cost, atol=atol, rtol=rtol)
def update_lambda(
lamb: torch.Tensor,
prev_cost: torch.Tensor,
new_cost: torch.Tensor,
lambda_min: float = 1e-6,
lambda_max: float = 1e2,
) -> torch.Tensor:
"""Update damping factor for Levenberg-Marquardt optimization."""
new_lamb = lamb.new_zeros(lamb.shape)
new_lamb = lamb * torch.where(new_cost > prev_cost, 10, 0.1)
lamb = torch.clamp(new_lamb, lambda_min, lambda_max)
return lamb
def optimizer_step(
G: torch.Tensor, H: torch.Tensor, lambda_: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
"""One optimization step with Gauss-Newton or Levenberg-Marquardt.
Args:
G (torch.Tensor): Batched gradient tensor of size (..., N).
H (torch.Tensor): Batched hessian tensor of size (..., N, N).
lambda_ (torch.Tensor): Damping factor for LM (use GN if lambda_=0) with shape (B,).
eps (float, optional): Epsilon for damping. Defaults to 1e-6.
Returns:
torch.Tensor: Batched update tensor of size (..., N).
"""
diag = H.diagonal(dim1=-2, dim2=-1)
diag = diag * lambda_.unsqueeze(-1) # (B, 3)
H = H + diag.clamp(min=eps).diag_embed()
H_, G_ = H.cpu(), G.cpu()
try:
U = torch.linalg.cholesky(H_)
except RuntimeError:
logger.warning("Cholesky decomposition failed. Stopping.")
delta = H.new_zeros((H.shape[0], H.shape[-1])) # (B, 3)
else:
delta = torch.cholesky_solve(G_[..., None], U)[..., 0]
return delta.to(H.device)