from __future__ import annotations from typing import Literal from torch import Tensor import numpy as np import torch from torchvision.ops import boxes as box_ops from matplotlib import pyplot as plt from matplotlib.axes import Axes from matplotlib.collections import EllipseCollection, PatchCollection from matplotlib.patches import Rectangle from ellipse_rcnn.utils.conics import ellipse_angle, conic_center, ellipse_axes from matplotlib.figure import Figure def plot_single_pred( image: Tensor, prediction, min_score: float = 0.75, ) -> Figure: if isinstance(prediction, list): if len(prediction) > 1: raise ValueError( "Multiple predictions detected. Please pass a single prediction." ) prediction = prediction[0] fig, ax = plt.subplots(1, 1, figsize=(10, 10)) fig.patch.set_alpha(0) ax.imshow(image.permute(1, 2, 0), cmap="grey") score_mask = prediction["scores"] > min_score plot_ellipses(prediction["ellipse_matrices"][score_mask], ax=ax) return fig def plot_ellipses( A_craters: torch.Tensor, figsize: tuple[float, float] = (15, 15), plot_centers: bool = False, ax: Axes | None = None, rim_color="r", alpha=1.0, ): a_proj, b_proj = ellipse_axes(A_craters) psi_proj = ellipse_angle(A_craters) x_pix_proj, y_pix_proj = conic_center(A_craters) a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj = map( lambda t: t.detach().cpu().numpy(), (a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj), ) if ax is None: fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"}) ec = EllipseCollection( a_proj * 2, b_proj * 2, np.degrees(psi_proj), units="xy", offsets=np.column_stack((x_pix_proj, y_pix_proj)), transOffset=ax.transData, facecolors="None", edgecolors=rim_color, alpha=alpha, ) ax.add_collection(ec) if plot_centers: crater_centers = conic_center(A_craters) for k, c_i in enumerate(crater_centers): x, y = c_i[0], c_i[1] ax.text(x.item(), y.item(), str(k), color=rim_color) def plot_bboxes( boxes: torch.Tensor, box_type: Literal["xyxy", "xywh", "cxcywh"] = "xyxy", figsize: tuple[float, float] = (15, 15), plot_centers: bool = False, ax: Axes | None = None, rim_color="r", alpha=1.0, ): if ax is None: fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"}) if box_type != "xyxy": boxes = box_ops.box_convert(boxes, box_type, "xyxy") boxes = boxes.detach().cpu().numpy() rectangles = [] for k, b_i in enumerate(boxes): x1, y1, x2, y2 = b_i rectangles.append(Rectangle((x1, y1), x2 - x1, y2 - y1)) collection = PatchCollection( rectangles, edgecolor=rim_color, facecolor="none", alpha=alpha ) ax.add_collection(collection) if plot_centers: for k, b_i in enumerate(boxes): x1, y1, x2, y2 = b_i ax.text(x1, y1, str(k), color=rim_color)