veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
4.14 kB
"""Various metrics for evaluating predictions."""
import logging
import torch
from torch.nn import functional as F
from siclib.geometry.base_camera import BaseCamera
from siclib.geometry.gravity import Gravity
from siclib.utils.conversions import rad2deg
logger = logging.getLogger(__name__)
def pitch_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor:
"""Computes the pitch error between two gravities.
Args:
pred_gravity (Gravity): Predicted camera.
target_gravity (Gravity): Ground truth camera.
Returns:
torch.Tensor: Pitch error in degrees.
"""
return rad2deg(torch.abs(pred_gravity.pitch - target_gravity.pitch))
def roll_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor:
"""Computes the roll error between two gravities.
Args:
pred_gravity (Gravity): Predicted Gravity.
target_gravity (Gravity): Ground truth Gravity.
Returns:
torch.Tensor: Roll error in degrees.
"""
return rad2deg(torch.abs(pred_gravity.roll - target_gravity.roll))
def gravity_error(pred_gravity: Gravity, target_gravity: Gravity) -> torch.Tensor:
"""Computes the gravity error between two gravities.
Args:
pred_gravity (Gravity): Predicted Gravity.
target_gravity (Gravity): Ground truth Gravity.
Returns:
torch.Tensor: Gravity error in degrees.
"""
assert (
pred_gravity.vec3d.shape == target_gravity.vec3d.shape
), f"{pred_gravity.vec3d.shape} != {target_gravity.vec3d.shape}"
assert pred_gravity.vec3d.ndim == 2, f"{pred_gravity.vec3d.ndim} != 2"
assert pred_gravity.vec3d.shape[1] == 3, f"{pred_gravity.vec3d.shape[1]} != 3"
cossim = F.cosine_similarity(pred_gravity.vec3d, target_gravity.vec3d, dim=-1).clamp(-1, 1)
return rad2deg(torch.acos(cossim))
def vfov_error(pred_cam: BaseCamera, target_cam: BaseCamera) -> torch.Tensor:
"""Computes the vertical field of view error between two cameras.
Args:
pred_cam (Camera): Predicted camera.
target_cam (Camera): Ground truth camera.
Returns:
torch.Tensor: Vertical field of view error in degrees.
"""
return rad2deg(torch.abs(pred_cam.vfov - target_cam.vfov))
def dist_error(pred_cam: BaseCamera, target_cam: BaseCamera) -> torch.Tensor:
"""Computes the distortion parameter error between two cameras.
Returns zero if the cameras do not have distortion parameters.
Args:
pred_cam (Camera): Predicted camera.
target_cam (Camera): Ground truth camera.
Returns:
torch.Tensor: distortion error.
"""
if hasattr(pred_cam, "dist") and hasattr(target_cam, "dist"):
return torch.abs(pred_cam.dist[..., 0] - target_cam.dist[..., 0])
logger.debug(
f"Predicted / target camera doesn't have distortion parameters: {pred_cam}/{target_cam}"
)
return pred_cam.new_zeros(pred_cam.f.shape[0])
def latitude_error(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Computes the latitude error between two tensors.
Args:
predictions (torch.Tensor): Predicted latitude field of shape (B, 1, H, W).
targets (torch.Tensor): Ground truth latitude field of shape (B, 1, H, W).
Returns:
torch.Tensor: Latitude error in degrees of shape (B, H, W).
"""
return rad2deg(torch.abs(predictions - targets)).squeeze(1)
def up_error(predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Computes the up error between two tensors.
Args:
predictions (torch.Tensor): Predicted up field of shape (B, 2, H, W).
targets (torch.Tensor): Ground truth up field of shape (B, 2, H, W).
Returns:
torch.Tensor: Up error in degrees of shape (B, H, W).
"""
assert predictions.shape == targets.shape, f"{predictions.shape} != {targets.shape}"
assert predictions.ndim == 4, f"{predictions.ndim} != 4"
assert predictions.shape[1] == 2, f"{predictions.shape[1]} != 2"
angle = F.cosine_similarity(predictions, targets, dim=1).clamp(-1, 1)
return rad2deg(torch.acos(angle))