File size: 2,541 Bytes
afc2161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
from ellipse_rcnn.utils.conics import conic_center


def gaussian_angle_distance(A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor:
    # Extract covariance matrices (negative of the top-left blocks)
    cov1, cov2 = map(lambda arr: -arr[..., :2, :2], (A1, A2))

    # Extract the means by computing conic centers
    c1_x, c1_y = conic_center(A1)
    c2_x, c2_y = conic_center(A2)

    # Stack the conic centers into the appropriate shape for computation
    m1 = torch.stack((c1_x, c1_y), dim=-1)[..., None]
    m2 = torch.stack((c2_x, c2_y), dim=-1)[..., None]

    # Compute determinants for covariance matrices
    det_cov1 = torch.clamp(cov1.det(), min=torch.finfo(cov1.dtype).eps)
    det_cov2 = torch.clamp(cov2.det(), min=torch.finfo(cov2.dtype).eps)
    cov_sum = cov1 + cov2

    # Determinant of sum (clamped for numerical stability)
    det_cov_sum = torch.clamp(cov_sum.det(), min=torch.finfo(cov_sum.dtype).eps)

    # Compute fractional term with stabilized determinants
    frac_term = (4 * torch.sqrt(det_cov1 * det_cov2)) / det_cov_sum
    # Stable computation of the exponential term
    mean_diff = m1 - m2
    cov_sum_inv = torch.linalg.solve(
        cov_sum, torch.eye(cov_sum.size(-1), dtype=cov_sum.dtype, device=cov_sum.device)
    )
    exp_arg = -0.5 * mean_diff.transpose(-1, -2) @ cov1 @ cov_sum_inv @ cov2 @ mean_diff
    exp_term = torch.exp(torch.clamp(exp_arg, min=-50, max=50)).squeeze()

    angle_term = frac_term * exp_term

    return torch.arccos(angle_term)


class GaussianAngleDistanceLoss(torch.nn.Module):
    """
    Computes the Gaussian Angle Distance loss between two tensors.

    This class serves as a wrapper around the `gaussian_angle_distance` function,
    providing a clean interface and ensuring numerical stability.

    Attributes
    ----------
    normalize : bool

    nan_to_num : float
        The value to replace NaN entries in the computation with. Helps maintain numerical
        stability in cases where the input tensors contain undefined or invalid values.
    """

    def __init__(self, normalize: bool = True, nan_to_num: float = 10.0):
        super().__init__()
        self.nan_to_num = nan_to_num

    def forward(self, A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor:
        # Calculate the Gaussian angle distance
        distance = gaussian_angle_distance(A1, A2)

        # Replace NaN values with a predefined constant for numerical stability
        distance = torch.nan_to_num(distance, nan=self.nan_to_num)

        return distance