|
import logging |
|
from typing import Tuple |
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from tqdm import tqdm |
|
|
|
from siclib.geometry.camera import Pinhole as Camera |
|
from siclib.geometry.gravity import Gravity |
|
from siclib.geometry.perspective_fields import get_perspective_field |
|
from siclib.models.base_model import BaseModel |
|
from siclib.models.utils.metrics import pitch_error, roll_error, vfov_error |
|
from siclib.utils.conversions import deg2rad |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class PerspectiveParamOpt(BaseModel): |
|
default_conf = { |
|
"max_steps": 1000, |
|
"lr": 0.01, |
|
"lr_scheduler": { |
|
"name": "ReduceLROnPlateau", |
|
"options": {"mode": "min", "patience": 3}, |
|
}, |
|
"patience": 3, |
|
"abs_tol": 1e-7, |
|
"rel_tol": 1e-9, |
|
"lamb": 0.5, |
|
"verbose": False, |
|
} |
|
|
|
required_data_keys = ["up_field", "latitude_field"] |
|
|
|
def _init(self, conf): |
|
pass |
|
|
|
def cost_function(self, pred, target): |
|
"""Compute cost function for perspective parameter optimization.""" |
|
eps = 1e-7 |
|
|
|
lat_loss = F.l1_loss(pred["latitude_field"], target["latitude_field"], reduction="none") |
|
lat_loss = lat_loss.squeeze(1) |
|
|
|
up_loss = F.cosine_similarity(pred["up_field"], target["up_field"], dim=1) |
|
up_loss = torch.acos(torch.clip(up_loss, -1 + eps, 1 - eps)) |
|
|
|
cost = (self.conf.lamb * lat_loss) + ((1 - self.conf.lamb) * up_loss) |
|
return { |
|
"total": torch.mean(cost), |
|
"up": torch.mean(up_loss), |
|
"latitude": torch.mean(lat_loss), |
|
} |
|
|
|
def check_convergence(self, loss, losses_prev): |
|
"""Check if optimization has converged.""" |
|
|
|
if loss["total"].item() <= self.conf.abs_tol: |
|
return True, losses_prev |
|
|
|
if len(losses_prev) < self.conf.patience: |
|
losses_prev.append(loss["total"].item()) |
|
|
|
elif np.abs(loss["total"].item() - losses_prev[0]) < self.conf.rel_tol: |
|
return True, losses_prev |
|
|
|
else: |
|
losses_prev.append(loss["total"].item()) |
|
losses_prev = losses_prev[-self.conf.patience :] |
|
|
|
return False, losses_prev |
|
|
|
def _update_estimate(self, camera: Camera, gravity: Gravity): |
|
"""Update camera estimate based on current parameters.""" |
|
|
|
camera = Camera.from_dict( |
|
{"height": camera.size[..., 1], "width": camera.size[..., 0], "vfov": self.vfov_opt} |
|
) |
|
gravity = Gravity.from_rp(self.roll_opt, self.pitch_opt) |
|
return camera, gravity |
|
|
|
def optimize(self, data, camera_init, gravity_init): |
|
"""Optimize camera parameters to minimize cost function.""" |
|
device = data["up_field"].device |
|
self.roll_opt = nn.Parameter(gravity_init.roll, requires_grad=True).to(device) |
|
self.pitch_opt = nn.Parameter(gravity_init.pitch, requires_grad=True).to(device) |
|
self.vfov_opt = nn.Parameter(camera_init.vfov, requires_grad=True).to(device) |
|
|
|
optimizer = torch.optim.Adam( |
|
[self.roll_opt, self.pitch_opt, self.vfov_opt], lr=self.conf.lr |
|
) |
|
|
|
lr_scheduler = None |
|
if self.conf.lr_scheduler["name"] is not None: |
|
lr_scheduler = getattr(torch.optim.lr_scheduler, self.conf.lr_scheduler["name"])( |
|
optimizer, **self.conf.lr_scheduler["options"] |
|
) |
|
|
|
losses_prev = [] |
|
|
|
loop = range(self.conf.max_steps) |
|
if self.conf.verbose: |
|
pbar = tqdm(loop, desc="Optimizing", total=len(loop), ncols=100) |
|
|
|
with torch.set_grad_enabled(True): |
|
self.train() |
|
for _ in loop: |
|
optimizer.zero_grad() |
|
|
|
camera_opt, gravity_opt = self._update_estimate(camera_init, gravity_init) |
|
|
|
up, lat = get_perspective_field(camera_opt, gravity_opt) |
|
pred = {"up_field": up, "latitude_field": lat} |
|
|
|
loss = self.cost_function(pred, data) |
|
loss["total"].backward() |
|
optimizer.step() |
|
|
|
if lr_scheduler is not None: |
|
lr_scheduler.step(loss["total"]) |
|
|
|
if self.conf.verbose: |
|
pbar.set_postfix({k[:3]: v.item() for k, v in loss.items()}) |
|
pbar.update(1) |
|
|
|
converged, losses_prev = self.check_convergence(loss, losses_prev) |
|
if converged: |
|
if self.conf.verbose: |
|
pbar.close() |
|
break |
|
|
|
camera_opt, gravity_opt = self._update_estimate(camera_init, gravity_init) |
|
return {"camera_opt": camera_opt, "gravity_opt": gravity_opt} |
|
|
|
def _get_init_params(self, data) -> Tuple[Camera, Gravity]: |
|
"""Get initial camera parameters for optimization.""" |
|
up_ref = data["up_field"] |
|
latitude_ref = data["latitude_field"] |
|
|
|
h, w = latitude_ref.shape[-2:] |
|
|
|
|
|
init_r = -torch.arctan2( |
|
up_ref[:, 0, int(h / 2), int(w / 2)], |
|
-up_ref[:, 1, int(h / 2), int(w / 2)], |
|
) |
|
|
|
|
|
init_p = latitude_ref[:, 0, int(h / 2), int(w / 2)] |
|
|
|
|
|
init_vfov = latitude_ref[:, 0, 0, int(w / 2)] - latitude_ref[:, 0, -1, int(w / 2)] |
|
init_vfov = torch.abs(init_vfov) |
|
init_vfov = init_vfov.clamp(min=deg2rad(20), max=deg2rad(120)) |
|
|
|
h, w = ( |
|
latitude_ref.new_ones(latitude_ref.shape[0]) * h, |
|
latitude_ref.new_ones(latitude_ref.shape[0]) * w, |
|
) |
|
params = {"width": w, "height": h, "vfov": init_vfov} |
|
camera = Camera.from_dict(params) |
|
gravity = Gravity.from_rp(init_r, init_p) |
|
return camera, gravity |
|
|
|
def _forward(self, data): |
|
"""Forward pass of optimization model.""" |
|
|
|
assert data["up_field"].shape[0] == 1, "Batch size must be 1 for optimization model." |
|
|
|
|
|
for k, v in data.items(): |
|
if isinstance(v, torch.Tensor): |
|
data[k] = v.detach() |
|
|
|
camera_init, gravity_init = self._get_init_params(data) |
|
return self.optimize(data, camera_init, gravity_init) |
|
|
|
def metrics(self, pred, data): |
|
pred_cam, gt_cam = pred["camera_opt"], data["camera"] |
|
pred_grav, gt_grav = pred["gravity_opt"], data["gravity"] |
|
|
|
return { |
|
"roll_opt_error": roll_error(pred_grav, gt_grav), |
|
"pitch_opt_error": pitch_error(pred_grav, gt_grav), |
|
"vfov_opt_error": vfov_error(pred_cam, gt_cam), |
|
} |
|
|
|
def loss(self, pred, data): |
|
"""No loss function for this optimization model.""" |
|
return {"opt_param_total": 0}, self.metrics(pred, data) |
|
|