Filipstrozik
Add initial implementation of EllipseRCNN model and dataset utilities
afc2161
from typing import Literal
import torch
@torch.jit.script
def adjugate_matrix(matrix: torch.Tensor) -> torch.Tensor:
"""Return adjugate matrix [1].
Parameters
----------
matrix:
Input matrix
Returns
-------
torch.Tensor
Adjugate of input matrix
References
----------
.. [1] https://en.wikipedia.org/wiki/Adjugate_matrix
"""
cofactor = torch.inverse(matrix).T * torch.det(matrix)
return cofactor.T
# @torch.jit.script
def unimodular_matrix(matrix: torch.Tensor) -> torch.Tensor:
"""Rescale matrix such that det(ellipses) = 1, in other words, make it unimodular. Doest not work with tensors
of dtype torch.float64.
Parameters
----------
matrix:
Matrix input
Returns
-------
torch.Tensor
Unimodular version of input matrix.
"""
val = 1.0 / torch.det(matrix)
return (torch.sign(val) * torch.pow(torch.abs(val), 1.0 / 3.0))[
..., None, None
] * matrix
# @torch.jit.script
def ellipse_to_conic_matrix(
*,
a: torch.Tensor,
b: torch.Tensor,
x: torch.Tensor | None = None,
y: torch.Tensor | None = None,
theta: torch.Tensor | None = None,
) -> torch.Tensor:
r"""Returns matrix representation for crater derived from ellipse parameters such that _[1]:
| A = a²(sin θ)² + b²(cos θ)²
| B = 2(b² - a²) sin θ cos θ
| C = a²(cos θ)² + b²(sin θ)²
| D = -2Ax₀ - By₀
| E = -Bx₀ - 2Cy₀
| F = Ax₀² + Bx₀y₀ + Cy₀² - a²b²
Resulting in a conic matrix:
::
|A B/2 D/2 |
M = |B/2 C E/2 |
|D/2 E/2 G |
Parameters
----------
a:
Semi-Major ellipse axis
b:
Semi-Minor ellipse axis
theta:
Ellipse angle (radians)
x:
X-position in 2D cartesian coordinate system (coplanar)
y:
Y-position in 2D cartesian coordinate system (coplanar)
Returns
-------
torch.Tensor
Array of ellipse matrices
References
----------
.. [1] https://www.researchgate.net/publication/355490899_Lunar_Crater_Identification_in_Digital_Images
"""
x = x if x is not None else torch.zeros(1)
y = y if y is not None else torch.zeros(1)
theta = theta if theta is not None else torch.zeros(1)
sin_theta = torch.sin(theta)
cos_theta = torch.cos(theta)
a2 = a**2
b2 = b**2
A = a2 * sin_theta**2 + b2 * cos_theta**2
B = 2 * (b2 - a2) * sin_theta * cos_theta
C = a2 * cos_theta**2 + b2 * sin_theta**2
D = -2 * A * x - B * y
F = -B * x - 2 * C * y
G = A * (x**2) + B * x * y + C * (y**2) - a2 * b2
# Create (array of) of conic matrix (N, 3, 3)
conic_matrix = torch.stack(
tensors=(
torch.stack((A, B / 2, D / 2), dim=-1),
torch.stack((B / 2, C, F / 2), dim=-1),
torch.stack((D / 2, F / 2, G), dim=-1),
),
dim=-1,
)
return conic_matrix.squeeze()
def conic_center(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns center of ellipse in 2D cartesian coordinate system with numerical stability."""
# Extract the top-left 2x2 submatrix of the conic matrix
A = conic_matrix[..., :2, :2]
# Add stabilization for pseudoinverse computation by clamping singular values
A_pinv = torch.linalg.pinv(A, rcond=torch.finfo(A.dtype).eps)
# Extract the last two rows for the linear term
b = -conic_matrix[..., :2, 2][..., None]
# Stabilize any potential numerical instabilities
centers = torch.matmul(A_pinv, b).squeeze()
return centers[..., 0], centers[..., 1]
def ellipse_axes(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Returns semi-major and semi-minor axes of ellipse in 2D cartesian coordinate system."""
lambdas = (
torch.linalg.eigvalsh(conic_matrix[..., :2, :2])
/ (-torch.det(conic_matrix) / torch.det(conic_matrix[..., :2, :2]))[..., None]
)
axes = torch.sqrt(1 / lambdas)
return axes[..., 0], axes[..., 1]
def ellipse_angle(conic_matrix: torch.Tensor) -> torch.Tensor:
"""Returns angle of ellipse in radians w.r.t. x-axis."""
return (
-torch.atan2(
2 * conic_matrix[..., 1, 0],
conic_matrix[..., 1, 1] - conic_matrix[..., 0, 0],
)
/ 2
)
def bbox_ellipse(
ellipses: torch.Tensor,
box_type: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
) -> torch.Tensor:
"""Converts (array of) ellipse matrices to bounding box tensor with format [xmin, ymin, xmax, ymax].
Parameters
----------
ellipses:
Array of ellipse matrices
box_type:
Format of bounding boxes, default is "xyxy"
Returns
-------
Array of bounding boxes
"""
cx, cy = conic_center(ellipses)
theta = ellipse_angle(ellipses)
semi_major_axis, semi_minor_axis = ellipse_axes(ellipses)
ux, uy = semi_major_axis * torch.cos(theta), semi_major_axis * torch.sin(theta)
vx, vy = (
semi_minor_axis * torch.cos(theta + torch.pi / 2),
semi_minor_axis * torch.sin(theta + torch.pi / 2),
)
box_halfwidth = torch.sqrt(ux**2 + vx**2)
box_halfheight = torch.sqrt(uy**2 + vy**2)
bboxes = torch.vstack(
(
cx - box_halfwidth,
cy - box_halfheight,
cx + box_halfwidth,
cy + box_halfheight,
)
).T
if box_type != "xyxy":
from torchvision.ops import boxes as box_ops
bboxes = box_ops.box_convert(bboxes, in_fmt="xyxy", out_fmt=box_type)
return bboxes