|
import torch |
|
|
|
from ellipse_rcnn.utils.conics import conic_center |
|
|
|
|
|
def wasserstein_distance( |
|
A1: torch.Tensor, |
|
A2: torch.Tensor, |
|
*, |
|
shape_only: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
Compute the squared Wasserstein-2 distance between ellipses represented by their matrices. |
|
|
|
Args: |
|
A1, A2: Ellipse matrices of shape (..., 3, 3) |
|
shape_only: If True, ignores displacement term |
|
|
|
Returns: |
|
Tensor containing Wasserstein distances |
|
""" |
|
|
|
if A1.shape[:-2] != A2.shape[:-2]: |
|
raise ValueError( |
|
f"Batch size mismatch: A1 has shape {A1.shape[:-2]}, A2 has shape {A2.shape[:-2]}" |
|
) |
|
|
|
|
|
cov1 = A1[..., :2, :2] |
|
cov2 = A2[..., :2, :2] |
|
|
|
if shape_only: |
|
displacement_term = 0 |
|
else: |
|
|
|
m1 = torch.vstack(conic_center(A1)).T[..., None] |
|
m2 = torch.vstack(conic_center(A2)).T[..., None] |
|
|
|
|
|
displacement_term = torch.sum((m1 - m2) ** 2, dim=(1, 2)) |
|
|
|
|
|
eigenvalues1, eigenvectors1 = torch.linalg.eigh(cov1) |
|
sqrt_eigenvalues1 = torch.sqrt(torch.clamp(eigenvalues1, min=1e-7)) |
|
sqrt_cov1 = ( |
|
eigenvectors1 |
|
@ torch.diag_embed(sqrt_eigenvalues1) |
|
@ eigenvectors1.transpose(-2, -1) |
|
) |
|
|
|
inner_term = sqrt_cov1 @ cov2 @ sqrt_cov1 |
|
eigenvalues_inner, eigenvectors_inner = torch.linalg.eigh(inner_term) |
|
sqrt_inner = ( |
|
eigenvectors_inner |
|
@ torch.diag_embed(torch.sqrt(torch.clamp(eigenvalues_inner, min=1e-7))) |
|
@ eigenvectors_inner.transpose(-2, -1) |
|
) |
|
|
|
trace_term = ( |
|
torch.diagonal(cov1, dim1=-2, dim2=-1).sum(-1) |
|
+ torch.diagonal(cov2, dim1=-2, dim2=-1).sum(-1) |
|
- 2 * torch.diagonal(sqrt_inner, dim1=-2, dim2=-1).sum(-1) |
|
) |
|
|
|
return displacement_term + trace_term |
|
|
|
|
|
def symmetric_wasserstein_distance( |
|
A1: torch.Tensor, |
|
A2: torch.Tensor, |
|
*, |
|
shape_only: bool = False, |
|
nan_to_num: float = float(1e4), |
|
normalize: bool = False, |
|
) -> torch.Tensor: |
|
""" |
|
Compute symmetric Wasserstein distance between ellipses. |
|
|
|
Args: |
|
A1, A2: Ellipse matrices |
|
shape_only: If True, ignores displacement term |
|
nan_to_num: Value to replace NaN entries with |
|
normalize: If True, normalizes the output to [0, 1] |
|
""" |
|
w = torch.nan_to_num( |
|
wasserstein_distance(A1, A2, shape_only=shape_only), nan=nan_to_num |
|
) |
|
|
|
if w.lt(0).any(): |
|
raise ValueError("Negative Wasserstein distance encountered.") |
|
|
|
if normalize: |
|
w = 1 - torch.exp(-w) |
|
return w |
|
|
|
|
|
class WassersteinLoss(torch.nn.Module): |
|
""" |
|
Computes the Wasserstein distance loss between two ellipse tensors. |
|
|
|
The Wasserstein distance provides a natural metric for comparing probability |
|
distributions or shapes, with advantages over KL divergence such as: |
|
- It's symmetric by definition |
|
- It provides a true metric (satisfies triangle inequality) |
|
- It's well-behaved even when distributions have different supports |
|
|
|
Attributes: |
|
shape_only: If True, computes distance based on shape without considering position |
|
nan_to_num: Value to replace NaN entries with |
|
normalize: If True, normalizes output to [0, 1] using exponential scaling |
|
""" |
|
|
|
def __init__( |
|
self, shape_only: bool = True, nan_to_num: float = 10.0, normalize: bool = False |
|
): |
|
super().__init__() |
|
self.shape_only = shape_only |
|
self.nan_to_num = nan_to_num |
|
self.normalize = normalize |
|
|
|
def forward(self, A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor: |
|
return symmetric_wasserstein_distance( |
|
A1, |
|
A2, |
|
shape_only=self.shape_only, |
|
nan_to_num=self.nan_to_num, |
|
normalize=self.normalize, |
|
) |
|
|