File size: 3,119 Bytes
afc2161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
from __future__ import annotations

from typing import Literal
from torch import Tensor
import numpy as np
import torch
from torchvision.ops import boxes as box_ops
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.collections import EllipseCollection, PatchCollection
from matplotlib.patches import Rectangle
from ellipse_rcnn.utils.conics import ellipse_angle, conic_center, ellipse_axes
from matplotlib.figure import Figure


def plot_single_pred(
    image: Tensor,
    prediction,
    min_score: float = 0.75,
) -> Figure:
    if isinstance(prediction, list):
        if len(prediction) > 1:
            raise ValueError(
                "Multiple predictions detected. Please pass a single prediction."
            )
        prediction = prediction[0]
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    fig.patch.set_alpha(0)
    ax.imshow(image.permute(1, 2, 0), cmap="grey")
    score_mask = prediction["scores"] > min_score

    plot_ellipses(prediction["ellipse_matrices"][score_mask], ax=ax)

    return fig


def plot_ellipses(
    A_craters: torch.Tensor,
    figsize: tuple[float, float] = (15, 15),
    plot_centers: bool = False,
    ax: Axes | None = None,
    rim_color="r",
    alpha=1.0,
):
    a_proj, b_proj = ellipse_axes(A_craters)
    psi_proj = ellipse_angle(A_craters)
    x_pix_proj, y_pix_proj = conic_center(A_craters)

    a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj = map(
        lambda t: t.detach().cpu().numpy(),
        (a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj),
    )

    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"})

    ec = EllipseCollection(
        a_proj * 2,
        b_proj * 2,
        np.degrees(psi_proj),
        units="xy",
        offsets=np.column_stack((x_pix_proj, y_pix_proj)),
        transOffset=ax.transData,
        facecolors="None",
        edgecolors=rim_color,
        alpha=alpha,
    )
    ax.add_collection(ec)

    if plot_centers:
        crater_centers = conic_center(A_craters)
        for k, c_i in enumerate(crater_centers):
            x, y = c_i[0], c_i[1]
            ax.text(x.item(), y.item(), str(k), color=rim_color)


def plot_bboxes(
    boxes: torch.Tensor,
    box_type: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
    figsize: tuple[float, float] = (15, 15),
    plot_centers: bool = False,
    ax: Axes | None = None,
    rim_color="r",
    alpha=1.0,
):
    if ax is None:
        fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"})

    if box_type != "xyxy":
        boxes = box_ops.box_convert(boxes, box_type, "xyxy")

    boxes = boxes.detach().cpu().numpy()
    rectangles = []
    for k, b_i in enumerate(boxes):
        x1, y1, x2, y2 = b_i
        rectangles.append(Rectangle((x1, y1), x2 - x1, y2 - y1))

    collection = PatchCollection(
        rectangles, edgecolor=rim_color, facecolor="none", alpha=alpha
    )
    ax.add_collection(collection)

    if plot_centers:
        for k, b_i in enumerate(boxes):
            x1, y1, x2, y2 = b_i
            ax.text(x1, y1, str(k), color=rim_color)