File size: 3,119 Bytes
afc2161 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from __future__ import annotations
from typing import Literal
from torch import Tensor
import numpy as np
import torch
from torchvision.ops import boxes as box_ops
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.collections import EllipseCollection, PatchCollection
from matplotlib.patches import Rectangle
from ellipse_rcnn.utils.conics import ellipse_angle, conic_center, ellipse_axes
from matplotlib.figure import Figure
def plot_single_pred(
image: Tensor,
prediction,
min_score: float = 0.75,
) -> Figure:
if isinstance(prediction, list):
if len(prediction) > 1:
raise ValueError(
"Multiple predictions detected. Please pass a single prediction."
)
prediction = prediction[0]
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
fig.patch.set_alpha(0)
ax.imshow(image.permute(1, 2, 0), cmap="grey")
score_mask = prediction["scores"] > min_score
plot_ellipses(prediction["ellipse_matrices"][score_mask], ax=ax)
return fig
def plot_ellipses(
A_craters: torch.Tensor,
figsize: tuple[float, float] = (15, 15),
plot_centers: bool = False,
ax: Axes | None = None,
rim_color="r",
alpha=1.0,
):
a_proj, b_proj = ellipse_axes(A_craters)
psi_proj = ellipse_angle(A_craters)
x_pix_proj, y_pix_proj = conic_center(A_craters)
a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj = map(
lambda t: t.detach().cpu().numpy(),
(a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj),
)
if ax is None:
fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"})
ec = EllipseCollection(
a_proj * 2,
b_proj * 2,
np.degrees(psi_proj),
units="xy",
offsets=np.column_stack((x_pix_proj, y_pix_proj)),
transOffset=ax.transData,
facecolors="None",
edgecolors=rim_color,
alpha=alpha,
)
ax.add_collection(ec)
if plot_centers:
crater_centers = conic_center(A_craters)
for k, c_i in enumerate(crater_centers):
x, y = c_i[0], c_i[1]
ax.text(x.item(), y.item(), str(k), color=rim_color)
def plot_bboxes(
boxes: torch.Tensor,
box_type: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
figsize: tuple[float, float] = (15, 15),
plot_centers: bool = False,
ax: Axes | None = None,
rim_color="r",
alpha=1.0,
):
if ax is None:
fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"})
if box_type != "xyxy":
boxes = box_ops.box_convert(boxes, box_type, "xyxy")
boxes = boxes.detach().cpu().numpy()
rectangles = []
for k, b_i in enumerate(boxes):
x1, y1, x2, y2 = b_i
rectangles.append(Rectangle((x1, y1), x2 - x1, y2 - y1))
collection = PatchCollection(
rectangles, edgecolor=rim_color, facecolor="none", alpha=alpha
)
ax.add_collection(collection)
if plot_centers:
for k, b_i in enumerate(boxes):
x1, y1, x2, y2 = b_i
ax.text(x1, y1, str(k), color=rim_color)
|