Filipstrozik
Add initial implementation of EllipseRCNN model and dataset utilities
afc2161
import torch
from ellipse_rcnn.utils.conics import conic_center
def gaussian_angle_distance(A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor:
# Extract covariance matrices (negative of the top-left blocks)
cov1, cov2 = map(lambda arr: -arr[..., :2, :2], (A1, A2))
# Extract the means by computing conic centers
c1_x, c1_y = conic_center(A1)
c2_x, c2_y = conic_center(A2)
# Stack the conic centers into the appropriate shape for computation
m1 = torch.stack((c1_x, c1_y), dim=-1)[..., None]
m2 = torch.stack((c2_x, c2_y), dim=-1)[..., None]
# Compute determinants for covariance matrices
det_cov1 = torch.clamp(cov1.det(), min=torch.finfo(cov1.dtype).eps)
det_cov2 = torch.clamp(cov2.det(), min=torch.finfo(cov2.dtype).eps)
cov_sum = cov1 + cov2
# Determinant of sum (clamped for numerical stability)
det_cov_sum = torch.clamp(cov_sum.det(), min=torch.finfo(cov_sum.dtype).eps)
# Compute fractional term with stabilized determinants
frac_term = (4 * torch.sqrt(det_cov1 * det_cov2)) / det_cov_sum
# Stable computation of the exponential term
mean_diff = m1 - m2
cov_sum_inv = torch.linalg.solve(
cov_sum, torch.eye(cov_sum.size(-1), dtype=cov_sum.dtype, device=cov_sum.device)
)
exp_arg = -0.5 * mean_diff.transpose(-1, -2) @ cov1 @ cov_sum_inv @ cov2 @ mean_diff
exp_term = torch.exp(torch.clamp(exp_arg, min=-50, max=50)).squeeze()
angle_term = frac_term * exp_term
return torch.arccos(angle_term)
class GaussianAngleDistanceLoss(torch.nn.Module):
"""
Computes the Gaussian Angle Distance loss between two tensors.
This class serves as a wrapper around the `gaussian_angle_distance` function,
providing a clean interface and ensuring numerical stability.
Attributes
----------
normalize : bool
nan_to_num : float
The value to replace NaN entries in the computation with. Helps maintain numerical
stability in cases where the input tensors contain undefined or invalid values.
"""
def __init__(self, normalize: bool = True, nan_to_num: float = 10.0):
super().__init__()
self.nan_to_num = nan_to_num
def forward(self, A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor:
# Calculate the Gaussian angle distance
distance = gaussian_angle_distance(A1, A2)
# Replace NaN values with a predefined constant for numerical stability
distance = torch.nan_to_num(distance, nan=self.nan_to_num)
return distance