veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
23.2 kB
import logging
import time
from typing import Dict, Tuple
import torch
from torch import nn
import siclib.models.optimization.losses as losses
from siclib.geometry.base_camera import BaseCamera
from siclib.geometry.camera import camera_models
from siclib.geometry.gravity import Gravity
from siclib.geometry.jacobians import J_focal2fov
from siclib.geometry.perspective_fields import J_perspective_field, get_perspective_field
from siclib.models import get_model
from siclib.models.base_model import BaseModel
from siclib.models.optimization.utils import (
early_stop,
get_initial_estimation,
optimizer_step,
update_lambda,
)
from siclib.models.utils.metrics import (
dist_error,
gravity_error,
pitch_error,
roll_error,
vfov_error,
)
from siclib.utils.conversions import rad2deg
logger = logging.getLogger(__name__)
# flake8: noqa
# mypy: ignore-errors
class LMOptimizer(BaseModel):
default_conf = {
# Camera model parameters
"camera_model": "pinhole", # {"pinhole", "simple_radial", "simple_spherical"}
"shared_intrinsics": False, # share focal length across all images in batch
# LM optimizer parameters
"num_steps": 10,
"lambda_": 0.1,
"fix_lambda": False,
"early_stop": False,
"atol": 1e-8,
"rtol": 1e-8,
"use_spherical_manifold": True, # use spherical manifold for gravity optimization
"use_log_focal": True, # use log focal length for optimization
# Loss function parameters
"loss_fn": "squared_loss", # {"squared_loss", "huber_loss"}
"up_loss_fn_scale": 1e-2,
"lat_loss_fn_scale": 1e-2,
"init_conf": {"name": "trivial"}, # pass config of other models to use as initializer
# Misc
"loss_weight": 1,
"verbose": False,
}
def _init(self, conf):
self.loss_fn = getattr(losses, conf.loss_fn)
self.num_steps = conf.num_steps
self.set_camera_model(conf.camera_model)
self.setup_optimization_and_priors(shared_intrinsics=conf.shared_intrinsics)
self.initializer = None
if self.conf.init_conf.name not in ["trivial", "heuristic"]:
self.initializer = get_model(conf.init_conf.name)(conf.init_conf)
def set_camera_model(self, camera_model: str) -> None:
"""Set the camera model to use for the optimization.
Args:
camera_model (str): Camera model to use.
"""
assert (
camera_model in camera_models.keys()
), f"Unknown camera model: {camera_model} not in {camera_models.keys()}"
self.camera_model = camera_models[camera_model]
self.camera_has_distortion = hasattr(self.camera_model, "dist")
logger.debug(
f"Using camera model: {camera_model} (with distortion: {self.camera_has_distortion})"
)
def setup_optimization_and_priors(
self, data: Dict[str, torch.Tensor] = None, shared_intrinsics: bool = False
) -> None:
"""Setup the optimization and priors for the LM optimizer.
Args:
data (Dict[str, torch.Tensor], optional): Dict potentially containing priors. Defaults
to None.
shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch.
Defaults to False.
"""
if data is None:
data = {}
self.shared_intrinsics = shared_intrinsics
if shared_intrinsics: # si => must use pinhole
assert (
self.camera_model == camera_models["pinhole"]
), f"Shared intrinsics only supported with pinhole camera model: {self.camera_model}"
self.estimate_gravity = True
if "prior_gravity" in data:
self.estimate_gravity = False
logger.debug("Using provided gravity as prior.")
self.estimate_focal = True
if "prior_focal" in data:
self.estimate_focal = False
logger.debug("Using provided focal as prior.")
self.estimate_k1 = True
if "prior_k1" in data:
self.estimate_k1 = False
logger.debug("Using provided k1 as prior.")
self.gravity_delta_dims = (0, 1) if self.estimate_gravity else (-1,)
self.focal_delta_dims = (
(max(self.gravity_delta_dims) + 1,) if self.estimate_focal else (-1,)
)
self.k1_delta_dims = (max(self.focal_delta_dims) + 1,) if self.estimate_k1 else (-1,)
logger.debug(f"Camera Model: {self.camera_model}")
logger.debug(f"Optimizing gravity: {self.estimate_gravity} ({self.gravity_delta_dims})")
logger.debug(f"Optimizing focal: {self.estimate_focal} ({self.focal_delta_dims})")
logger.debug(f"Optimizing k1: {self.estimate_k1} ({self.k1_delta_dims})")
logger.debug(f"Shared intrinsics: {self.shared_intrinsics}")
def calculate_residuals(
self, camera: BaseCamera, gravity: Gravity, data: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Calculate the residuals for the optimization.
Args:
camera (BaseCamera): Optimized camera.
gravity (Gravity): Optimized gravity.
data (Dict[str, torch.Tensor]): Input data containing the up and latitude fields.
Returns:
Dict[str, torch.Tensor]: Residuals for the optimization.
"""
perspective_up, perspective_lat = get_perspective_field(camera, gravity)
perspective_lat = torch.sin(perspective_lat)
residuals = {}
if "up_field" in data:
up_residual = (data["up_field"] - perspective_up).permute(0, 2, 3, 1)
residuals["up_residual"] = up_residual.reshape(up_residual.shape[0], -1, 2)
if "latitude_field" in data:
target_lat = torch.sin(data["latitude_field"])
lat_residual = (target_lat - perspective_lat).permute(0, 2, 3, 1)
residuals["latitude_residual"] = lat_residual.reshape(lat_residual.shape[0], -1, 1)
return residuals
def calculate_costs(
self, residuals: torch.Tensor, data: Dict[str, torch.Tensor]
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""Calculate the costs and weights for the optimization.
Args:
residuals (torch.Tensor): Residuals for the optimization.
data (Dict[str, torch.Tensor]): Input data containing the up and latitude confidence.
Returns:
Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: Costs and weights for the
optimization.
"""
costs, weights = {}, {}
if "up_residual" in residuals:
up_cost = (residuals["up_residual"] ** 2).sum(dim=-1)
up_cost, up_weight, _ = losses.scaled_loss(
up_cost, self.loss_fn, self.conf.up_loss_fn_scale
)
if "up_confidence" in data:
up_conf = data["up_confidence"].reshape(up_weight.shape[0], -1)
up_weight = up_weight * up_conf
up_cost = up_cost * up_conf
costs["up_cost"] = up_cost
weights["up_weights"] = up_weight
if "latitude_residual" in residuals:
lat_cost = (residuals["latitude_residual"] ** 2).sum(dim=-1)
lat_cost, lat_weight, _ = losses.scaled_loss(
lat_cost, self.loss_fn, self.conf.lat_loss_fn_scale
)
if "latitude_confidence" in data:
lat_conf = data["latitude_confidence"].reshape(lat_weight.shape[0], -1)
lat_weight = lat_weight * lat_conf
lat_cost = lat_cost * lat_conf
costs["latitude_cost"] = lat_cost
weights["latitude_weights"] = lat_weight
return costs, weights
def calculate_gradient_and_hessian(
self,
J: torch.Tensor,
residuals: torch.Tensor,
weights: torch.Tensor,
shared_intrinsics: bool,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate the gradient and Hessian for given the Jacobian, residuals, and weights.
Args:
J (torch.Tensor): Jacobian.
residuals (torch.Tensor): Residuals.
weights (torch.Tensor): Weights.
shared_intrinsics (bool): Whether to share the intrinsics across the batch.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian.
"""
dims = ()
if self.estimate_gravity:
dims = (0, 1)
if self.estimate_focal:
dims += (2,)
if self.camera_has_distortion and self.estimate_k1:
dims += (3,)
assert dims, "No parameters to optimize"
J = J[..., dims]
Grad = torch.einsum("...Njk,...Nj->...Nk", J, residuals)
Grad = weights[..., None] * Grad
Grad = Grad.sum(-2) # (B, N_params)
if shared_intrinsics:
# reshape to (1, B * (N_params-1) + 1)
Grad_g = Grad[..., :2].reshape(1, -1)
Grad_f = Grad[..., 2].reshape(1, -1).sum(-1, keepdim=True)
Grad = torch.cat([Grad_g, Grad_f], dim=-1)
Hess = torch.einsum("...Njk,...Njl->...Nkl", J, J)
Hess = weights[..., None, None] * Hess
Hess = Hess.sum(-3)
if shared_intrinsics:
H_g = torch.block_diag(*list(Hess[..., :2, :2]))
J_fg = Hess[..., :2, 2].flatten()
J_gf = Hess[..., 2, :2].flatten()
J_f = Hess[..., 2, 2].sum()
dims = H_g.shape[-1] + 1
Hess = Hess.new_zeros((dims, dims), dtype=torch.float32)
Hess[:-1, :-1] = H_g
Hess[-1, :-1] = J_gf
Hess[:-1, -1] = J_fg
Hess[-1, -1] = J_f
Hess = Hess.unsqueeze(0)
return Grad, Hess
def setup_system(
self,
camera: BaseCamera,
gravity: Gravity,
residuals: Dict[str, torch.Tensor],
weights: Dict[str, torch.Tensor],
as_rpf: bool = False,
shared_intrinsics: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Calculate the gradient and Hessian for the optimization.
Args:
camera (BaseCamera): Optimized camera.
gravity (Gravity): Optimized gravity.
residuals (Dict[str, torch.Tensor]): Residuals for the optimization.
weights (Dict[str, torch.Tensor]): Weights for the optimization.
as_rpf (bool, optional): Wether to calculate the gradient and Hessian with respect to
roll, pitch, and focal length. Defaults to False.
shared_intrinsics (bool, optional): Whether to share the intrinsics across the batch.
Defaults to False.
Returns:
Tuple[torch.Tensor, torch.Tensor]: Gradient and Hessian for the optimization.
"""
J_up, J_lat = J_perspective_field(
camera,
gravity,
spherical=self.conf.use_spherical_manifold and not as_rpf,
log_focal=self.conf.use_log_focal and not as_rpf,
)
J_up = J_up.reshape(J_up.shape[0], -1, J_up.shape[-2], J_up.shape[-1]) # (B, N, 2, 3)
J_lat = J_lat.reshape(J_lat.shape[0], -1, J_lat.shape[-2], J_lat.shape[-1]) # (B, N, 1, 3)
n_params = (
2 * self.estimate_gravity
+ self.estimate_focal
+ (self.camera_has_distortion and self.estimate_k1)
)
Grad = J_up.new_zeros(J_up.shape[0], n_params)
Hess = J_up.new_zeros(J_up.shape[0], n_params, n_params)
if shared_intrinsics:
N_params = Grad.shape[0] * (n_params - 1) + 1
Grad = Grad.new_zeros(1, N_params)
Hess = Hess.new_zeros(1, N_params, N_params)
if "up_residual" in residuals:
Up_Grad, Up_Hess = self.calculate_gradient_and_hessian(
J_up, residuals["up_residual"], weights["up_weights"], shared_intrinsics
)
if self.conf.verbose:
logger.info(f"Up J:\n{Up_Grad.mean(0)}")
Grad = Grad + Up_Grad
Hess = Hess + Up_Hess
if "latitude_residual" in residuals:
Lat_Grad, Lat_Hess = self.calculate_gradient_and_hessian(
J_lat,
residuals["latitude_residual"],
weights["latitude_weights"],
shared_intrinsics,
)
if self.conf.verbose:
logger.info(f"Lat J:\n{Lat_Grad.mean(0)}")
Grad = Grad + Lat_Grad
Hess = Hess + Lat_Hess
return Grad, Hess
def estimate_uncertainty(
self,
camera_opt: BaseCamera,
gravity_opt: Gravity,
errors: Dict[str, torch.Tensor],
weights: Dict[str, torch.Tensor],
) -> Dict[str, torch.Tensor]:
"""Estimate the uncertainty of the optimized camera and gravity at the final step.
Args:
camera_opt (BaseCamera): Final optimized camera.
gravity_opt (Gravity): Final optimized gravity.
errors (Dict[str, torch.Tensor]): Costs for the optimization.
weights (Dict[str, torch.Tensor]): Weights for the optimization.
Returns:
Dict[str, torch.Tensor]: Uncertainty estimates for the optimized camera and gravity.
"""
_, Hess = self.setup_system(
camera_opt, gravity_opt, errors, weights, as_rpf=True, shared_intrinsics=False
)
Cov = torch.inverse(Hess)
roll_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
pitch_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
gravity_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
if self.estimate_gravity:
roll_uncertainty = Cov[..., 0, 0]
pitch_uncertainty = Cov[..., 1, 1]
try:
delta_uncertainty = Cov[..., :2, :2]
eigenvalues = torch.linalg.eigvalsh(delta_uncertainty.cpu())
gravity_uncertainty = torch.max(eigenvalues, dim=-1).values.to(Cov.device)
except RuntimeError:
logger.warning("Could not calculate gravity uncertainty")
gravity_uncertainty = Cov.new_zeros(Cov.shape[0])
focal_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
fov_uncertainty = Cov.new_zeros(Cov[..., 0, 0].shape)
if self.estimate_focal:
focal_uncertainty = Cov[..., self.focal_delta_dims[0], self.focal_delta_dims[0]]
fov_uncertainty = (
J_focal2fov(camera_opt.f[..., 1], camera_opt.size[..., 1]) ** 2 * focal_uncertainty
)
return {
"covariance": Cov,
"roll_uncertainty": torch.sqrt(roll_uncertainty),
"pitch_uncertainty": torch.sqrt(pitch_uncertainty),
"gravity_uncertainty": torch.sqrt(gravity_uncertainty),
"focal_uncertainty": torch.sqrt(focal_uncertainty) / 2,
"vfov_uncertainty": torch.sqrt(fov_uncertainty / 2),
}
def update_estimate(
self, camera: BaseCamera, gravity: Gravity, delta: torch.Tensor
) -> Tuple[BaseCamera, Gravity]:
"""Update the camera and gravity estimates with the given delta.
Args:
camera (BaseCamera): Optimized camera.
gravity (Gravity): Optimized gravity.
delta (torch.Tensor): Delta to update the camera and gravity estimates.
Returns:
Tuple[BaseCamera, Gravity]: Updated camera and gravity estimates.
"""
delta_gravity = (
delta[..., self.gravity_delta_dims]
if self.estimate_gravity
else delta.new_zeros(delta.shape[:-1] + (2,))
)
new_gravity = gravity.update(delta_gravity, spherical=self.conf.use_spherical_manifold)
delta_f = (
delta[..., self.focal_delta_dims]
if self.estimate_focal
else delta.new_zeros(delta.shape[:-1] + (1,))
)
new_camera = camera.update_focal(delta_f, as_log=self.conf.use_log_focal)
delta_dist = (
delta[..., self.k1_delta_dims]
if self.camera_has_distortion and self.estimate_k1
else delta.new_zeros(delta.shape[:-1] + (1,))
)
if self.camera_has_distortion:
new_camera = new_camera.update_dist(delta_dist)
return new_camera, new_gravity
def optimize(
self,
data: Dict[str, torch.Tensor],
camera_opt: BaseCamera,
gravity_opt: Gravity,
) -> Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]:
"""Optimize the camera and gravity estimates.
Args:
data (Dict[str, torch.Tensor]): Input data.
camera_opt (BaseCamera): Optimized camera.
gravity_opt (Gravity): Optimized gravity.
Returns:
Tuple[BaseCamera, Gravity, Dict[str, torch.Tensor]]: Optimized camera, gravity
estimates and optimization information.
"""
key = list(data.keys())[0]
B = data[key].shape[0]
lamb = data[key].new_ones(B) * self.conf.lambda_
if self.shared_intrinsics:
lamb = data[key].new_ones(1) * self.conf.lambda_
infos = {"stop_at": self.num_steps}
for i in range(self.num_steps):
if self.conf.verbose:
logger.info(f"Step {i+1}/{self.num_steps}")
errors = self.calculate_residuals(camera_opt, gravity_opt, data)
costs, weights = self.calculate_costs(errors, data)
if i == 0:
prev_cost = sum(c.mean(-1) for c in costs.values())
for k, c in costs.items():
infos[f"initial_{k}"] = c.mean(-1)
infos["initial_cost"] = prev_cost
Grad, Hess = self.setup_system(
camera_opt,
gravity_opt,
errors,
weights,
shared_intrinsics=self.shared_intrinsics,
)
delta = optimizer_step(Grad, Hess, lamb) # (B, N_params)
if self.shared_intrinsics:
delta_g = delta[..., :-1].reshape(B, 2)
delta_f = delta[..., -1].expand(B, 1)
delta = torch.cat([delta_g, delta_f], dim=-1)
# calculate new cost
camera_opt, gravity_opt = self.update_estimate(camera_opt, gravity_opt, delta)
new_cost, _ = self.calculate_costs(
self.calculate_residuals(camera_opt, gravity_opt, data), data
)
new_cost = sum(c.mean(-1) for c in new_cost.values())
if not self.conf.fix_lambda and not self.shared_intrinsics:
lamb = update_lambda(lamb, prev_cost, new_cost)
if self.conf.verbose:
logger.info(f"Cost:\nPrev: {prev_cost}\nNew: {new_cost}")
logger.info(f"Camera:\n{camera_opt._data}")
if early_stop(new_cost, prev_cost, atol=self.conf.atol, rtol=self.conf.rtol):
infos["stop_at"] = min(i + 1, infos["stop_at"])
if self.conf.early_stop:
if self.conf.verbose:
logger.info(f"Early stopping at step {i+1}")
break
prev_cost = new_cost
if i == self.num_steps - 1 and self.conf.early_stop:
logger.warning("Reached maximum number of steps without convergence.")
final_errors = self.calculate_residuals(camera_opt, gravity_opt, data) # (B, N, 3)
final_cost, weights = self.calculate_costs(final_errors, data) # (B, N)
if not self.training:
infos |= self.estimate_uncertainty(camera_opt, gravity_opt, final_errors, weights)
infos["stop_at"] = camera_opt.new_ones(camera_opt.shape[0]) * infos["stop_at"]
for k, c in final_cost.items():
infos[f"final_{k}"] = c.mean(-1)
infos["final_cost"] = sum(c.mean(-1) for c in final_cost.values())
return camera_opt, gravity_opt, infos
def _forward(self, data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""Run the LM optimization."""
if self.initializer is None:
camera_init, gravity_init = get_initial_estimation(
data, self.camera_model, trivial_init=self.conf.init_conf.name == "trivial"
)
else:
out = self.initializer(data)
camera_init = out["camera"]
gravity_init = out["gravity"]
self.setup_optimization_and_priors(data, shared_intrinsics=self.shared_intrinsics)
start = time.time()
camera_opt, gravity_opt, infos = self.optimize(data, camera_init, gravity_init)
if self.conf.verbose:
logger.info(f"Optimization took {(time.time() - start)*1000:.2f} ms")
logger.info(f"Initial camera:\n{rad2deg(camera_init.vfov)}")
logger.info(f"Optimized camera:\n{rad2deg(camera_opt.vfov)}")
logger.info(f"Initial gravity:\n{rad2deg(gravity_init.rp)}")
logger.info(f"Optimized gravity:\n{rad2deg(gravity_opt.rp)}")
return {"camera": camera_opt, "gravity": gravity_opt, **infos}
def metrics(
self, pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor]
) -> Dict[str, torch.Tensor]:
"""Calculate the metrics for the optimization."""
pred_cam, gt_cam = pred["camera"], data["camera"]
pred_gravity, gt_gravity = pred["gravity"], data["gravity"]
infos = {"stop_at": pred["stop_at"]}
for k, v in pred.items():
if "initial" in k or "final" in k:
infos[k] = v
return {
"roll_error": roll_error(pred_gravity, gt_gravity),
"pitch_error": pitch_error(pred_gravity, gt_gravity),
"gravity_error": gravity_error(pred_gravity, gt_gravity),
"vfov_error": vfov_error(pred_cam, gt_cam),
"k1_error": dist_error(pred_cam, gt_cam),
**infos,
}
def loss(
self, pred: Dict[str, torch.Tensor], data: Dict[str, torch.Tensor]
) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""Calculate the loss for the optimization."""
pred_cam, gt_cam = pred["camera"], data["camera"]
pred_gravity, gt_gravity = pred["gravity"], data["gravity"]
loss_fn = nn.L1Loss(reduction="none")
# loss will be 0 if estimate is false and prior is provided during training
gravity_loss = loss_fn(pred_gravity.vec3d, gt_gravity.vec3d)
h = data["camera"].size[0, 0]
focal_loss = loss_fn(pred_cam.f, gt_cam.f).mean(-1) / h
dist_loss = focal_loss.new_zeros(focal_loss.shape)
if self.camera_has_distortion:
dist_loss = loss_fn(pred_cam.dist, gt_cam.dist).sum(-1)
losses = {
"gravity": gravity_loss.sum(-1),
"focal": focal_loss,
"dist": dist_loss,
"param_total": gravity_loss.sum(-1) + focal_loss + dist_loss,
}
losses = {k: v * self.conf.loss_weight for k, v in losses.items()}
return losses, self.metrics(pred, data)