GeoCalib / siclib /models /optimization /perspective_opt.py
veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
6.97 kB
import logging
from typing import Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
from siclib.geometry.camera import Pinhole as Camera
from siclib.geometry.gravity import Gravity
from siclib.geometry.perspective_fields import get_perspective_field
from siclib.models.base_model import BaseModel
from siclib.models.utils.metrics import pitch_error, roll_error, vfov_error
from siclib.utils.conversions import deg2rad
logger = logging.getLogger(__name__)
# flake8: noqa
# mypy: ignore-errors
class PerspectiveParamOpt(BaseModel):
default_conf = {
"max_steps": 1000,
"lr": 0.01,
"lr_scheduler": {
"name": "ReduceLROnPlateau",
"options": {"mode": "min", "patience": 3},
},
"patience": 3,
"abs_tol": 1e-7,
"rel_tol": 1e-9,
"lamb": 0.5,
"verbose": False,
}
required_data_keys = ["up_field", "latitude_field"]
def _init(self, conf):
pass
def cost_function(self, pred, target):
"""Compute cost function for perspective parameter optimization."""
eps = 1e-7
lat_loss = F.l1_loss(pred["latitude_field"], target["latitude_field"], reduction="none")
lat_loss = lat_loss.squeeze(1)
up_loss = F.cosine_similarity(pred["up_field"], target["up_field"], dim=1)
up_loss = torch.acos(torch.clip(up_loss, -1 + eps, 1 - eps))
cost = (self.conf.lamb * lat_loss) + ((1 - self.conf.lamb) * up_loss)
return {
"total": torch.mean(cost),
"up": torch.mean(up_loss),
"latitude": torch.mean(lat_loss),
}
def check_convergence(self, loss, losses_prev):
"""Check if optimization has converged."""
if loss["total"].item() <= self.conf.abs_tol:
return True, losses_prev
if len(losses_prev) < self.conf.patience:
losses_prev.append(loss["total"].item())
elif np.abs(loss["total"].item() - losses_prev[0]) < self.conf.rel_tol:
return True, losses_prev
else:
losses_prev.append(loss["total"].item())
losses_prev = losses_prev[-self.conf.patience :]
return False, losses_prev
def _update_estimate(self, camera: Camera, gravity: Gravity):
"""Update camera estimate based on current parameters."""
camera = Camera.from_dict(
{"height": camera.size[..., 1], "width": camera.size[..., 0], "vfov": self.vfov_opt}
)
gravity = Gravity.from_rp(self.roll_opt, self.pitch_opt)
return camera, gravity
def optimize(self, data, camera_init, gravity_init):
"""Optimize camera parameters to minimize cost function."""
device = data["up_field"].device
self.roll_opt = nn.Parameter(gravity_init.roll, requires_grad=True).to(device)
self.pitch_opt = nn.Parameter(gravity_init.pitch, requires_grad=True).to(device)
self.vfov_opt = nn.Parameter(camera_init.vfov, requires_grad=True).to(device)
optimizer = torch.optim.Adam(
[self.roll_opt, self.pitch_opt, self.vfov_opt], lr=self.conf.lr
)
lr_scheduler = None
if self.conf.lr_scheduler["name"] is not None:
lr_scheduler = getattr(torch.optim.lr_scheduler, self.conf.lr_scheduler["name"])(
optimizer, **self.conf.lr_scheduler["options"]
)
losses_prev = []
loop = range(self.conf.max_steps)
if self.conf.verbose:
pbar = tqdm(loop, desc="Optimizing", total=len(loop), ncols=100)
with torch.set_grad_enabled(True):
self.train()
for _ in loop:
optimizer.zero_grad()
camera_opt, gravity_opt = self._update_estimate(camera_init, gravity_init)
up, lat = get_perspective_field(camera_opt, gravity_opt)
pred = {"up_field": up, "latitude_field": lat}
loss = self.cost_function(pred, data)
loss["total"].backward()
optimizer.step()
if lr_scheduler is not None:
lr_scheduler.step(loss["total"])
if self.conf.verbose:
pbar.set_postfix({k[:3]: v.item() for k, v in loss.items()})
pbar.update(1)
converged, losses_prev = self.check_convergence(loss, losses_prev)
if converged:
if self.conf.verbose:
pbar.close()
break
camera_opt, gravity_opt = self._update_estimate(camera_init, gravity_init)
return {"camera_opt": camera_opt, "gravity_opt": gravity_opt}
def _get_init_params(self, data) -> Tuple[Camera, Gravity]:
"""Get initial camera parameters for optimization."""
up_ref = data["up_field"]
latitude_ref = data["latitude_field"]
h, w = latitude_ref.shape[-2:]
# init roll is angle of the up vector at the center of the image
init_r = -torch.arctan2(
up_ref[:, 0, int(h / 2), int(w / 2)],
-up_ref[:, 1, int(h / 2), int(w / 2)],
)
# init pitch is the value at the center of the latitude map
init_p = latitude_ref[:, 0, int(h / 2), int(w / 2)]
# init vfov is the difference between the central top and bottom of the latitude map
init_vfov = latitude_ref[:, 0, 0, int(w / 2)] - latitude_ref[:, 0, -1, int(w / 2)]
init_vfov = torch.abs(init_vfov)
init_vfov = init_vfov.clamp(min=deg2rad(20), max=deg2rad(120))
h, w = (
latitude_ref.new_ones(latitude_ref.shape[0]) * h,
latitude_ref.new_ones(latitude_ref.shape[0]) * w,
)
params = {"width": w, "height": h, "vfov": init_vfov}
camera = Camera.from_dict(params)
gravity = Gravity.from_rp(init_r, init_p)
return camera, gravity
def _forward(self, data):
"""Forward pass of optimization model."""
assert data["up_field"].shape[0] == 1, "Batch size must be 1 for optimization model."
# detach all tensors to avoid backprop
for k, v in data.items():
if isinstance(v, torch.Tensor):
data[k] = v.detach()
camera_init, gravity_init = self._get_init_params(data)
return self.optimize(data, camera_init, gravity_init)
def metrics(self, pred, data):
pred_cam, gt_cam = pred["camera_opt"], data["camera"]
pred_grav, gt_grav = pred["gravity_opt"], data["gravity"]
return {
"roll_opt_error": roll_error(pred_grav, gt_grav),
"pitch_opt_error": pitch_error(pred_grav, gt_grav),
"vfov_opt_error": vfov_error(pred_cam, gt_cam),
}
def loss(self, pred, data):
"""No loss function for this optimization model."""
return {"opt_param_total": 0}, self.metrics(pred, data)