from typing import Optional, Tuple import numpy as np import torch import torch.nn.functional as F from siclib.geometry.camera import Pinhole from siclib.geometry.gravity import Gravity from siclib.geometry.perspective_fields import get_latitude_field, get_up_field from siclib.models.base_model import BaseModel from siclib.models.utils.metrics import ( latitude_error, pitch_error, roll_error, up_error, vfov_error, ) from siclib.utils.conversions import skew_symmetric # flake8: noqa # mypy: ignore-errors def get_up_lines(up, xy): up_lines = torch.cat([up, torch.zeros_like(up[..., :1])], dim=-1) xy1 = torch.cat([xy, torch.ones_like(xy[..., :1])], dim=-1) xy2 = xy1 + up_lines return torch.einsum("...ij,...j->...i", skew_symmetric(xy1), xy2) def calculate_vvp(line1, line2): return torch.einsum("...ij,...j->...i", skew_symmetric(line1), line2) def calculate_vvps(xs, ys, up): xy_grav = torch.stack([xs[..., :2], ys[..., :2]], dim=-1).float() up_lines = get_up_lines(up, xy_grav) # (B, N, 2, D) vvp = calculate_vvp(up_lines[..., 0, :], up_lines[..., 1, :]) # (B, N, 3) vvp = vvp / vvp[..., (2,)] return vvp def get_up_samples(pred, xs, ys): B, N = xs.shape[:2] batch_indices = torch.arange(B).unsqueeze(1).unsqueeze(2).expand(B, N, 3).to(xs.device) zeros = torch.zeros_like(xs).to(xs.device) ones = torch.ones_like(xs).to(xs.device) sample_indices_x = torch.stack([batch_indices, zeros, ys, xs], dim=-1).long() # (B, N, 3, 4) sample_indices_y = torch.stack([batch_indices, ones, ys, xs], dim=-1).long() # (B, N, 3, 4) up_x = pred["up_field"][sample_indices_x[..., (0, 1), :].unbind(-1)] # (B, N, 2) up_y = pred["up_field"][sample_indices_y[..., (0, 1), :].unbind(-1)] # (B, N, 2) return torch.stack([up_x, up_y], dim=-1) # (B, N, 2, D) def get_latitude_samples(pred, xs, ys): # Setup latitude B, N = xs.shape[:2] batch_indices = torch.arange(B).unsqueeze(1).unsqueeze(2).expand(B, N, 3).to(xs.device) zeros = torch.zeros_like(xs).to(xs.device) sample_indices = torch.stack([batch_indices, zeros, ys, xs], dim=-1).long() # (B, N, 3, 4) latitude = pred["latitude_field"][sample_indices[..., 2, :].unbind(-1)] return torch.sin(latitude) # (B, N) class MinimalSolver: def __init__(self): pass @staticmethod def solve_focal( L: torch.Tensor, xy: torch.Tensor, vvp: torch.Tensor, c: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor]: """Solve for focal length. Args: L (torch.Tensor): Latitude samples. xy (torch.Tensor): xy of latitude samples of shape (..., 2). vvp (torch.Tensor): Vertical vanishing points of shape (..., 3). c (torch.Tensor): Principal points of shape (..., 2). Returns: Tuple[torch.Tensor, torch.Tensor]: Positive and negative solution of focal length. """ c = c.unsqueeze(1) u, v = (xy - c).unbind(-1) vx, vy, vz = vvp.unbind(-1) cx, cy = c.unbind(-1) vx = vx - cx * vz vy = vy - cy * vz # Solve quadratic equation a0 = (L**2 - 1) * vz**2 a1 = L**2 * (vz**2 * (u**2 + v**2) + vx**2 + vy**2) - 2 * vz * (vx * u + vy * v) a2 = L**2 * (v**2 + u**2) * (vx**2 + vy**2) - (u * vx + v * vy) ** 2 a0 = torch.where(a0 == 0, torch.ones_like(a0) * 1e-6, a0) f2_pos = -a1 / (2 * a0) + torch.sqrt(a1**2 - 4 * a0 * a2) / (2 * a0) f2_neg = -a1 / (2 * a0) - torch.sqrt(a1**2 - 4 * a0 * a2) / (2 * a0) f_pos, f_neg = torch.sqrt(f2_pos), torch.sqrt(f2_neg) return f_pos, f_neg @staticmethod def solve_scale( L: torch.Tensor, xy: torch.Tensor, vvp: torch.Tensor, c: torch.Tensor, f: torch.Tensor ) -> torch.Tensor: """Solve for scale of homogeneous vector. Args: L (torch.Tensor): Latitude samples. xy (torch.Tensor): xy of latitude samples of shape (..., 2). vvp (torch.Tensor): Vertical vanishing points of shape (..., 3). c (torch.Tensor): Principal points of shape (..., 2). f (torch.Tensor): Focal lengths. Returns: torch.Tensor: Estimated scales. """ c = c.unsqueeze(1) u, v = (xy - c).unbind(-1) vx, vy, vz = vvp.unbind(-1) cx, cy = c.unbind(-1) vx = vx - cx * vz vy = vy - cy * vz w2 = (f**2 * L**2 * (u**2 + v**2 + f**2)) / (vx * u + vy * v + vz * f**2) ** 2 return torch.sqrt(w2) @staticmethod def solve_abc( vvp: torch.Tensor, c: torch.Tensor, f: torch.Tensor, w: Optional[torch.Tensor] = None ) -> torch.Tensor: """Solve for abc vector (solution to homogeneous equation). Args: vvp (torch.Tensor): Vertical vanishing points of shape (..., 3). c (torch.Tensor): Principal points of shape (..., 2). f (torch.Tensor): Focal lengths. w (torch.Tensor): Scales. Returns: torch.Tensor: Estimated abc vector. """ vx, vy, vz = vvp.unbind(-1) cx, cy = c.unsqueeze(1).unbind(-1) vx = vx - cx * vz vy = vy - cy * vz a = vx / f b = vy / f c = vz abc = torch.stack((a, b, c), dim=-1) return F.normalize(abc, dim=-1) if w is None else abc * w.unsqueeze(-1) @staticmethod def solve_rp(abc: torch.Tensor) -> torch.Tensor: """Solve for roll, pitch. Args: abc (torch.Tensor): Estimated abc vector. Returns: torch.Tensor: Estimated roll, pitch, focal length. """ a, _, c = abc.unbind(-1) roll = torch.asin(-a / torch.sqrt(1 - c**2)) pitch = torch.asin(c) return roll, pitch class RPFSolver(BaseModel): default_conf = { "n_iter": 1000, "up_inlier_th": 1, "latitude_inlier_th": 1, "error_fn": "angle", # angle or mse "up_weight": 1, "latitude_weight": 1, "loss_weight": 1, "use_latitude": True, } def _init(self, conf): self.solver = MinimalSolver() def check_up_inliers(self, pred, est_camera, est_gravity, N=1): pred_up = pred["up_field"] # expand from from (B, 1, H, W) to (B * N, 1, H, W) B = pred_up.shape[0] pred_up = pred_up.unsqueeze(1).expand(-1, N, -1, -1, -1) pred_up = pred_up.reshape(B * N, *pred_up.shape[2:]) est_up = get_up_field(est_camera, est_gravity).permute(0, 3, 1, 2) if self.conf.error_fn == "angle": mse = up_error(est_up, pred_up) elif self.conf.error_fn == "mse": mse = F.mse_loss(est_up, pred_up, reduction="none").mean(1) else: raise ValueError(f"Unknown error function: {self.conf.error_fn}") # shape (B, H, W) conf = pred.get("up_confidence", pred_up.new_ones(pred_up.shape[0], *pred_up.shape[-2:])) # shape (B, N, H, W) conf = conf.unsqueeze(1).expand(-1, N, -1, -1) # shape (B * N, H, W) conf = conf.reshape(B * N, *conf.shape[-2:]) return (mse < self.conf.up_inlier_th) * conf def check_latitude_inliers(self, pred, est_camera, est_gravity, N=1): B = pred["up_field"].shape[0] pred_latitude = pred.get("latitude_field") if pred_latitude is None: shape = (B * N, *pred["up_field"].shape[-2:]) return est_camera.new_zeros(shape) # expand from from (B, 1, H, W) to (B * N, 1, H, W) pred_latitude = pred_latitude.unsqueeze(1).expand(-1, N, -1, -1, -1) pred_latitude = pred_latitude.reshape(B * N, *pred_latitude.shape[2:]) est_latitude = get_latitude_field(est_camera, est_gravity).permute(0, 3, 1, 2) if self.conf.error_fn == "angle": error = latitude_error(est_latitude, pred_latitude) elif self.conf.error_fn == "mse": error = F.mse_loss(est_latitude, pred_latitude, reduction="none").mean(1) else: raise ValueError(f"Unknown error function: {self.conf.error_fn}") conf = pred.get( "latitude_confidence", pred_latitude.new_ones(pred_latitude.shape[0], *pred_latitude.shape[-2:]), ) conf = conf.unsqueeze(1).expand(-1, N, -1, -1) conf = conf.reshape(B * N, *conf.shape[-2:]) return (error < self.conf.latitude_inlier_th) * conf def get_best_index(self, data, camera, gravity, inliers=None): B, _, H, W = data["up_field"].shape N = self.conf.n_iter up_inliers = self.check_up_inliers(data, camera, gravity, N) latitude_inliers = self.check_latitude_inliers(data, camera, gravity, N) up_inliers = up_inliers.reshape(B, N, H, W) latitude_inliers = latitude_inliers.reshape(B, N, H, W) if inliers is not None: up_inliers = up_inliers * inliers.unsqueeze(1) latitude_inliers = latitude_inliers * inliers.unsqueeze(1) up_inliers = up_inliers.sum((2, 3)) latitude_inliers = latitude_inliers.sum((2, 3)) total_inliers = ( self.conf.up_weight * up_inliers + self.conf.latitude_weight * latitude_inliers ) best_idx = total_inliers.argmax(-1) return best_idx, total_inliers[torch.arange(B), best_idx] def solve_rpf(self, pred, xs, ys, principal_points, focal=None): device = pred["up_field"].device # Get samples up = get_up_samples(pred, xs, ys) # Calculate vvps vvp = calculate_vvps(xs, ys, up).to(device) # Solve for focal length xy = torch.stack([xs[..., 2], ys[..., 2]], dim=-1).float() if focal is not None: f = focal.new_ones(xs[..., 2].shape) * focal.unsqueeze(-1) f_pos, f_neg = f, f else: L = get_latitude_samples(pred, xs, ys) f_pos, f_neg = self.solver.solve_focal(L, xy, vvp, principal_points) # Solve for abc abc_pos = self.solver.solve_abc(vvp, principal_points, f_pos) abc_neg = self.solver.solve_abc(vvp, principal_points, f_neg) # Solve for roll, pitch roll_pos, pitch_pos = self.solver.solve_rp(abc_pos) roll_neg, pitch_neg = self.solver.solve_rp(abc_neg) rpf_pos = torch.stack([roll_pos, pitch_pos, f_pos], dim=-1) rpf_neg = torch.stack([roll_neg, pitch_neg, f_neg], dim=-1) return rpf_pos, rpf_neg def get_camera_and_gravity(self, pred, rpf): B, _, H, W = pred["up_field"].shape N = rpf.shape[1] w = pred["up_field"].new_ones(B, N) * W h = pred["up_field"].new_ones(B, N) * H cx = w / 2.0 cy = h / 2.0 roll, pitch, focal = rpf.unbind(-1) params = torch.stack([w, h, focal, focal, cx, cy], dim=-1) params = params.reshape(B * N, params.shape[-1]) cam = Pinhole(params) roll, pitch = roll.reshape(B * N), pitch.reshape(B * N) gravity = Gravity.from_rp(roll, pitch) return cam, gravity def _forward(self, data): device = data["up_field"].device B, _, H, W = data["up_field"].shape principal_points = torch.tensor([H / 2.0, W / 2.0]).expand(B, 2).to(device) if not self.conf.use_latitude and "latitude_field" in data: data.pop("latitude_field") if "inliers" in data: indices = torch.nonzero(data["inliers"] == 1, as_tuple=False) batch_indices = torch.unique(indices[:, 0]) sampled_indices = [] for batch_index in batch_indices: batch_mask = indices[:, 0] == batch_index batch_indices_sampled = np.random.choice( batch_mask.sum(), self.conf.n_iter * 3, replace=True ) batch_indices_sampled = batch_indices_sampled.reshape(self.conf.n_iter, 3) sampled_indices.append(indices[batch_mask][batch_indices_sampled][:, :, 1:]) ys, xs = torch.stack(sampled_indices, dim=0).unbind(-1) else: xs = torch.randint(0, W, (B, self.conf.n_iter, 3)).to(device) ys = torch.randint(0, H, (B, self.conf.n_iter, 3)).to(device) rpf_pos, rpf_neg = self.solve_rpf( data, xs, ys, principal_points, focal=data.get("prior_focal") ) cams_pos, gravity_pos = self.get_camera_and_gravity(data, rpf_pos) cams_neg, gravity_neg = self.get_camera_and_gravity(data, rpf_neg) inliers = data.get("inliers", None) best_pos, score_pos = self.get_best_index(data, cams_pos, gravity_pos, inliers) best_neg, score_neg = self.get_best_index(data, cams_neg, gravity_neg, inliers) rpf = rpf_pos[torch.arange(B), best_pos] rpf[score_neg > score_pos] = rpf_neg[torch.arange(B), best_neg][score_neg > score_pos] cam, gravity = self.get_camera_and_gravity(data, rpf.unsqueeze(1)) return { "camera_opt": cam, "gravity_opt": gravity, "up_inliers": self.check_up_inliers(data, cam, gravity), "latitude_inliers": self.check_latitude_inliers(data, cam, gravity), } def metrics(self, pred, data): pred_cam, gt_cam = pred["camera_opt"], data["camera"] pred_gravity, gt_gravity = pred["gravity_opt"], data["gravity"] return { "roll_opt_error": roll_error(pred_gravity, gt_gravity), "pitch_opt_error": pitch_error(pred_gravity, gt_gravity), "vfov_opt_error": vfov_error(pred_cam, gt_cam), } def loss(self, pred, data): pred_cam, gt_cam = pred["camera_opt"], data["camera"] pred_gravity, gt_gravity = pred["gravity_opt"], data["gravity"] h = data["camera"].size[0, 0] gravity_loss = F.l1_loss(pred_gravity.vec3d, gt_gravity.vec3d, reduction="none") focal_loss = F.l1_loss(pred_cam.f, gt_cam.f, reduction="none").sum(-1) / h total_loss = gravity_loss.sum(-1) if self.conf.estimate_focal: total_loss = total_loss + focal_loss losses = { "opt_gravity": gravity_loss.sum(-1), "opt_focal": focal_loss, "opt_param_total": total_loss, } losses = {k: v * self.conf.loss_weight for k, v in losses.items()} return losses, self.metrics(pred, data)