|
from __future__ import annotations |
|
|
|
from typing import Literal |
|
from torch import Tensor |
|
import numpy as np |
|
import torch |
|
from torchvision.ops import boxes as box_ops |
|
from matplotlib import pyplot as plt |
|
from matplotlib.axes import Axes |
|
from matplotlib.collections import EllipseCollection, PatchCollection |
|
from matplotlib.patches import Rectangle |
|
from ellipse_rcnn.utils.conics import ellipse_angle, conic_center, ellipse_axes |
|
from matplotlib.figure import Figure |
|
|
|
|
|
def plot_single_pred( |
|
image: Tensor, |
|
prediction, |
|
min_score: float = 0.75, |
|
) -> Figure: |
|
if isinstance(prediction, list): |
|
if len(prediction) > 1: |
|
raise ValueError( |
|
"Multiple predictions detected. Please pass a single prediction." |
|
) |
|
prediction = prediction[0] |
|
fig, ax = plt.subplots(1, 1, figsize=(10, 10)) |
|
fig.patch.set_alpha(0) |
|
ax.imshow(image.permute(1, 2, 0), cmap="grey") |
|
score_mask = prediction["scores"] > min_score |
|
|
|
plot_ellipses(prediction["ellipse_matrices"][score_mask], ax=ax) |
|
|
|
return fig |
|
|
|
|
|
def plot_ellipses( |
|
A_craters: torch.Tensor, |
|
figsize: tuple[float, float] = (15, 15), |
|
plot_centers: bool = False, |
|
ax: Axes | None = None, |
|
rim_color="r", |
|
alpha=1.0, |
|
): |
|
a_proj, b_proj = ellipse_axes(A_craters) |
|
psi_proj = ellipse_angle(A_craters) |
|
x_pix_proj, y_pix_proj = conic_center(A_craters) |
|
|
|
a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj = map( |
|
lambda t: t.detach().cpu().numpy(), |
|
(a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj), |
|
) |
|
|
|
if ax is None: |
|
fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"}) |
|
|
|
ec = EllipseCollection( |
|
a_proj * 2, |
|
b_proj * 2, |
|
np.degrees(psi_proj), |
|
units="xy", |
|
offsets=np.column_stack((x_pix_proj, y_pix_proj)), |
|
transOffset=ax.transData, |
|
facecolors="None", |
|
edgecolors=rim_color, |
|
alpha=alpha, |
|
) |
|
ax.add_collection(ec) |
|
|
|
if plot_centers: |
|
crater_centers = conic_center(A_craters) |
|
for k, c_i in enumerate(crater_centers): |
|
x, y = c_i[0], c_i[1] |
|
ax.text(x.item(), y.item(), str(k), color=rim_color) |
|
|
|
|
|
def plot_bboxes( |
|
boxes: torch.Tensor, |
|
box_type: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", |
|
figsize: tuple[float, float] = (15, 15), |
|
plot_centers: bool = False, |
|
ax: Axes | None = None, |
|
rim_color="r", |
|
alpha=1.0, |
|
): |
|
if ax is None: |
|
fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"}) |
|
|
|
if box_type != "xyxy": |
|
boxes = box_ops.box_convert(boxes, box_type, "xyxy") |
|
|
|
boxes = boxes.detach().cpu().numpy() |
|
rectangles = [] |
|
for k, b_i in enumerate(boxes): |
|
x1, y1, x2, y2 = b_i |
|
rectangles.append(Rectangle((x1, y1), x2 - x1, y2 - y1)) |
|
|
|
collection = PatchCollection( |
|
rectangles, edgecolor=rim_color, facecolor="none", alpha=alpha |
|
) |
|
ax.add_collection(collection) |
|
|
|
if plot_centers: |
|
for k, b_i in enumerate(boxes): |
|
x1, y1, x2, y2 = b_i |
|
ax.text(x1, y1, str(k), color=rim_color) |
|
|