Filipstrozik
commited on
Commit
·
afc2161
1
Parent(s):
2c3fe3e
Add initial implementation of EllipseRCNN model and dataset utilities
Browse files- Introduced core model and loss functions for ellipse detection.
- Added dataset classes for loading and processing crater images.
- Included type definitions for better code clarity.
- Created requirements file for necessary dependencies.
- Added README documentation for core functionalities.
- app.py +225 -0
- ellipse_rcnn/__init__.py +1 -0
- ellipse_rcnn/__pycache__/__init__.cpython-312.pyc +0 -0
- ellipse_rcnn/core/README.md +40 -0
- ellipse_rcnn/core/__init__.py +1 -0
- ellipse_rcnn/core/__pycache__/__init__.cpython-312.pyc +0 -0
- ellipse_rcnn/core/__pycache__/ellipse_roi_head.cpython-312.pyc +0 -0
- ellipse_rcnn/core/__pycache__/kld.cpython-312.pyc +0 -0
- ellipse_rcnn/core/__pycache__/model.cpython-312.pyc +0 -0
- ellipse_rcnn/core/__pycache__/wd.cpython-312.pyc +0 -0
- ellipse_rcnn/core/ellipse_roi_head.py +429 -0
- ellipse_rcnn/core/ga.py +67 -0
- ellipse_rcnn/core/kld.py +124 -0
- ellipse_rcnn/core/model.py +282 -0
- ellipse_rcnn/core/wd.py +128 -0
- ellipse_rcnn/utils/__init__.py +0 -0
- ellipse_rcnn/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- ellipse_rcnn/utils/__pycache__/conics.cpython-312.pyc +0 -0
- ellipse_rcnn/utils/__pycache__/types.cpython-312.pyc +0 -0
- ellipse_rcnn/utils/__pycache__/viz.cpython-312.pyc +0 -0
- ellipse_rcnn/utils/conics.py +209 -0
- ellipse_rcnn/utils/data/__init__.py +0 -0
- ellipse_rcnn/utils/data/__pycache__/__init__.cpython-312.pyc +0 -0
- ellipse_rcnn/utils/data/__pycache__/base.cpython-312.pyc +0 -0
- ellipse_rcnn/utils/data/base.py +62 -0
- ellipse_rcnn/utils/data/craters.py +54 -0
- ellipse_rcnn/utils/data/fddb.py +239 -0
- ellipse_rcnn/utils/types.py +46 -0
- ellipse_rcnn/utils/viz.py +106 -0
- examples/image1.jpg +0 -0
- examples/image2.jpg +0 -0
- examples/image3.jpg +0 -0
- requirements.txt +8 -0
app.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import io
|
2 |
+
from ast import mod
|
3 |
+
import gradio as gr
|
4 |
+
from PIL import Image
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import matplotlib.patches as mpatches
|
7 |
+
import torchvision.transforms as transforms
|
8 |
+
import torch
|
9 |
+
from huggingface_hub import hf_hub_download
|
10 |
+
from ellipse_rcnn import EllipseRCNN
|
11 |
+
|
12 |
+
|
13 |
+
# load model.pth from Filipstrozik/sat-tree-detection-v0 repository in hugging face
|
14 |
+
load_state_dict = torch.load(
|
15 |
+
hf_hub_download("Filipstrozik/sat-tree-detection-v0", "model.pth"),
|
16 |
+
weights_only=True,
|
17 |
+
)
|
18 |
+
model = EllipseRCNN()
|
19 |
+
|
20 |
+
model.load_state_dict(load_state_dict)
|
21 |
+
model.eval()
|
22 |
+
|
23 |
+
|
24 |
+
def conic_center(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
25 |
+
"""Returns center of ellipse in 2D cartesian coordinate system with numerical stability."""
|
26 |
+
# Extract the top-left 2x2 submatrix of the conic matrix
|
27 |
+
A = conic_matrix[..., :2, :2]
|
28 |
+
|
29 |
+
# Add stabilization for pseudoinverse computation by clamping singular values
|
30 |
+
A_pinv = torch.linalg.pinv(A, rcond=torch.finfo(A.dtype).eps)
|
31 |
+
|
32 |
+
# Extract the last two rows for the linear term
|
33 |
+
b = -conic_matrix[..., :2, 2][..., None]
|
34 |
+
|
35 |
+
# Stabilize any potential numerical instabilities
|
36 |
+
centers = torch.matmul(A_pinv, b).squeeze()
|
37 |
+
|
38 |
+
return centers[..., 0], centers[..., 1]
|
39 |
+
|
40 |
+
|
41 |
+
def ellipse_axes(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
42 |
+
"""Returns semi-major and semi-minor axes of ellipse in 2D cartesian coordinate system."""
|
43 |
+
lambdas = (
|
44 |
+
torch.linalg.eigvalsh(conic_matrix[..., :2, :2])
|
45 |
+
/ (-torch.det(conic_matrix) / torch.det(conic_matrix[..., :2, :2]))[..., None]
|
46 |
+
)
|
47 |
+
axes = torch.sqrt(1 / lambdas)
|
48 |
+
return axes[..., 0], axes[..., 1]
|
49 |
+
|
50 |
+
|
51 |
+
def ellipse_angle(conic_matrix: torch.Tensor) -> torch.Tensor:
|
52 |
+
"""Returns angle of ellipse in radians w.r.t. x-axis."""
|
53 |
+
return (
|
54 |
+
-torch.atan2(
|
55 |
+
2 * conic_matrix[..., 1, 0],
|
56 |
+
conic_matrix[..., 1, 1] - conic_matrix[..., 0, 0],
|
57 |
+
)
|
58 |
+
/ 2
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
def get_ellipse_params_from_matrices(ellipse_matrices):
|
63 |
+
if ellipse_matrices.shape[0] == 0:
|
64 |
+
return None
|
65 |
+
a, b = ellipse_axes(ellipse_matrices)
|
66 |
+
cx, cy = conic_center(ellipse_matrices)
|
67 |
+
theta = ellipse_angle(ellipse_matrices)
|
68 |
+
|
69 |
+
a = a.view(-1)
|
70 |
+
b = b.view(-1)
|
71 |
+
cx = cx.view(-1)
|
72 |
+
cy = cy.view(-1)
|
73 |
+
theta = theta.view(-1)
|
74 |
+
|
75 |
+
ellipses = torch.stack([a, b, cx, cy, theta], dim=1).reshape(-1, 5)
|
76 |
+
return ellipses
|
77 |
+
|
78 |
+
|
79 |
+
def plot_ellipses(
|
80 |
+
ellipse_params: torch.Tensor,
|
81 |
+
image: torch.Tensor,
|
82 |
+
plot_centers: bool = False,
|
83 |
+
rim_color: str = "r",
|
84 |
+
alpha: float = 0.25,
|
85 |
+
) -> None:
|
86 |
+
if ellipse_params is None:
|
87 |
+
return
|
88 |
+
a, b, cx, cy, theta = ellipse_params.unbind(-1)
|
89 |
+
|
90 |
+
# multiply all pixel values by 4
|
91 |
+
cx = cx * 4
|
92 |
+
cy = cy * 4
|
93 |
+
|
94 |
+
# draw ellipses
|
95 |
+
for i in range(len(a)):
|
96 |
+
ellipse = mpatches.Ellipse(
|
97 |
+
(cx[i], cy[i]),
|
98 |
+
width=a[i],
|
99 |
+
height=b[i],
|
100 |
+
angle=theta[i],
|
101 |
+
fill=True,
|
102 |
+
alpha=alpha,
|
103 |
+
color=rim_color,
|
104 |
+
)
|
105 |
+
plt.gca().add_patch(ellipse)
|
106 |
+
|
107 |
+
if plot_centers:
|
108 |
+
plt.scatter(cx[i], cy[i], c=rim_color, s=10)
|
109 |
+
|
110 |
+
plt.imshow(image)
|
111 |
+
|
112 |
+
|
113 |
+
# Define the necessary transformations and the inverse normalization
|
114 |
+
def invert_normalization(image, mean, std):
|
115 |
+
for t, m, s in zip(image, mean, std):
|
116 |
+
t.mul_(s).add_(m)
|
117 |
+
return torch.clamp(image, 0, 1)
|
118 |
+
|
119 |
+
|
120 |
+
def process_image(image):
|
121 |
+
original_size = image.size
|
122 |
+
|
123 |
+
# Define the transform pipeline
|
124 |
+
transform = transforms.Compose(
|
125 |
+
[
|
126 |
+
transforms.Resize((1024, 1024)),
|
127 |
+
transforms.PILToTensor(),
|
128 |
+
transforms.ConvertImageDtype(torch.float),
|
129 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
130 |
+
]
|
131 |
+
)
|
132 |
+
|
133 |
+
image_tensor = transform(image).unsqueeze(0) # Add batch dimension
|
134 |
+
return image_tensor, original_size
|
135 |
+
|
136 |
+
|
137 |
+
def generate_prediction(image, rpn_nms_thresh, score_thresh, nms_thresh):
|
138 |
+
# Preprocess image
|
139 |
+
image_tensor, original_size = process_image(image)
|
140 |
+
image_tensor = image_tensor.to("cpu")
|
141 |
+
|
142 |
+
# Ensure the model is in evaluation mode
|
143 |
+
model.rpn.nms_thresh = rpn_nms_thresh
|
144 |
+
model.roi_heads.score_thresh = score_thresh
|
145 |
+
model.roi_heads.nms_thresh = nms_thresh
|
146 |
+
|
147 |
+
with torch.no_grad():
|
148 |
+
prediction = model(image_tensor)[0]
|
149 |
+
|
150 |
+
# Invert normalization for display
|
151 |
+
mean = [0.485, 0.456, 0.406]
|
152 |
+
std = [0.229, 0.224, 0.225]
|
153 |
+
inverted_image = (
|
154 |
+
invert_normalization(image_tensor, mean, std)
|
155 |
+
.squeeze(0)
|
156 |
+
.permute(1, 2, 0)
|
157 |
+
.cpu()
|
158 |
+
.numpy()
|
159 |
+
)
|
160 |
+
|
161 |
+
# Plot results with ellipses
|
162 |
+
plt.figure(figsize=(10, 10))
|
163 |
+
plt.imshow(inverted_image)
|
164 |
+
plot_ellipses(
|
165 |
+
get_ellipse_params_from_matrices(prediction["ellipse_matrices"]),
|
166 |
+
inverted_image,
|
167 |
+
plot_centers=True,
|
168 |
+
rim_color="red",
|
169 |
+
alpha=0.25,
|
170 |
+
)
|
171 |
+
red_patch = mpatches.Patch(color="red", label="Predicted")
|
172 |
+
plt.legend(handles=[red_patch], loc="upper right")
|
173 |
+
plt.gca().set_aspect(original_size[0] / original_size[1])
|
174 |
+
plt.axis("off")
|
175 |
+
plt.tight_layout()
|
176 |
+
# Save the figure to a buffer and return as an image
|
177 |
+
buf = io.BytesIO()
|
178 |
+
plt.savefig(buf, format="png")
|
179 |
+
buf.seek(0)
|
180 |
+
with Image.open(buf) as output_image:
|
181 |
+
output_image = output_image.copy()
|
182 |
+
buf.close()
|
183 |
+
return output_image
|
184 |
+
|
185 |
+
|
186 |
+
# Define Gradio interface
|
187 |
+
with gr.Blocks() as demo:
|
188 |
+
gr.Markdown("## Tree Detection from Satellite Images")
|
189 |
+
gr.Markdown("Upload an image and see the detected trees with ellipses.")
|
190 |
+
|
191 |
+
with gr.Row():
|
192 |
+
image_input = gr.Image(label="Input Image", type="pil")
|
193 |
+
image_output = gr.Image(label="Detected Trees")
|
194 |
+
|
195 |
+
examples = [
|
196 |
+
["examples/image1.jpg"],
|
197 |
+
["examples/image2.jpg"],
|
198 |
+
["examples/image3.jpg"],
|
199 |
+
]
|
200 |
+
|
201 |
+
with gr.Row():
|
202 |
+
rpn_nms_slider = gr.Slider(
|
203 |
+
0.0, 1.0, value=model.rpn.nms_thresh, label="RPN NMS Threshold"
|
204 |
+
)
|
205 |
+
score_thresh_slider = gr.Slider(
|
206 |
+
0.0,
|
207 |
+
1.0,
|
208 |
+
value=model.roi_heads.score_thresh,
|
209 |
+
label="ROI Heads Score Threshold",
|
210 |
+
)
|
211 |
+
nms_thresh_slider = gr.Slider(
|
212 |
+
0.0, 1.0, value=model.roi_heads.nms_thresh, label="ROI Heads NMS Threshold"
|
213 |
+
)
|
214 |
+
|
215 |
+
submit_button = gr.Button("Detect Trees")
|
216 |
+
submit_button.click(
|
217 |
+
fn=generate_prediction,
|
218 |
+
inputs=[image_input, rpn_nms_slider, score_thresh_slider, nms_thresh_slider],
|
219 |
+
outputs=image_output,
|
220 |
+
)
|
221 |
+
|
222 |
+
gr.Examples(examples=examples, inputs=image_input, outputs=image_output)
|
223 |
+
|
224 |
+
|
225 |
+
demo.launch()
|
ellipse_rcnn/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .core.model import EllipseRCNN
|
ellipse_rcnn/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (223 Bytes). View file
|
|
ellipse_rcnn/core/README.md
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# README.md
|
2 |
+
|
3 |
+
## Symmetric Kullback-Leibler Divergence Loss
|
4 |
+
|
5 |
+
This directory provides an implementation of a **Symmetric Kullback-Leibler (KL) Divergence Loss** tailored for tensors
|
6 |
+
representing ellipses in matrix form. The loss function is designed to measure the difference between two elliptical
|
7 |
+
shapes and is particularly useful in optimization and generative modeling tasks.
|
8 |
+
|
9 |
+
## Loss Calculation
|
10 |
+
|
11 |
+
### **Kullback-Leibler Divergence**
|
12 |
+
|
13 |
+
For two ellipses represented by their matrix forms ( $A_1$ ) and ( $A_2$ ), the KL divergence is calculated as:
|
14 |
+
$$ D_{KL}(A_1 \parallel A_2) = \frac{1}{2} \left( \text{Tr}(C_2^{-1}C_1) + (\mu_1 - \mu_2)^T C_2^{-1} (\mu_1 - \mu_2) - 2 + \log\left(\frac{\det(C_2)}{\det(C_1)}\right) \right) $$
|
15 |
+
Where:
|
16 |
+
|
17 |
+
- ( $C_1$, $C_2$ ): Covariance matrices extracted from ( $A_1$, $A_2$ ).
|
18 |
+
- ( $\mu_1$, $\mu_2$ ): Centers (means) of the ellipses, computed from the conic representation.
|
19 |
+
- ( $\text{Tr}$ ): Trace operator.
|
20 |
+
- ( $C_2^{-1}$ ): Inverse of the covariance matrix of ( $A_2$ ).
|
21 |
+
- ( $\det(C_1)$, $\det(C_2)$ ): Determinants of covariance matrices.
|
22 |
+
|
23 |
+
A regularization term ( $\epsilon$ ) is added to ensure numerical stability when computing inverses and determinants.
|
24 |
+
|
25 |
+
### **Symmetric KL Divergence**
|
26 |
+
|
27 |
+
The symmetric version of the KL divergence combines the calculations in both directions:
|
28 |
+
$$ D_{KL}^{\text{sym}}(A_1, A_2) = \frac{1}{2} \left( D_{KL}(A_1 \parallel A_2) + D_{KL}(A_2 \parallel A_1) \right) $$
|
29 |
+
This ensures a bidirectional comparison, making the function suitable as a loss metric in optimization tasks.
|
30 |
+
|
31 |
+
## Features of the Loss
|
32 |
+
|
33 |
+
- **Shape-Only Comparison**: Option to ignore translation and compute divergence based purely on the shapes (covariance
|
34 |
+
matrices).
|
35 |
+
- **NaN Handling**: Replaces NaN values with a specified constant, ensuring robust loss evaluation.
|
36 |
+
- **Normalization**: An optional normalization step that rescales the divergence for certain applications.
|
37 |
+
|
38 |
+
### Usage
|
39 |
+
|
40 |
+
The loss is encapsulated in the `SymmetricKLDLoss` class, which integrates seamlessly into PyTorch-based workflows.
|
ellipse_rcnn/core/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .model import EllipseRCNN # noqa: F401
|
ellipse_rcnn/core/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (223 Bytes). View file
|
|
ellipse_rcnn/core/__pycache__/ellipse_roi_head.cpython-312.pyc
ADDED
Binary file (19.2 kB). View file
|
|
ellipse_rcnn/core/__pycache__/kld.cpython-312.pyc
ADDED
Binary file (5.72 kB). View file
|
|
ellipse_rcnn/core/__pycache__/model.cpython-312.pyc
ADDED
Binary file (10.3 kB). View file
|
|
ellipse_rcnn/core/__pycache__/wd.cpython-312.pyc
ADDED
Binary file (5.82 kB). View file
|
|
ellipse_rcnn/core/ellipse_roi_head.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Tuple, Optional, TypedDict, NamedTuple, Self
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import functional as F
|
6 |
+
from torchvision.models.detection.roi_heads import RoIHeads, fastrcnn_loss
|
7 |
+
|
8 |
+
from .kld import SymmetricKLDLoss
|
9 |
+
from .wd import WassersteinLoss
|
10 |
+
from ..utils.conics import (
|
11 |
+
ellipse_to_conic_matrix,
|
12 |
+
ellipse_axes,
|
13 |
+
ellipse_angle,
|
14 |
+
conic_center,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class RegressorPrediction(NamedTuple):
|
19 |
+
"""
|
20 |
+
Represents the processed outputs of a regression model as a named tuple.
|
21 |
+
|
22 |
+
This class encapsulates regression model outputs in a structured format, where
|
23 |
+
each attribute corresponds to a specific component of the regression output.
|
24 |
+
These outputs can be directly used for post-processing steps such as transformation
|
25 |
+
into conic matrices or further evaluations of ellipse geometry.
|
26 |
+
|
27 |
+
Attributes
|
28 |
+
----------
|
29 |
+
d_a : torch.Tensor
|
30 |
+
The normalized semi-major axis scale factor (logarithmic) used to compute
|
31 |
+
the actual semi-major axis length of ellipses.
|
32 |
+
d_b : torch.Tensor
|
33 |
+
The normalized semi-minor axis scale factor (logarithmic) used to compute
|
34 |
+
the actual semi-minor axis length of ellipses.
|
35 |
+
d_x : torch.Tensor
|
36 |
+
The normalized x-coordinate translation factor, specifying the adjustment
|
37 |
+
to the center of bounding boxes for ellipse placement.
|
38 |
+
d_y : torch.Tensor
|
39 |
+
The normalized y-coordinate translation factor, specifying the adjustment
|
40 |
+
to the center of bounding boxes for ellipse placement.
|
41 |
+
d_theta : torch.Tensor
|
42 |
+
The normalized rotation angle factor which is processed to derive the
|
43 |
+
actual rotation angle (in radians) of ellipses.
|
44 |
+
|
45 |
+
Notes
|
46 |
+
-----
|
47 |
+
- The attributes `d_a` and `d_b`, representing scale factors for the semi-major
|
48 |
+
and semi-minor axes, are typically bounded between 0 and 1 using a sigmoid activation.
|
49 |
+
- The attributes `d_x` and `d_y` serve as adjustments to bounding box centers, normalized
|
50 |
+
with respect to the bounding box diagonals.
|
51 |
+
- The attribute `d_theta` is normalized to ensure the rotation angle lies within
|
52 |
+
a valid range (after transformation, typically between -π/2 and π/2 radians).
|
53 |
+
- These normalized outputs are post-processed together with bounding box information
|
54 |
+
to construct actionable ellipse parameters such as their axes lengths, centers,
|
55 |
+
and angles.
|
56 |
+
- This structure simplifies downstream regression tasks, such as conversion into
|
57 |
+
conic matrices or calculation of geometrical losses.
|
58 |
+
"""
|
59 |
+
|
60 |
+
d_a: torch.Tensor
|
61 |
+
d_b: torch.Tensor
|
62 |
+
d_theta: torch.Tensor
|
63 |
+
|
64 |
+
@property
|
65 |
+
def device(self) -> torch.device:
|
66 |
+
return self.d_a.device
|
67 |
+
|
68 |
+
@property
|
69 |
+
def dtype(self) -> torch.dtype:
|
70 |
+
return self.d_a.dtype
|
71 |
+
|
72 |
+
def split(self, split_size: list[int] | int, dim: int = 0) -> list[Self]:
|
73 |
+
return [
|
74 |
+
RegressorPrediction(*tensors)
|
75 |
+
for tensors in zip(
|
76 |
+
*[torch.split(attr, split_size, dim=dim) for attr in self]
|
77 |
+
)
|
78 |
+
]
|
79 |
+
|
80 |
+
|
81 |
+
class EllipseRegressor(nn.Module):
|
82 |
+
"""
|
83 |
+
EllipseRegressor is a neural network module designed to predict parameters of
|
84 |
+
an ellipse given input features.
|
85 |
+
|
86 |
+
This class is a PyTorch module that uses a feedforward neural network to predict
|
87 |
+
the normalized five parameters of an ellipse: semi-major axis `a`, semi-minor axis `b`, center
|
88 |
+
coordinates (`x`, `y`), and orientation `theta`. The class includes mechanisms
|
89 |
+
for batch normalization and uses Xavier weight initialization for improved
|
90 |
+
training stability and convergence.
|
91 |
+
|
92 |
+
Attributes
|
93 |
+
----------
|
94 |
+
ffnn : nn.Sequential
|
95 |
+
A feedforward neural network with two hidden layers and ReLU activations.
|
96 |
+
"""
|
97 |
+
|
98 |
+
def __init__(self, in_channels: int = 1024, hidden_size: int = 64):
|
99 |
+
super().__init__()
|
100 |
+
# Separate prediction heads for better gradient flow
|
101 |
+
self.ffnn = nn.Sequential(
|
102 |
+
nn.Linear(in_channels, hidden_size),
|
103 |
+
nn.ReLU(),
|
104 |
+
nn.Linear(hidden_size, 3),
|
105 |
+
nn.Tanh(),
|
106 |
+
)
|
107 |
+
|
108 |
+
# Initialize with small values
|
109 |
+
for lin in self.ffnn:
|
110 |
+
if isinstance(lin, nn.Linear):
|
111 |
+
nn.init.xavier_uniform_(lin.weight, gain=0.01)
|
112 |
+
nn.init.zeros_(lin.bias)
|
113 |
+
|
114 |
+
def forward(self, x: torch.Tensor) -> RegressorPrediction:
|
115 |
+
x = x.flatten(start_dim=1)
|
116 |
+
x = self.ffnn(x)
|
117 |
+
|
118 |
+
d_a, d_b, d_theta = x.unbind(dim=-1)
|
119 |
+
|
120 |
+
return RegressorPrediction(d_a=d_a, d_b=d_b, d_theta=d_theta)
|
121 |
+
|
122 |
+
|
123 |
+
def postprocess_ellipse_predictor(
|
124 |
+
pred: RegressorPrediction,
|
125 |
+
box_proposals: torch.Tensor,
|
126 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
127 |
+
"""Processes elliptical predictor outputs and converts them into conic matrices.
|
128 |
+
|
129 |
+
Parameters
|
130 |
+
----------
|
131 |
+
pred : RegressorPrediction
|
132 |
+
The output of the elliptical predictor model.
|
133 |
+
box_proposals : torch.Tensor
|
134 |
+
Tensor containing proposed bounding box information, with shape (N, 4). Each box
|
135 |
+
is represented as a 4-tuple (x_min, y_min, x_max, y_max).
|
136 |
+
|
137 |
+
Returns
|
138 |
+
-------
|
139 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
|
140 |
+
A tuple containing:
|
141 |
+
- a (torch.Tensor): Computed semi-major axis of the ellipses.
|
142 |
+
- b (torch.Tensor): Computed semi-minor axis of the ellipses.
|
143 |
+
- x (torch.Tensor): X-coordinates of the ellipse centers.
|
144 |
+
- y (torch.Tensor): Y-coordinates of the ellipse centers.
|
145 |
+
- theta (torch.Tensor): Rotation angles (in radians) for the ellipses.
|
146 |
+
|
147 |
+
"""
|
148 |
+
d_a, d_b, d_theta = pred
|
149 |
+
|
150 |
+
# Pre-compute box width, height, and diagonal
|
151 |
+
box_width = box_proposals[:, 2] - box_proposals[:, 0]
|
152 |
+
box_height = box_proposals[:, 3] - box_proposals[:, 1]
|
153 |
+
box_diag = torch.sqrt(box_width**2 + box_height**2)
|
154 |
+
|
155 |
+
a = box_diag * d_a.exp()
|
156 |
+
b = box_diag * d_b.exp()
|
157 |
+
|
158 |
+
box_x = box_proposals[:, 0] + box_width * 0.5
|
159 |
+
box_y = box_proposals[:, 1] + box_height * 0.5
|
160 |
+
|
161 |
+
theta = (d_theta * 2.0 - 1.0) * (torch.pi / 2)
|
162 |
+
|
163 |
+
cos_theta = torch.cos(theta)
|
164 |
+
sin_theta = torch.sin(theta)
|
165 |
+
theta = torch.where(
|
166 |
+
cos_theta >= 0,
|
167 |
+
torch.atan2(sin_theta, cos_theta),
|
168 |
+
torch.atan2(-sin_theta, -cos_theta),
|
169 |
+
)
|
170 |
+
|
171 |
+
return a, b, box_x, box_y, theta
|
172 |
+
|
173 |
+
|
174 |
+
class EllipseLossDict(TypedDict):
|
175 |
+
loss_ellipse_kld: torch.Tensor
|
176 |
+
loss_ellipse_smooth_l1: torch.Tensor
|
177 |
+
loss_ellipse_wasserstein: torch.Tensor
|
178 |
+
|
179 |
+
|
180 |
+
def ellipse_loss(
|
181 |
+
pred: RegressorPrediction,
|
182 |
+
A_target: List[torch.Tensor],
|
183 |
+
pos_matched_idxs: List[torch.Tensor],
|
184 |
+
box_proposals: List[torch.Tensor],
|
185 |
+
kld_loss_fn: SymmetricKLDLoss,
|
186 |
+
wd_loss_fn: WassersteinLoss,
|
187 |
+
) -> EllipseLossDict:
|
188 |
+
pos_matched_idxs_batched = torch.cat(pos_matched_idxs, dim=0)
|
189 |
+
A_target = torch.cat(A_target, dim=0)[pos_matched_idxs_batched]
|
190 |
+
|
191 |
+
box_proposals = torch.cat(box_proposals, dim=0)
|
192 |
+
|
193 |
+
if A_target.numel() == 0:
|
194 |
+
return {
|
195 |
+
"loss_ellipse_kld": torch.tensor(0.0, device=pred.device, dtype=pred.dtype),
|
196 |
+
"loss_ellipse_smooth_l1": torch.tensor(
|
197 |
+
0.0, device=pred.device, dtype=pred.dtype
|
198 |
+
),
|
199 |
+
"loss_ellipse_wasserstein": torch.tensor(
|
200 |
+
0.0, device=pred.device, dtype=pred.dtype
|
201 |
+
),
|
202 |
+
}
|
203 |
+
|
204 |
+
a_target, b_target = ellipse_axes(A_target)
|
205 |
+
theta_target = ellipse_angle(A_target)
|
206 |
+
|
207 |
+
# Box proposal parameters
|
208 |
+
box_width = box_proposals[:, 2] - box_proposals[:, 0]
|
209 |
+
box_height = box_proposals[:, 3] - box_proposals[:, 1]
|
210 |
+
box_diag = torch.sqrt(box_width**2 + box_height**2).clamp(min=1e-6)
|
211 |
+
|
212 |
+
# Normalize target variables
|
213 |
+
da_target = (a_target / box_diag).log()
|
214 |
+
db_target = (b_target / box_diag).log()
|
215 |
+
dtheta_target = (theta_target / (torch.pi / 2) + 1) / 2
|
216 |
+
|
217 |
+
# Direct parameter losses
|
218 |
+
d_a, d_b, d_theta = pred
|
219 |
+
|
220 |
+
pred_t = torch.stack([d_a, d_b, d_theta], dim=1)
|
221 |
+
target_t = torch.stack([da_target, db_target, dtheta_target], dim=1)
|
222 |
+
|
223 |
+
loss_smooth_l1 = F.smooth_l1_loss(pred_t, target_t, beta=(1 / 9), reduction="sum")
|
224 |
+
loss_smooth_l1 /= box_proposals.shape[0]
|
225 |
+
loss_smooth_l1 = loss_smooth_l1.nan_to_num(nan=0.0).clip(max=float(1e4))
|
226 |
+
|
227 |
+
a, b, x, y, theta = postprocess_ellipse_predictor(pred, box_proposals)
|
228 |
+
|
229 |
+
A_pred = ellipse_to_conic_matrix(a=a, b=b, theta=theta, x=x, y=y)
|
230 |
+
|
231 |
+
loss_kld = kld_loss_fn.forward(A_pred, A_target).clip(max=float(1e4)).mean() * 0.1
|
232 |
+
loss_wd = torch.zeros(1, device=pred.device, dtype=pred.dtype)
|
233 |
+
# loss_wd = wd_loss_fn.forward(A_pred, A_target).clip(max=float(1e4)).mean() * 0.1
|
234 |
+
|
235 |
+
return {
|
236 |
+
"loss_ellipse_kld": loss_kld,
|
237 |
+
"loss_ellipse_smooth_l1": loss_smooth_l1,
|
238 |
+
"loss_ellipse_wasserstein": loss_wd,
|
239 |
+
}
|
240 |
+
|
241 |
+
|
242 |
+
class EllipseRoIHeads(RoIHeads):
|
243 |
+
def __init__(
|
244 |
+
self,
|
245 |
+
box_roi_pool: nn.Module,
|
246 |
+
box_head: nn.Module,
|
247 |
+
box_predictor: nn.Module,
|
248 |
+
fg_iou_thresh: float,
|
249 |
+
bg_iou_thresh: float,
|
250 |
+
batch_size_per_image: int,
|
251 |
+
positive_fraction: float,
|
252 |
+
bbox_reg_weights: Optional[Tuple[float, float, float, float]],
|
253 |
+
score_thresh: float,
|
254 |
+
nms_thresh: float,
|
255 |
+
detections_per_img: int,
|
256 |
+
ellipse_roi_pool: nn.Module,
|
257 |
+
ellipse_head: nn.Module,
|
258 |
+
ellipse_predictor: nn.Module,
|
259 |
+
# Loss parameters
|
260 |
+
kld_shape_only: bool = False,
|
261 |
+
kld_normalize: bool = False,
|
262 |
+
# Numerical stability parameters
|
263 |
+
nan_to_num: float = 10.0,
|
264 |
+
loss_scale: float = 1.0,
|
265 |
+
):
|
266 |
+
super().__init__(
|
267 |
+
box_roi_pool,
|
268 |
+
box_head,
|
269 |
+
box_predictor,
|
270 |
+
fg_iou_thresh,
|
271 |
+
bg_iou_thresh,
|
272 |
+
batch_size_per_image,
|
273 |
+
positive_fraction,
|
274 |
+
bbox_reg_weights,
|
275 |
+
score_thresh,
|
276 |
+
nms_thresh,
|
277 |
+
detections_per_img,
|
278 |
+
)
|
279 |
+
|
280 |
+
self.ellipse_roi_pool = ellipse_roi_pool
|
281 |
+
self.ellipse_head = ellipse_head
|
282 |
+
self.ellipse_predictor = ellipse_predictor
|
283 |
+
|
284 |
+
self.kld_loss = SymmetricKLDLoss(
|
285 |
+
shape_only=kld_shape_only,
|
286 |
+
normalize=kld_normalize,
|
287 |
+
nan_to_num=nan_to_num,
|
288 |
+
)
|
289 |
+
self.wd_loss = WassersteinLoss(
|
290 |
+
nan_to_num=nan_to_num,
|
291 |
+
normalize=kld_normalize,
|
292 |
+
)
|
293 |
+
self.loss_scale = loss_scale
|
294 |
+
|
295 |
+
def has_ellipse_reg(self) -> bool:
|
296 |
+
if self.ellipse_roi_pool is None:
|
297 |
+
return False
|
298 |
+
if self.ellipse_head is None:
|
299 |
+
return False
|
300 |
+
if self.ellipse_predictor is None:
|
301 |
+
return False
|
302 |
+
return True
|
303 |
+
|
304 |
+
def postprocess_ellipse_regressions(self):
|
305 |
+
pass
|
306 |
+
|
307 |
+
def forward(
|
308 |
+
self,
|
309 |
+
features: Dict[str, torch.Tensor],
|
310 |
+
proposals: List[torch.Tensor],
|
311 |
+
image_shapes: List[Tuple[int, int]],
|
312 |
+
targets: Optional[List[Dict[str, torch.Tensor]]] = None,
|
313 |
+
) -> Tuple[List[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]:
|
314 |
+
if targets is not None:
|
315 |
+
for t in targets:
|
316 |
+
floating_point_types = (torch.float, torch.double, torch.half)
|
317 |
+
if t["boxes"].dtype not in floating_point_types:
|
318 |
+
raise TypeError("target boxes must be of float type")
|
319 |
+
if t["ellipse_matrices"].dtype not in floating_point_types:
|
320 |
+
raise TypeError("target ellipse_offsets must be of float type")
|
321 |
+
if t["labels"].dtype != torch.int64:
|
322 |
+
raise TypeError("target labels must be of int64 type")
|
323 |
+
|
324 |
+
if self.training:
|
325 |
+
proposals, matched_idxs, labels, regression_targets = (
|
326 |
+
self.select_training_samples(proposals, targets)
|
327 |
+
)
|
328 |
+
else:
|
329 |
+
labels = None
|
330 |
+
regression_targets = None
|
331 |
+
matched_idxs = None
|
332 |
+
|
333 |
+
box_features = self.box_roi_pool(features, proposals, image_shapes)
|
334 |
+
box_features = self.box_head(box_features)
|
335 |
+
class_logits, box_regression = self.box_predictor(box_features)
|
336 |
+
|
337 |
+
result: List[Dict[str, torch.Tensor]] = []
|
338 |
+
losses = {}
|
339 |
+
if self.training:
|
340 |
+
if labels is None or regression_targets is None:
|
341 |
+
raise ValueError(
|
342 |
+
"Labels and regression targets must not be None during training"
|
343 |
+
)
|
344 |
+
loss_classifier, loss_box_reg = fastrcnn_loss(
|
345 |
+
class_logits, box_regression, labels, regression_targets
|
346 |
+
)
|
347 |
+
losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
|
348 |
+
else:
|
349 |
+
boxes, scores, labels = self.postprocess_detections(
|
350 |
+
class_logits, box_regression, proposals, image_shapes
|
351 |
+
)
|
352 |
+
num_images = len(boxes)
|
353 |
+
for i in range(num_images):
|
354 |
+
result.append(
|
355 |
+
{
|
356 |
+
"boxes": boxes[i],
|
357 |
+
"labels": labels[i],
|
358 |
+
"scores": scores[i],
|
359 |
+
}
|
360 |
+
)
|
361 |
+
|
362 |
+
if self.has_ellipse_reg():
|
363 |
+
ellipse_box_proposals = [p["boxes"] for p in result]
|
364 |
+
if self.training:
|
365 |
+
if matched_idxs is None:
|
366 |
+
raise ValueError("matched_idxs must not be None during training")
|
367 |
+
# during training, only focus on positive boxes
|
368 |
+
num_images = len(proposals)
|
369 |
+
ellipse_box_proposals = []
|
370 |
+
pos_matched_idxs = []
|
371 |
+
for img_id in range(num_images):
|
372 |
+
pos = torch.where(labels[img_id] > 0)[0]
|
373 |
+
ellipse_box_proposals.append(proposals[img_id][pos])
|
374 |
+
pos_matched_idxs.append(matched_idxs[img_id][pos])
|
375 |
+
else:
|
376 |
+
pos_matched_idxs = None # type: ignore
|
377 |
+
|
378 |
+
if self.ellipse_roi_pool is not None:
|
379 |
+
ellipse_features = self.ellipse_roi_pool(
|
380 |
+
features, ellipse_box_proposals, image_shapes
|
381 |
+
)
|
382 |
+
ellipse_features = self.ellipse_head(ellipse_features)
|
383 |
+
ellipse_shapes_normalised = self.ellipse_predictor(ellipse_features)
|
384 |
+
else:
|
385 |
+
raise Exception("Expected ellipse_roi_pool to be not None")
|
386 |
+
|
387 |
+
loss_ellipse_regressor = {}
|
388 |
+
if self.training:
|
389 |
+
if targets is None:
|
390 |
+
raise ValueError("Targets must not be None during training")
|
391 |
+
if pos_matched_idxs is None:
|
392 |
+
raise ValueError(
|
393 |
+
"pos_matched_idxs must not be None during training"
|
394 |
+
)
|
395 |
+
if ellipse_shapes_normalised is None:
|
396 |
+
raise ValueError(
|
397 |
+
"ellipse_shapes_normalised must not be None during training"
|
398 |
+
)
|
399 |
+
|
400 |
+
ellipse_matrix_targets = [t["ellipse_matrices"] for t in targets]
|
401 |
+
rcnn_loss_ellipse = ellipse_loss(
|
402 |
+
ellipse_shapes_normalised,
|
403 |
+
ellipse_matrix_targets,
|
404 |
+
pos_matched_idxs,
|
405 |
+
ellipse_box_proposals,
|
406 |
+
self.kld_loss,
|
407 |
+
self.wd_loss,
|
408 |
+
)
|
409 |
+
|
410 |
+
if self.loss_scale != 1.0:
|
411 |
+
rcnn_loss_ellipse["loss_ellipse_kld"] *= self.loss_scale
|
412 |
+
rcnn_loss_ellipse["loss_ellipse_smooth_l1"] *= self.loss_scale
|
413 |
+
|
414 |
+
loss_ellipse_regressor.update(rcnn_loss_ellipse)
|
415 |
+
else:
|
416 |
+
ellipses_per_image = [lbl.shape[0] for lbl in labels]
|
417 |
+
for pred, r, box in zip(
|
418 |
+
ellipse_shapes_normalised.split(ellipses_per_image, dim=0),
|
419 |
+
result,
|
420 |
+
ellipse_box_proposals,
|
421 |
+
):
|
422 |
+
a, b, x, y, theta = postprocess_ellipse_predictor(pred, box)
|
423 |
+
A_pred = ellipse_to_conic_matrix(a=a, b=b, theta=theta, x=x, y=y)
|
424 |
+
r["ellipse_matrices"] = A_pred
|
425 |
+
# r["boxes"] = bbox_ellipse(A_pred)
|
426 |
+
|
427 |
+
losses.update(loss_ellipse_regressor)
|
428 |
+
|
429 |
+
return result, losses
|
ellipse_rcnn/core/ga.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from ellipse_rcnn.utils.conics import conic_center
|
3 |
+
|
4 |
+
|
5 |
+
def gaussian_angle_distance(A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor:
|
6 |
+
# Extract covariance matrices (negative of the top-left blocks)
|
7 |
+
cov1, cov2 = map(lambda arr: -arr[..., :2, :2], (A1, A2))
|
8 |
+
|
9 |
+
# Extract the means by computing conic centers
|
10 |
+
c1_x, c1_y = conic_center(A1)
|
11 |
+
c2_x, c2_y = conic_center(A2)
|
12 |
+
|
13 |
+
# Stack the conic centers into the appropriate shape for computation
|
14 |
+
m1 = torch.stack((c1_x, c1_y), dim=-1)[..., None]
|
15 |
+
m2 = torch.stack((c2_x, c2_y), dim=-1)[..., None]
|
16 |
+
|
17 |
+
# Compute determinants for covariance matrices
|
18 |
+
det_cov1 = torch.clamp(cov1.det(), min=torch.finfo(cov1.dtype).eps)
|
19 |
+
det_cov2 = torch.clamp(cov2.det(), min=torch.finfo(cov2.dtype).eps)
|
20 |
+
cov_sum = cov1 + cov2
|
21 |
+
|
22 |
+
# Determinant of sum (clamped for numerical stability)
|
23 |
+
det_cov_sum = torch.clamp(cov_sum.det(), min=torch.finfo(cov_sum.dtype).eps)
|
24 |
+
|
25 |
+
# Compute fractional term with stabilized determinants
|
26 |
+
frac_term = (4 * torch.sqrt(det_cov1 * det_cov2)) / det_cov_sum
|
27 |
+
# Stable computation of the exponential term
|
28 |
+
mean_diff = m1 - m2
|
29 |
+
cov_sum_inv = torch.linalg.solve(
|
30 |
+
cov_sum, torch.eye(cov_sum.size(-1), dtype=cov_sum.dtype, device=cov_sum.device)
|
31 |
+
)
|
32 |
+
exp_arg = -0.5 * mean_diff.transpose(-1, -2) @ cov1 @ cov_sum_inv @ cov2 @ mean_diff
|
33 |
+
exp_term = torch.exp(torch.clamp(exp_arg, min=-50, max=50)).squeeze()
|
34 |
+
|
35 |
+
angle_term = frac_term * exp_term
|
36 |
+
|
37 |
+
return torch.arccos(angle_term)
|
38 |
+
|
39 |
+
|
40 |
+
class GaussianAngleDistanceLoss(torch.nn.Module):
|
41 |
+
"""
|
42 |
+
Computes the Gaussian Angle Distance loss between two tensors.
|
43 |
+
|
44 |
+
This class serves as a wrapper around the `gaussian_angle_distance` function,
|
45 |
+
providing a clean interface and ensuring numerical stability.
|
46 |
+
|
47 |
+
Attributes
|
48 |
+
----------
|
49 |
+
normalize : bool
|
50 |
+
|
51 |
+
nan_to_num : float
|
52 |
+
The value to replace NaN entries in the computation with. Helps maintain numerical
|
53 |
+
stability in cases where the input tensors contain undefined or invalid values.
|
54 |
+
"""
|
55 |
+
|
56 |
+
def __init__(self, normalize: bool = True, nan_to_num: float = 10.0):
|
57 |
+
super().__init__()
|
58 |
+
self.nan_to_num = nan_to_num
|
59 |
+
|
60 |
+
def forward(self, A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor:
|
61 |
+
# Calculate the Gaussian angle distance
|
62 |
+
distance = gaussian_angle_distance(A1, A2)
|
63 |
+
|
64 |
+
# Replace NaN values with a predefined constant for numerical stability
|
65 |
+
distance = torch.nan_to_num(distance, nan=self.nan_to_num)
|
66 |
+
|
67 |
+
return distance
|
ellipse_rcnn/core/kld.py
ADDED
@@ -0,0 +1,124 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ellipse_rcnn.utils.conics import conic_center
|
4 |
+
|
5 |
+
|
6 |
+
def mv_kullback_leibler_divergence(
|
7 |
+
A1: torch.Tensor,
|
8 |
+
A2: torch.Tensor,
|
9 |
+
*,
|
10 |
+
shape_only: bool = False,
|
11 |
+
) -> torch.Tensor:
|
12 |
+
"""
|
13 |
+
Compute multi-variate KL divergence between ellipses represented by their matrices.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
A1, A2: Ellipse matrices of shape (..., 3, 3)
|
17 |
+
shape_only: If True, ignores displacement term
|
18 |
+
"""
|
19 |
+
|
20 |
+
# Ensure that batch sizes are equal
|
21 |
+
if A1.shape[:-2] != A2.shape[:-2]:
|
22 |
+
raise ValueError(
|
23 |
+
f"Batch size mismatch: A1 has shape {A1.shape[:-2]}, A2 has shape {A2.shape[:-2]}"
|
24 |
+
)
|
25 |
+
|
26 |
+
# Extract the upper 2x2 blocks as covariance matrices
|
27 |
+
cov1 = A1[..., :2, :2]
|
28 |
+
cov2 = A2[..., :2, :2]
|
29 |
+
|
30 |
+
# Compute centers
|
31 |
+
m1 = torch.vstack(conic_center(A1)).T[..., None]
|
32 |
+
m2 = torch.vstack(conic_center(A2)).T[..., None]
|
33 |
+
|
34 |
+
# Compute inverse
|
35 |
+
try:
|
36 |
+
cov2_inv = torch.linalg.inv(cov2)
|
37 |
+
except RuntimeError:
|
38 |
+
cov2_inv = torch.linalg.pinv(cov2)
|
39 |
+
|
40 |
+
# Trace term
|
41 |
+
trace_term = (cov2_inv @ cov1).diagonal(dim2=-2, dim1=-1).sum(1)
|
42 |
+
|
43 |
+
# Log determinant term
|
44 |
+
det_cov1 = torch.det(cov1)
|
45 |
+
det_cov2 = torch.det(cov2)
|
46 |
+
log_term = torch.log(det_cov2 / det_cov1).nan_to_num(nan=0.0)
|
47 |
+
|
48 |
+
if shape_only:
|
49 |
+
displacement_term = 0
|
50 |
+
else:
|
51 |
+
# Mean difference term
|
52 |
+
displacement_term = (
|
53 |
+
((m1 - m2).transpose(-1, -2) @ cov2_inv @ (m1 - m2)).squeeze().abs()
|
54 |
+
)
|
55 |
+
|
56 |
+
return 0.5 * (trace_term + displacement_term - 2 + log_term)
|
57 |
+
|
58 |
+
|
59 |
+
def symmetric_kl_divergence(
|
60 |
+
A1: torch.Tensor,
|
61 |
+
A2: torch.Tensor,
|
62 |
+
*,
|
63 |
+
shape_only: bool = False,
|
64 |
+
nan_to_num: float = float(1e4),
|
65 |
+
normalize: bool = False,
|
66 |
+
) -> torch.Tensor:
|
67 |
+
"""
|
68 |
+
Compute symmetric KL divergence between ellipses.
|
69 |
+
"""
|
70 |
+
kl_12 = torch.nan_to_num(
|
71 |
+
mv_kullback_leibler_divergence(A1, A2, shape_only=shape_only), nan_to_num
|
72 |
+
)
|
73 |
+
kl_21 = torch.nan_to_num(
|
74 |
+
mv_kullback_leibler_divergence(A2, A1, shape_only=shape_only), nan_to_num
|
75 |
+
)
|
76 |
+
kl = (kl_12 + kl_21) / 2
|
77 |
+
|
78 |
+
if kl.lt(0).any():
|
79 |
+
raise ValueError("Negative KL divergence encountered.")
|
80 |
+
|
81 |
+
if normalize:
|
82 |
+
kl = 1 - torch.exp(-kl)
|
83 |
+
return kl
|
84 |
+
|
85 |
+
|
86 |
+
class SymmetricKLDLoss(torch.nn.Module):
|
87 |
+
"""
|
88 |
+
Computes the symmetric Kullback-Leibler divergence (KLD) loss between two tensors.
|
89 |
+
|
90 |
+
SymmetricKLDLoss is used for measuring the divergence between two probability
|
91 |
+
distributions or tensors, which can be useful in tasks such as generative modeling
|
92 |
+
or optimization. The function allows for options such as normalizing the tensors or
|
93 |
+
focusing only on their shape for comparison. Additionally, it includes a feature
|
94 |
+
to handle NaN values by replacing them with a numeric constant.
|
95 |
+
|
96 |
+
Attributes
|
97 |
+
----------
|
98 |
+
shape_only : bool
|
99 |
+
If True, computes the divergence based on the shape of the tensors only. This
|
100 |
+
can be used to evaluate similarity without considering magnitude differences.
|
101 |
+
nan_to_num : float
|
102 |
+
The value to replace NaN entries in the tensors with. Helps maintain numerical
|
103 |
+
stability in cases where the input tensors contain undefined or invalid values.
|
104 |
+
normalize : bool
|
105 |
+
If True, normalizes the tensors before computing the divergence. This is
|
106 |
+
typically used when the inputs are not already probability distributions.
|
107 |
+
"""
|
108 |
+
|
109 |
+
def __init__(
|
110 |
+
self, shape_only: bool = True, nan_to_num: float = 10.0, normalize: bool = False
|
111 |
+
):
|
112 |
+
super().__init__()
|
113 |
+
self.shape_only = shape_only
|
114 |
+
self.nan_to_num = nan_to_num
|
115 |
+
self.normalize = normalize
|
116 |
+
|
117 |
+
def forward(self, A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor:
|
118 |
+
return symmetric_kl_divergence(
|
119 |
+
A1,
|
120 |
+
A2,
|
121 |
+
shape_only=self.shape_only,
|
122 |
+
nan_to_num=self.nan_to_num,
|
123 |
+
normalize=self.normalize,
|
124 |
+
)
|
ellipse_rcnn/core/model.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from types import NoneType
|
2 |
+
from typing import List, Tuple, Optional, Any
|
3 |
+
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torchvision.models import ResNet50_Weights, WeightsEnum
|
8 |
+
from torchvision.models.detection.anchor_utils import AnchorGenerator
|
9 |
+
from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
|
10 |
+
from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor # noqa: F
|
11 |
+
from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
|
12 |
+
from torchvision.models.detection.rpn import RPNHead, RegionProposalNetwork
|
13 |
+
from torchvision.models.detection.transform import GeneralizedRCNNTransform
|
14 |
+
from torchvision.ops import MultiScaleRoIAlign
|
15 |
+
|
16 |
+
from .ellipse_roi_head import EllipseRoIHeads, EllipseRegressor
|
17 |
+
from ..utils.types import CollatedBatchType
|
18 |
+
|
19 |
+
|
20 |
+
class EllipseRCNN(GeneralizedRCNN):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
num_classes: int = 2,
|
24 |
+
# transform parameters
|
25 |
+
backbone_name: str = "resnet50",
|
26 |
+
weights: WeightsEnum | str = ResNet50_Weights.IMAGENET1K_V1,
|
27 |
+
min_size: int = 256,
|
28 |
+
max_size: int = 512,
|
29 |
+
image_mean: Optional[List[float]] = None,
|
30 |
+
image_std: Optional[List[float]] = None,
|
31 |
+
# Region Proposal Network parameters
|
32 |
+
rpn_anchor_generator: Optional[nn.Module] = None,
|
33 |
+
rpn_head: Optional[nn.Module] = None,
|
34 |
+
rpn_pre_nms_top_n_train: int = 2000,
|
35 |
+
rpn_pre_nms_top_n_test: int = 1000,
|
36 |
+
rpn_post_nms_top_n_train: int = 2000,
|
37 |
+
rpn_post_nms_top_n_test: int = 1000,
|
38 |
+
rpn_nms_thresh: float = 0.7,
|
39 |
+
rpn_fg_iou_thresh: float = 0.7,
|
40 |
+
rpn_bg_iou_thresh: float = 0.3,
|
41 |
+
rpn_batch_size_per_image: int = 256,
|
42 |
+
rpn_positive_fraction: float = 0.5,
|
43 |
+
rpn_score_thresh: float = 0.0,
|
44 |
+
# Box parameters
|
45 |
+
box_roi_pool: Optional[nn.Module] = None,
|
46 |
+
box_head: Optional[nn.Module] = None,
|
47 |
+
box_predictor: Optional[nn.Module] = None,
|
48 |
+
box_score_thresh: float = 0.05,
|
49 |
+
box_nms_thresh: float = 0.5,
|
50 |
+
box_detections_per_img: int = 100,
|
51 |
+
box_fg_iou_thresh: float = 0.5,
|
52 |
+
box_bg_iou_thresh: float = 0.5,
|
53 |
+
box_batch_size_per_image: int = 512,
|
54 |
+
box_positive_fraction: float = 0.25,
|
55 |
+
bbox_reg_weights: Optional[Tuple[float, float, float, float]] = None,
|
56 |
+
# Ellipse regressor
|
57 |
+
ellipse_roi_pool: Optional[nn.Module] = None,
|
58 |
+
ellipse_head: Optional[nn.Module] = None,
|
59 |
+
ellipse_predictor: Optional[nn.Module] = None,
|
60 |
+
ellipse_loss_scale: float = 1.0,
|
61 |
+
ellipse_loss_normalize: bool = False,
|
62 |
+
):
|
63 |
+
if backbone_name != "resnet50" and weights == ResNet50_Weights.IMAGENET1K_V1:
|
64 |
+
raise ValueError(
|
65 |
+
"If backbone_name is not resnet50, weights_enum must be specified"
|
66 |
+
)
|
67 |
+
|
68 |
+
backbone = resnet_fpn_backbone(
|
69 |
+
backbone_name=backbone_name, weights=weights, trainable_layers=5
|
70 |
+
)
|
71 |
+
|
72 |
+
if not hasattr(backbone, "out_channels"):
|
73 |
+
raise ValueError(
|
74 |
+
"backbone should contain an attribute out_channels "
|
75 |
+
"specifying the number of output channels (assumed to be the "
|
76 |
+
"same for all the levels)"
|
77 |
+
)
|
78 |
+
|
79 |
+
if not isinstance(rpn_anchor_generator, (AnchorGenerator, NoneType)):
|
80 |
+
raise TypeError(
|
81 |
+
"rpn_anchor_generator must be an instance of AnchorGenerator or None"
|
82 |
+
)
|
83 |
+
|
84 |
+
if not isinstance(box_roi_pool, (MultiScaleRoIAlign, NoneType)):
|
85 |
+
raise TypeError(
|
86 |
+
"box_roi_pool must be an instance of MultiScaleRoIAlign or None"
|
87 |
+
)
|
88 |
+
|
89 |
+
if num_classes is not None:
|
90 |
+
if box_predictor is not None:
|
91 |
+
raise ValueError(
|
92 |
+
"num_classes should be None when box_predictor is specified"
|
93 |
+
)
|
94 |
+
else:
|
95 |
+
if box_predictor is None:
|
96 |
+
raise ValueError(
|
97 |
+
"num_classes should not be None when box_predictor "
|
98 |
+
"is not specified"
|
99 |
+
)
|
100 |
+
|
101 |
+
out_channels = backbone.out_channels
|
102 |
+
|
103 |
+
if rpn_anchor_generator is None:
|
104 |
+
anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
|
105 |
+
aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
|
106 |
+
rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
|
107 |
+
if rpn_head is None:
|
108 |
+
rpn_head = RPNHead(
|
109 |
+
out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
|
110 |
+
)
|
111 |
+
|
112 |
+
rpn_pre_nms_top_n = dict(
|
113 |
+
training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test
|
114 |
+
)
|
115 |
+
rpn_post_nms_top_n = dict(
|
116 |
+
training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test
|
117 |
+
)
|
118 |
+
|
119 |
+
rpn = RegionProposalNetwork(
|
120 |
+
rpn_anchor_generator,
|
121 |
+
rpn_head,
|
122 |
+
rpn_fg_iou_thresh,
|
123 |
+
rpn_bg_iou_thresh,
|
124 |
+
rpn_batch_size_per_image,
|
125 |
+
rpn_positive_fraction,
|
126 |
+
rpn_pre_nms_top_n,
|
127 |
+
rpn_post_nms_top_n,
|
128 |
+
rpn_nms_thresh,
|
129 |
+
score_thresh=rpn_score_thresh,
|
130 |
+
)
|
131 |
+
|
132 |
+
default_representation_size = 1024
|
133 |
+
|
134 |
+
if box_roi_pool is None:
|
135 |
+
box_roi_pool = MultiScaleRoIAlign(
|
136 |
+
featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2
|
137 |
+
)
|
138 |
+
|
139 |
+
if box_head is None:
|
140 |
+
resolution = box_roi_pool.output_size[0]
|
141 |
+
if isinstance(resolution, int):
|
142 |
+
box_head = TwoMLPHead(
|
143 |
+
out_channels * resolution**2, default_representation_size
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
raise ValueError(
|
147 |
+
"resolution should be an int but is {}".format(resolution)
|
148 |
+
)
|
149 |
+
|
150 |
+
if box_predictor is None:
|
151 |
+
box_predictor = FastRCNNPredictor(default_representation_size, num_classes)
|
152 |
+
|
153 |
+
if ellipse_roi_pool is None:
|
154 |
+
ellipse_roi_pool = MultiScaleRoIAlign(
|
155 |
+
featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2
|
156 |
+
)
|
157 |
+
|
158 |
+
resolution = box_roi_pool.output_size[0]
|
159 |
+
if ellipse_head is None:
|
160 |
+
if isinstance(resolution, int):
|
161 |
+
ellipse_head = TwoMLPHead(
|
162 |
+
out_channels * resolution**2, default_representation_size
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
raise ValueError(
|
166 |
+
"resolution should be an int but is {}".format(resolution)
|
167 |
+
)
|
168 |
+
|
169 |
+
if ellipse_predictor is None:
|
170 |
+
ellipse_predictor = EllipseRegressor(
|
171 |
+
default_representation_size, num_classes
|
172 |
+
)
|
173 |
+
|
174 |
+
roi_heads = EllipseRoIHeads(
|
175 |
+
# Box
|
176 |
+
box_roi_pool,
|
177 |
+
box_head,
|
178 |
+
box_predictor,
|
179 |
+
box_fg_iou_thresh,
|
180 |
+
box_bg_iou_thresh,
|
181 |
+
box_batch_size_per_image,
|
182 |
+
box_positive_fraction,
|
183 |
+
bbox_reg_weights,
|
184 |
+
box_score_thresh,
|
185 |
+
box_nms_thresh,
|
186 |
+
box_detections_per_img,
|
187 |
+
# Ellipse
|
188 |
+
ellipse_roi_pool=ellipse_roi_pool,
|
189 |
+
ellipse_head=ellipse_head,
|
190 |
+
ellipse_predictor=ellipse_predictor,
|
191 |
+
loss_scale=ellipse_loss_scale,
|
192 |
+
kld_normalize=ellipse_loss_normalize,
|
193 |
+
)
|
194 |
+
|
195 |
+
if image_mean is None:
|
196 |
+
image_mean = [0.485, 0.456, 0.406]
|
197 |
+
if image_std is None:
|
198 |
+
image_std = [0.229, 0.224, 0.225]
|
199 |
+
transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
|
200 |
+
|
201 |
+
super().__init__(backbone, rpn, roi_heads, transform)
|
202 |
+
|
203 |
+
|
204 |
+
class EllipseRCNNLightning(pl.LightningModule):
|
205 |
+
def __init__(
|
206 |
+
self,
|
207 |
+
model: EllipseRCNN,
|
208 |
+
lr: float = 1e-4,
|
209 |
+
weight_decay: float = 1e-4,
|
210 |
+
):
|
211 |
+
super().__init__()
|
212 |
+
self.model = model
|
213 |
+
self.save_hyperparameters(ignore=["model"])
|
214 |
+
|
215 |
+
def configure_optimizers(self) -> Any:
|
216 |
+
optimizer = torch.optim.AdamW(
|
217 |
+
self.model.parameters(),
|
218 |
+
lr=self.hparams.lr,
|
219 |
+
weight_decay=self.hparams.weight_decay,
|
220 |
+
amsgrad=True,
|
221 |
+
)
|
222 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
|
223 |
+
optimizer, mode="min", factor=0.5, patience=2, min_lr=1e-6
|
224 |
+
)
|
225 |
+
return {
|
226 |
+
"optimizer": optimizer,
|
227 |
+
"lr_scheduler": {"scheduler": scheduler, "monitor": "val/loss_total"},
|
228 |
+
}
|
229 |
+
|
230 |
+
def training_step(
|
231 |
+
self, batch: CollatedBatchType, batch_idx: int = 0
|
232 |
+
) -> torch.Tensor:
|
233 |
+
images, targets = batch
|
234 |
+
loss_dict = self.model(images, targets)
|
235 |
+
self.log_dict(
|
236 |
+
{f"train/{k}": v for k, v in loss_dict.items()},
|
237 |
+
prog_bar=True,
|
238 |
+
logger=True,
|
239 |
+
on_step=True,
|
240 |
+
)
|
241 |
+
|
242 |
+
loss = sum(loss_dict.values())
|
243 |
+
self.log("train/loss_total", loss, prog_bar=True, logger=True, on_step=True)
|
244 |
+
|
245 |
+
return loss
|
246 |
+
|
247 |
+
def validation_step(
|
248 |
+
self, batch: CollatedBatchType, batch_idx: int = 0
|
249 |
+
) -> torch.Tensor:
|
250 |
+
self.train(True)
|
251 |
+
images, targets = batch
|
252 |
+
|
253 |
+
loss_dict = self.model(images, targets)
|
254 |
+
|
255 |
+
self.log_dict(
|
256 |
+
{f"val/{k}": v for k, v in loss_dict.items()},
|
257 |
+
logger=True,
|
258 |
+
on_step=False,
|
259 |
+
on_epoch=True,
|
260 |
+
)
|
261 |
+
|
262 |
+
val_loss = sum(loss_dict.values())
|
263 |
+
self.log(
|
264 |
+
"val/loss_total",
|
265 |
+
val_loss,
|
266 |
+
prog_bar=True,
|
267 |
+
logger=True,
|
268 |
+
on_step=False,
|
269 |
+
on_epoch=True,
|
270 |
+
)
|
271 |
+
|
272 |
+
self.log(
|
273 |
+
"hp_metric",
|
274 |
+
val_loss,
|
275 |
+
)
|
276 |
+
|
277 |
+
self.log(
|
278 |
+
"lr",
|
279 |
+
self.lr_schedulers().get_last_lr()[0],
|
280 |
+
)
|
281 |
+
|
282 |
+
return val_loss
|
ellipse_rcnn/core/wd.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from ellipse_rcnn.utils.conics import conic_center
|
4 |
+
|
5 |
+
|
6 |
+
def wasserstein_distance(
|
7 |
+
A1: torch.Tensor,
|
8 |
+
A2: torch.Tensor,
|
9 |
+
*,
|
10 |
+
shape_only: bool = False,
|
11 |
+
) -> torch.Tensor:
|
12 |
+
"""
|
13 |
+
Compute the squared Wasserstein-2 distance between ellipses represented by their matrices.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
A1, A2: Ellipse matrices of shape (..., 3, 3)
|
17 |
+
shape_only: If True, ignores displacement term
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Tensor containing Wasserstein distances
|
21 |
+
"""
|
22 |
+
# Ensure batch sizes match
|
23 |
+
if A1.shape[:-2] != A2.shape[:-2]:
|
24 |
+
raise ValueError(
|
25 |
+
f"Batch size mismatch: A1 has shape {A1.shape[:-2]}, A2 has shape {A2.shape[:-2]}"
|
26 |
+
)
|
27 |
+
|
28 |
+
# Extract covariance matrices (upper 2x2 blocks)
|
29 |
+
cov1 = A1[..., :2, :2]
|
30 |
+
cov2 = A2[..., :2, :2]
|
31 |
+
|
32 |
+
if shape_only:
|
33 |
+
displacement_term = 0
|
34 |
+
else:
|
35 |
+
# Compute centers
|
36 |
+
m1 = torch.vstack(conic_center(A1)).T[..., None]
|
37 |
+
m2 = torch.vstack(conic_center(A2)).T[..., None]
|
38 |
+
|
39 |
+
# Mean difference term
|
40 |
+
displacement_term = torch.sum((m1 - m2) ** 2, dim=(1, 2))
|
41 |
+
|
42 |
+
# Compute the matrix square root term
|
43 |
+
eigenvalues1, eigenvectors1 = torch.linalg.eigh(cov1)
|
44 |
+
sqrt_eigenvalues1 = torch.sqrt(torch.clamp(eigenvalues1, min=1e-7))
|
45 |
+
sqrt_cov1 = (
|
46 |
+
eigenvectors1
|
47 |
+
@ torch.diag_embed(sqrt_eigenvalues1)
|
48 |
+
@ eigenvectors1.transpose(-2, -1)
|
49 |
+
)
|
50 |
+
|
51 |
+
inner_term = sqrt_cov1 @ cov2 @ sqrt_cov1
|
52 |
+
eigenvalues_inner, eigenvectors_inner = torch.linalg.eigh(inner_term)
|
53 |
+
sqrt_inner = (
|
54 |
+
eigenvectors_inner
|
55 |
+
@ torch.diag_embed(torch.sqrt(torch.clamp(eigenvalues_inner, min=1e-7)))
|
56 |
+
@ eigenvectors_inner.transpose(-2, -1)
|
57 |
+
)
|
58 |
+
|
59 |
+
trace_term = (
|
60 |
+
torch.diagonal(cov1, dim1=-2, dim2=-1).sum(-1)
|
61 |
+
+ torch.diagonal(cov2, dim1=-2, dim2=-1).sum(-1)
|
62 |
+
- 2 * torch.diagonal(sqrt_inner, dim1=-2, dim2=-1).sum(-1)
|
63 |
+
)
|
64 |
+
|
65 |
+
return displacement_term + trace_term
|
66 |
+
|
67 |
+
|
68 |
+
def symmetric_wasserstein_distance(
|
69 |
+
A1: torch.Tensor,
|
70 |
+
A2: torch.Tensor,
|
71 |
+
*,
|
72 |
+
shape_only: bool = False,
|
73 |
+
nan_to_num: float = float(1e4),
|
74 |
+
normalize: bool = False,
|
75 |
+
) -> torch.Tensor:
|
76 |
+
"""
|
77 |
+
Compute symmetric Wasserstein distance between ellipses.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
A1, A2: Ellipse matrices
|
81 |
+
shape_only: If True, ignores displacement term
|
82 |
+
nan_to_num: Value to replace NaN entries with
|
83 |
+
normalize: If True, normalizes the output to [0, 1]
|
84 |
+
"""
|
85 |
+
w = torch.nan_to_num(
|
86 |
+
wasserstein_distance(A1, A2, shape_only=shape_only), nan=nan_to_num
|
87 |
+
)
|
88 |
+
|
89 |
+
if w.lt(0).any():
|
90 |
+
raise ValueError("Negative Wasserstein distance encountered.")
|
91 |
+
|
92 |
+
if normalize:
|
93 |
+
w = 1 - torch.exp(-w)
|
94 |
+
return w
|
95 |
+
|
96 |
+
|
97 |
+
class WassersteinLoss(torch.nn.Module):
|
98 |
+
"""
|
99 |
+
Computes the Wasserstein distance loss between two ellipse tensors.
|
100 |
+
|
101 |
+
The Wasserstein distance provides a natural metric for comparing probability
|
102 |
+
distributions or shapes, with advantages over KL divergence such as:
|
103 |
+
- It's symmetric by definition
|
104 |
+
- It provides a true metric (satisfies triangle inequality)
|
105 |
+
- It's well-behaved even when distributions have different supports
|
106 |
+
|
107 |
+
Attributes:
|
108 |
+
shape_only: If True, computes distance based on shape without considering position
|
109 |
+
nan_to_num: Value to replace NaN entries with
|
110 |
+
normalize: If True, normalizes output to [0, 1] using exponential scaling
|
111 |
+
"""
|
112 |
+
|
113 |
+
def __init__(
|
114 |
+
self, shape_only: bool = True, nan_to_num: float = 10.0, normalize: bool = False
|
115 |
+
):
|
116 |
+
super().__init__()
|
117 |
+
self.shape_only = shape_only
|
118 |
+
self.nan_to_num = nan_to_num
|
119 |
+
self.normalize = normalize
|
120 |
+
|
121 |
+
def forward(self, A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor:
|
122 |
+
return symmetric_wasserstein_distance(
|
123 |
+
A1,
|
124 |
+
A2,
|
125 |
+
shape_only=self.shape_only,
|
126 |
+
nan_to_num=self.nan_to_num,
|
127 |
+
normalize=self.normalize,
|
128 |
+
)
|
ellipse_rcnn/utils/__init__.py
ADDED
File without changes
|
ellipse_rcnn/utils/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (180 Bytes). View file
|
|
ellipse_rcnn/utils/__pycache__/conics.cpython-312.pyc
ADDED
Binary file (8.37 kB). View file
|
|
ellipse_rcnn/utils/__pycache__/types.cpython-312.pyc
ADDED
Binary file (2.73 kB). View file
|
|
ellipse_rcnn/utils/__pycache__/viz.cpython-312.pyc
ADDED
Binary file (4.9 kB). View file
|
|
ellipse_rcnn/utils/conics.py
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
@torch.jit.script
|
7 |
+
def adjugate_matrix(matrix: torch.Tensor) -> torch.Tensor:
|
8 |
+
"""Return adjugate matrix [1].
|
9 |
+
|
10 |
+
Parameters
|
11 |
+
----------
|
12 |
+
matrix:
|
13 |
+
Input matrix
|
14 |
+
|
15 |
+
Returns
|
16 |
+
-------
|
17 |
+
torch.Tensor
|
18 |
+
Adjugate of input matrix
|
19 |
+
|
20 |
+
References
|
21 |
+
----------
|
22 |
+
.. [1] https://en.wikipedia.org/wiki/Adjugate_matrix
|
23 |
+
"""
|
24 |
+
|
25 |
+
cofactor = torch.inverse(matrix).T * torch.det(matrix)
|
26 |
+
return cofactor.T
|
27 |
+
|
28 |
+
|
29 |
+
# @torch.jit.script
|
30 |
+
def unimodular_matrix(matrix: torch.Tensor) -> torch.Tensor:
|
31 |
+
"""Rescale matrix such that det(ellipses) = 1, in other words, make it unimodular. Doest not work with tensors
|
32 |
+
of dtype torch.float64.
|
33 |
+
|
34 |
+
Parameters
|
35 |
+
----------
|
36 |
+
matrix:
|
37 |
+
Matrix input
|
38 |
+
|
39 |
+
Returns
|
40 |
+
-------
|
41 |
+
torch.Tensor
|
42 |
+
Unimodular version of input matrix.
|
43 |
+
"""
|
44 |
+
val = 1.0 / torch.det(matrix)
|
45 |
+
return (torch.sign(val) * torch.pow(torch.abs(val), 1.0 / 3.0))[
|
46 |
+
..., None, None
|
47 |
+
] * matrix
|
48 |
+
|
49 |
+
|
50 |
+
# @torch.jit.script
|
51 |
+
def ellipse_to_conic_matrix(
|
52 |
+
*,
|
53 |
+
a: torch.Tensor,
|
54 |
+
b: torch.Tensor,
|
55 |
+
x: torch.Tensor | None = None,
|
56 |
+
y: torch.Tensor | None = None,
|
57 |
+
theta: torch.Tensor | None = None,
|
58 |
+
) -> torch.Tensor:
|
59 |
+
r"""Returns matrix representation for crater derived from ellipse parameters such that _[1]:
|
60 |
+
|
61 |
+
| A = a²(sin θ)² + b²(cos θ)²
|
62 |
+
| B = 2(b² - a²) sin θ cos θ
|
63 |
+
| C = a²(cos θ)² + b²(sin θ)²
|
64 |
+
| D = -2Ax₀ - By₀
|
65 |
+
| E = -Bx₀ - 2Cy₀
|
66 |
+
| F = Ax₀² + Bx₀y₀ + Cy₀² - a²b²
|
67 |
+
|
68 |
+
Resulting in a conic matrix:
|
69 |
+
::
|
70 |
+
|A B/2 D/2 |
|
71 |
+
M = |B/2 C E/2 |
|
72 |
+
|D/2 E/2 G |
|
73 |
+
|
74 |
+
Parameters
|
75 |
+
----------
|
76 |
+
a:
|
77 |
+
Semi-Major ellipse axis
|
78 |
+
b:
|
79 |
+
Semi-Minor ellipse axis
|
80 |
+
theta:
|
81 |
+
Ellipse angle (radians)
|
82 |
+
x:
|
83 |
+
X-position in 2D cartesian coordinate system (coplanar)
|
84 |
+
y:
|
85 |
+
Y-position in 2D cartesian coordinate system (coplanar)
|
86 |
+
|
87 |
+
Returns
|
88 |
+
-------
|
89 |
+
torch.Tensor
|
90 |
+
Array of ellipse matrices
|
91 |
+
|
92 |
+
References
|
93 |
+
----------
|
94 |
+
.. [1] https://www.researchgate.net/publication/355490899_Lunar_Crater_Identification_in_Digital_Images
|
95 |
+
"""
|
96 |
+
|
97 |
+
x = x if x is not None else torch.zeros(1)
|
98 |
+
y = y if y is not None else torch.zeros(1)
|
99 |
+
theta = theta if theta is not None else torch.zeros(1)
|
100 |
+
|
101 |
+
sin_theta = torch.sin(theta)
|
102 |
+
cos_theta = torch.cos(theta)
|
103 |
+
|
104 |
+
a2 = a**2
|
105 |
+
b2 = b**2
|
106 |
+
|
107 |
+
A = a2 * sin_theta**2 + b2 * cos_theta**2
|
108 |
+
B = 2 * (b2 - a2) * sin_theta * cos_theta
|
109 |
+
C = a2 * cos_theta**2 + b2 * sin_theta**2
|
110 |
+
D = -2 * A * x - B * y
|
111 |
+
F = -B * x - 2 * C * y
|
112 |
+
G = A * (x**2) + B * x * y + C * (y**2) - a2 * b2
|
113 |
+
|
114 |
+
# Create (array of) of conic matrix (N, 3, 3)
|
115 |
+
conic_matrix = torch.stack(
|
116 |
+
tensors=(
|
117 |
+
torch.stack((A, B / 2, D / 2), dim=-1),
|
118 |
+
torch.stack((B / 2, C, F / 2), dim=-1),
|
119 |
+
torch.stack((D / 2, F / 2, G), dim=-1),
|
120 |
+
),
|
121 |
+
dim=-1,
|
122 |
+
)
|
123 |
+
|
124 |
+
return conic_matrix.squeeze()
|
125 |
+
|
126 |
+
|
127 |
+
def conic_center(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
128 |
+
"""Returns center of ellipse in 2D cartesian coordinate system with numerical stability."""
|
129 |
+
# Extract the top-left 2x2 submatrix of the conic matrix
|
130 |
+
A = conic_matrix[..., :2, :2]
|
131 |
+
|
132 |
+
# Add stabilization for pseudoinverse computation by clamping singular values
|
133 |
+
A_pinv = torch.linalg.pinv(A, rcond=torch.finfo(A.dtype).eps)
|
134 |
+
|
135 |
+
# Extract the last two rows for the linear term
|
136 |
+
b = -conic_matrix[..., :2, 2][..., None]
|
137 |
+
|
138 |
+
# Stabilize any potential numerical instabilities
|
139 |
+
centers = torch.matmul(A_pinv, b).squeeze()
|
140 |
+
|
141 |
+
return centers[..., 0], centers[..., 1]
|
142 |
+
|
143 |
+
|
144 |
+
def ellipse_axes(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
145 |
+
"""Returns semi-major and semi-minor axes of ellipse in 2D cartesian coordinate system."""
|
146 |
+
lambdas = (
|
147 |
+
torch.linalg.eigvalsh(conic_matrix[..., :2, :2])
|
148 |
+
/ (-torch.det(conic_matrix) / torch.det(conic_matrix[..., :2, :2]))[..., None]
|
149 |
+
)
|
150 |
+
axes = torch.sqrt(1 / lambdas)
|
151 |
+
return axes[..., 0], axes[..., 1]
|
152 |
+
|
153 |
+
|
154 |
+
def ellipse_angle(conic_matrix: torch.Tensor) -> torch.Tensor:
|
155 |
+
"""Returns angle of ellipse in radians w.r.t. x-axis."""
|
156 |
+
return (
|
157 |
+
-torch.atan2(
|
158 |
+
2 * conic_matrix[..., 1, 0],
|
159 |
+
conic_matrix[..., 1, 1] - conic_matrix[..., 0, 0],
|
160 |
+
)
|
161 |
+
/ 2
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
def bbox_ellipse(
|
166 |
+
ellipses: torch.Tensor,
|
167 |
+
box_type: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
|
168 |
+
) -> torch.Tensor:
|
169 |
+
"""Converts (array of) ellipse matrices to bounding box tensor with format [xmin, ymin, xmax, ymax].
|
170 |
+
|
171 |
+
Parameters
|
172 |
+
----------
|
173 |
+
ellipses:
|
174 |
+
Array of ellipse matrices
|
175 |
+
box_type:
|
176 |
+
Format of bounding boxes, default is "xyxy"
|
177 |
+
|
178 |
+
Returns
|
179 |
+
-------
|
180 |
+
Array of bounding boxes
|
181 |
+
"""
|
182 |
+
cx, cy = conic_center(ellipses)
|
183 |
+
theta = ellipse_angle(ellipses)
|
184 |
+
semi_major_axis, semi_minor_axis = ellipse_axes(ellipses)
|
185 |
+
|
186 |
+
ux, uy = semi_major_axis * torch.cos(theta), semi_major_axis * torch.sin(theta)
|
187 |
+
vx, vy = (
|
188 |
+
semi_minor_axis * torch.cos(theta + torch.pi / 2),
|
189 |
+
semi_minor_axis * torch.sin(theta + torch.pi / 2),
|
190 |
+
)
|
191 |
+
|
192 |
+
box_halfwidth = torch.sqrt(ux**2 + vx**2)
|
193 |
+
box_halfheight = torch.sqrt(uy**2 + vy**2)
|
194 |
+
|
195 |
+
bboxes = torch.vstack(
|
196 |
+
(
|
197 |
+
cx - box_halfwidth,
|
198 |
+
cy - box_halfheight,
|
199 |
+
cx + box_halfwidth,
|
200 |
+
cy + box_halfheight,
|
201 |
+
)
|
202 |
+
).T
|
203 |
+
|
204 |
+
if box_type != "xyxy":
|
205 |
+
from torchvision.ops import boxes as box_ops
|
206 |
+
|
207 |
+
bboxes = box_ops.box_convert(bboxes, in_fmt="xyxy", out_fmt=box_type)
|
208 |
+
|
209 |
+
return bboxes
|
ellipse_rcnn/utils/data/__init__.py
ADDED
File without changes
|
ellipse_rcnn/utils/data/__pycache__/__init__.cpython-312.pyc
ADDED
Binary file (185 Bytes). View file
|
|
ellipse_rcnn/utils/data/__pycache__/base.cpython-312.pyc
ADDED
Binary file (2.01 kB). View file
|
|
ellipse_rcnn/utils/data/base.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
|
6 |
+
from ellipse_rcnn.utils.types import (
|
7 |
+
TargetDict,
|
8 |
+
CollatedBatchType,
|
9 |
+
UncollatedBatchType,
|
10 |
+
)
|
11 |
+
|
12 |
+
|
13 |
+
def collate_fn(batch: UncollatedBatchType) -> CollatedBatchType:
|
14 |
+
"""
|
15 |
+
Collate function for the :class:`DataLoader`.
|
16 |
+
|
17 |
+
Parameters
|
18 |
+
----------
|
19 |
+
batch:
|
20 |
+
A batch of data.
|
21 |
+
"""
|
22 |
+
return tuple(zip(*batch)) # type: ignore
|
23 |
+
|
24 |
+
|
25 |
+
class EllipseDatasetBase(ABC, Dataset):
|
26 |
+
@abstractmethod
|
27 |
+
def load_image(self, index: int) -> Any:
|
28 |
+
"""
|
29 |
+
Load the image for the given index.
|
30 |
+
|
31 |
+
Parameters
|
32 |
+
----------
|
33 |
+
index:
|
34 |
+
The index of the image.
|
35 |
+
|
36 |
+
Returns
|
37 |
+
-------
|
38 |
+
image:
|
39 |
+
The raw image.
|
40 |
+
"""
|
41 |
+
pass
|
42 |
+
|
43 |
+
@abstractmethod
|
44 |
+
def load_target_dict(self, index: int) -> TargetDict:
|
45 |
+
"""
|
46 |
+
Load the target dict for the given index.
|
47 |
+
|
48 |
+
Parameters
|
49 |
+
----------
|
50 |
+
index:
|
51 |
+
The index of the target dict.
|
52 |
+
|
53 |
+
Returns
|
54 |
+
-------
|
55 |
+
target_dict:
|
56 |
+
The target dictionary.
|
57 |
+
"""
|
58 |
+
pass
|
59 |
+
|
60 |
+
@abstractmethod
|
61 |
+
def __len__(self) -> int:
|
62 |
+
pass
|
ellipse_rcnn/utils/data/craters.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import h5py
|
2 |
+
import torch
|
3 |
+
from torch.utils.data import Dataset
|
4 |
+
|
5 |
+
from ellipse_rcnn.utils.types import TargetDict, ImageTargetTuple
|
6 |
+
from ellipse_rcnn.utils.conics import bbox_ellipse
|
7 |
+
|
8 |
+
|
9 |
+
class CraterEllipseDataset(Dataset):
|
10 |
+
"""
|
11 |
+
Dataset for crater ellipse detection. Mostly meant as an example in combination with
|
12 |
+
https://github.com/wdoppenberg/crater-detection.
|
13 |
+
"""
|
14 |
+
|
15 |
+
def __init__(self, file_path: str, group: str) -> None:
|
16 |
+
self.file_path = file_path
|
17 |
+
self.group = group
|
18 |
+
|
19 |
+
def __getitem__(self, idx: torch.Tensor) -> ImageTargetTuple:
|
20 |
+
with h5py.File(self.file_path, "r") as dataset:
|
21 |
+
image = torch.tensor(dataset[self.group]["images"][idx])
|
22 |
+
|
23 |
+
# The number of instances can vary, hence it was decided to use a separate array with the indices of the
|
24 |
+
# instances.
|
25 |
+
start_idx = dataset[self.group]["craters/crater_list_idx"][idx]
|
26 |
+
end_idx = dataset[self.group]["craters/crater_list_idx"][idx + 1]
|
27 |
+
ellipse_matrices = torch.tensor(
|
28 |
+
dataset[self.group]["craters/A_craters"][start_idx:end_idx]
|
29 |
+
)
|
30 |
+
|
31 |
+
boxes = bbox_ellipse(ellipse_matrices)
|
32 |
+
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
33 |
+
|
34 |
+
num_objs = len(boxes)
|
35 |
+
|
36 |
+
labels = torch.ones((num_objs,), dtype=torch.int64)
|
37 |
+
image_id = torch.tensor([idx])
|
38 |
+
|
39 |
+
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
|
40 |
+
|
41 |
+
target = TargetDict(
|
42 |
+
boxes=boxes,
|
43 |
+
labels=labels,
|
44 |
+
image_id=image_id,
|
45 |
+
area=area,
|
46 |
+
iscrowd=iscrowd,
|
47 |
+
ellipse_matrices=ellipse_matrices,
|
48 |
+
)
|
49 |
+
|
50 |
+
return image, target
|
51 |
+
|
52 |
+
def __len__(self) -> int:
|
53 |
+
with h5py.File(self.file_path, "r") as f:
|
54 |
+
return len(f[self.group]["images"])
|
ellipse_rcnn/utils/data/fddb.py
ADDED
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Data loader and module for the FDDB dataset.
|
3 |
+
https://vis-www.cs.umass.edu/fddb/
|
4 |
+
"""
|
5 |
+
|
6 |
+
from glob import glob
|
7 |
+
from typing import Any
|
8 |
+
from pathlib import Path
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import pandas as pd
|
12 |
+
import PIL.Image
|
13 |
+
import torchvision.transforms
|
14 |
+
import pytorch_lightning as pl
|
15 |
+
from torch.utils.data import DataLoader, random_split
|
16 |
+
|
17 |
+
from ellipse_rcnn.utils.types import TargetDict, ImageTargetTuple, EllipseTuple
|
18 |
+
from ellipse_rcnn.utils.conics import bbox_ellipse, ellipse_to_conic_matrix, conic_center, unimodular_matrix
|
19 |
+
from ellipse_rcnn.utils.data.base import EllipseDatasetBase, collate_fn
|
20 |
+
|
21 |
+
|
22 |
+
def preprocess_label_files(root_path: str) -> dict[str, list[EllipseTuple]]:
|
23 |
+
label_files = glob(f"{root_path}/labels/*.txt")
|
24 |
+
|
25 |
+
file_paths = []
|
26 |
+
ellipse_data = []
|
27 |
+
|
28 |
+
for filename in label_files:
|
29 |
+
with open(filename) as f:
|
30 |
+
if "ellipseList" not in filename:
|
31 |
+
file_paths += [p.strip("\n") for p in f.readlines()]
|
32 |
+
else:
|
33 |
+
ellipse_data += [p.strip("\n") for p in f.readlines()]
|
34 |
+
|
35 |
+
pdf_file_paths = pd.DataFrame({"path": file_paths})
|
36 |
+
pdf_file_paths["path_idx"] = pdf_file_paths.index
|
37 |
+
|
38 |
+
pdf_ellipse_data = pd.DataFrame({"data": ellipse_data})
|
39 |
+
pdf_ellipse_data["data_idx"] = pdf_ellipse_data.index
|
40 |
+
|
41 |
+
pdf_file_data_mapping = pdf_file_paths.merge(
|
42 |
+
pdf_ellipse_data, left_on="path", right_on="data", how="left"
|
43 |
+
)
|
44 |
+
|
45 |
+
ellipse_dict: dict[str, list[EllipseTuple]] = {
|
46 |
+
str(k): [] for k in pdf_file_paths["path"]
|
47 |
+
}
|
48 |
+
|
49 |
+
for i, r in pdf_file_data_mapping.iterrows():
|
50 |
+
data_idx = r["data_idx"]
|
51 |
+
num_ellipses = int(ellipse_data[data_idx + 1])
|
52 |
+
file_path = r["path"]
|
53 |
+
for j in range(data_idx + 2, data_idx + num_ellipses + 2):
|
54 |
+
a, b, theta, x, y = [
|
55 |
+
float(v) for v in ellipse_data[j].split(" ")[:-1] if len(v) > 0
|
56 |
+
]
|
57 |
+
ellipse_params = EllipseTuple(a, b, theta, x, y)
|
58 |
+
ellipse_dict[file_path].append(ellipse_params)
|
59 |
+
|
60 |
+
return ellipse_dict
|
61 |
+
|
62 |
+
|
63 |
+
class FDDB(EllipseDatasetBase):
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
root_path: str | Path,
|
67 |
+
ellipse_dict: dict[str, list[EllipseTuple]] | None = None,
|
68 |
+
transform: Any = None,
|
69 |
+
) -> None:
|
70 |
+
self.root_path = Path(root_path) if isinstance(root_path, str) else root_path
|
71 |
+
if transform is None:
|
72 |
+
self.transform = torchvision.transforms.Compose(
|
73 |
+
[
|
74 |
+
torchvision.transforms.ToTensor(),
|
75 |
+
torchvision.transforms.Normalize(
|
76 |
+
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
|
77 |
+
),
|
78 |
+
]
|
79 |
+
)
|
80 |
+
else:
|
81 |
+
self.transform = transform
|
82 |
+
self.ellipse_dict = ellipse_dict or preprocess_label_files(root_path)
|
83 |
+
|
84 |
+
def __len__(self) -> int:
|
85 |
+
return len(self.ellipse_dict)
|
86 |
+
|
87 |
+
def load_target_dict(self, index: int) -> TargetDict:
|
88 |
+
key = list(self.ellipse_dict.keys())[index]
|
89 |
+
ellipses_list = self.ellipse_dict[key]
|
90 |
+
|
91 |
+
a = torch.tensor([[e.a for e in ellipses_list]])
|
92 |
+
b = torch.tensor([[e.b for e in ellipses_list]])
|
93 |
+
theta = torch.tensor([[e.theta for e in ellipses_list]])
|
94 |
+
x = torch.tensor([[e.x for e in ellipses_list]])
|
95 |
+
y = torch.tensor([[e.y for e in ellipses_list]])
|
96 |
+
|
97 |
+
ellipse_matrices = ellipse_to_conic_matrix(a=a, b=b, x=x, y=y, theta=theta)
|
98 |
+
|
99 |
+
if torch.stack(conic_center(ellipse_matrices)).isnan().any():
|
100 |
+
raise ValueError("NaN values in ellipse matrices. Please check the data.")
|
101 |
+
|
102 |
+
if len(ellipse_matrices.shape) == 2:
|
103 |
+
ellipse_matrices = ellipse_matrices.unsqueeze(0)
|
104 |
+
|
105 |
+
boxes = bbox_ellipse(ellipse_matrices, box_type="xyxy")
|
106 |
+
|
107 |
+
num_objs = len(boxes)
|
108 |
+
|
109 |
+
labels = torch.ones((num_objs,), dtype=torch.int64)
|
110 |
+
image_id = torch.tensor([index])
|
111 |
+
area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
|
112 |
+
iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
|
113 |
+
|
114 |
+
target = TargetDict(
|
115 |
+
boxes=boxes,
|
116 |
+
labels=labels,
|
117 |
+
image_id=image_id,
|
118 |
+
area=area,
|
119 |
+
iscrowd=iscrowd,
|
120 |
+
ellipse_matrices=ellipse_matrices,
|
121 |
+
)
|
122 |
+
|
123 |
+
return target
|
124 |
+
|
125 |
+
def load_image(self, index: int) -> PIL.Image.Image:
|
126 |
+
key = list(self.ellipse_dict.keys())[index]
|
127 |
+
file_path = str(Path(self.root_path) / "images" / Path(key)) + ".jpg"
|
128 |
+
return PIL.Image.open(file_path)
|
129 |
+
|
130 |
+
def __getitem__(self, idx: int) -> ImageTargetTuple:
|
131 |
+
image = self.load_image(idx)
|
132 |
+
target_dict = self.load_target_dict(idx)
|
133 |
+
|
134 |
+
# If the image is grayscale, convert it to RGB
|
135 |
+
if image.mode == "L":
|
136 |
+
image = image.convert("RGB")
|
137 |
+
|
138 |
+
image = self.transform(image)
|
139 |
+
|
140 |
+
return image, target_dict
|
141 |
+
|
142 |
+
def __repr__(self) -> str:
|
143 |
+
return f"FDDB<img={len(self)}>"
|
144 |
+
|
145 |
+
def split(self, fraction: float, shuffle: bool = False) -> tuple["FDDB", "FDDB"]:
|
146 |
+
"""
|
147 |
+
Splits the dataset into two subsets based on the given fraction.
|
148 |
+
|
149 |
+
Args:
|
150 |
+
fraction (float): Fraction of the dataset for the first subset (0 < fraction < 1).
|
151 |
+
shuffle (bool): If True, dataset keys will be shuffled before splitting.
|
152 |
+
|
153 |
+
Returns:
|
154 |
+
tuple[FDDB, FDDB]: Two FDDB instances, one with the fraction of data,
|
155 |
+
and the other with the remaining data.
|
156 |
+
"""
|
157 |
+
if not (0 < fraction < 1):
|
158 |
+
raise ValueError("The fraction must be between 0 and 1.")
|
159 |
+
|
160 |
+
keys = list(self.ellipse_dict.keys())
|
161 |
+
if shuffle:
|
162 |
+
import random
|
163 |
+
|
164 |
+
random.shuffle(keys)
|
165 |
+
|
166 |
+
total_length = len(keys)
|
167 |
+
split_index = int(total_length * fraction)
|
168 |
+
|
169 |
+
subset1_keys = keys[:split_index]
|
170 |
+
subset2_keys = keys[split_index:]
|
171 |
+
|
172 |
+
subset1_ellipse_dict = {key: self.ellipse_dict[key] for key in subset1_keys}
|
173 |
+
subset2_ellipse_dict = {key: self.ellipse_dict[key] for key in subset2_keys}
|
174 |
+
|
175 |
+
subset1 = FDDB(
|
176 |
+
self.root_path, ellipse_dict=subset1_ellipse_dict, transform=self.transform
|
177 |
+
)
|
178 |
+
subset2 = FDDB(
|
179 |
+
self.root_path, ellipse_dict=subset2_ellipse_dict, transform=self.transform
|
180 |
+
)
|
181 |
+
|
182 |
+
return subset1, subset2
|
183 |
+
|
184 |
+
|
185 |
+
class FDDBLightningDataModule(pl.LightningDataModule):
|
186 |
+
def __init__(
|
187 |
+
self,
|
188 |
+
data_dir: str,
|
189 |
+
batch_size: int = 16,
|
190 |
+
train_fraction: float = 0.8,
|
191 |
+
transform: Any = None,
|
192 |
+
num_workers: int = 0,
|
193 |
+
) -> None:
|
194 |
+
super().__init__()
|
195 |
+
self.data_dir = data_dir
|
196 |
+
self.batch_size = batch_size
|
197 |
+
self.train_fraction = train_fraction
|
198 |
+
self.transform = transform
|
199 |
+
self.dataset: FDDB | None = None
|
200 |
+
self.train_dataset = None
|
201 |
+
self.val_dataset = None
|
202 |
+
self.num_workers = num_workers
|
203 |
+
|
204 |
+
def prepare_data(self) -> None:
|
205 |
+
# Ensure data preparation or downloading is done here.
|
206 |
+
pass
|
207 |
+
|
208 |
+
def setup(self, stage: str | None = None) -> None:
|
209 |
+
# Instantiate the FDDB dataset and split it into training and validation subsets.
|
210 |
+
self.dataset = FDDB(self.data_dir, transform=self.transform)
|
211 |
+
|
212 |
+
train_size = int(len(self.dataset) * self.train_fraction)
|
213 |
+
val_size = len(self.dataset) - train_size
|
214 |
+
self.train_dataset, self.val_dataset = random_split(
|
215 |
+
self.dataset, [train_size, val_size]
|
216 |
+
)
|
217 |
+
|
218 |
+
def train_dataloader(self) -> DataLoader[ImageTargetTuple]:
|
219 |
+
return DataLoader(
|
220 |
+
self.train_dataset,
|
221 |
+
batch_size=self.batch_size,
|
222 |
+
shuffle=True,
|
223 |
+
collate_fn=collate_fn,
|
224 |
+
num_workers=self.num_workers,
|
225 |
+
)
|
226 |
+
|
227 |
+
def val_dataloader(self) -> DataLoader[ImageTargetTuple]:
|
228 |
+
return DataLoader(
|
229 |
+
self.val_dataset,
|
230 |
+
batch_size=self.batch_size,
|
231 |
+
collate_fn=collate_fn,
|
232 |
+
num_workers=self.num_workers,
|
233 |
+
)
|
234 |
+
|
235 |
+
def test_dataloader(self) -> DataLoader[ImageTargetTuple]:
|
236 |
+
# Placeholder for test data; currently returns the validation dataloader as a default.
|
237 |
+
return DataLoader(
|
238 |
+
self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn
|
239 |
+
)
|
ellipse_rcnn/utils/types.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import TypedDict, NamedTuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class TargetDict(TypedDict):
|
7 |
+
boxes: torch.Tensor
|
8 |
+
labels: torch.Tensor
|
9 |
+
image_id: torch.Tensor
|
10 |
+
area: torch.Tensor
|
11 |
+
iscrowd: torch.Tensor
|
12 |
+
ellipse_matrices: torch.Tensor
|
13 |
+
|
14 |
+
|
15 |
+
class LossDict(TypedDict, total=False):
|
16 |
+
loss_classifier: torch.Tensor
|
17 |
+
loss_box_reg: torch.Tensor
|
18 |
+
loss_objectness: torch.Tensor
|
19 |
+
loss_rpn_box_reg: torch.Tensor
|
20 |
+
loss_ellipse_kld: torch.Tensor
|
21 |
+
loss_ellipse_smooth_l1: torch.Tensor
|
22 |
+
loss_total: torch.Tensor
|
23 |
+
|
24 |
+
|
25 |
+
class PredictionDict(TypedDict):
|
26 |
+
bboxes: torch.Tensor
|
27 |
+
labels: torch.Tensor
|
28 |
+
scores: torch.Tensor
|
29 |
+
ellipse_matrices: torch.Tensor
|
30 |
+
|
31 |
+
|
32 |
+
type ImageTargetTuple = tuple[torch.Tensor, TargetDict] # Tensor shape: (C, H, W)
|
33 |
+
type CollatedBatchType = tuple[
|
34 |
+
tuple[torch.Tensor, ...], tuple[TargetDict, ...]
|
35 |
+
] # Tensor shape: (C, H, W)
|
36 |
+
type UncollatedBatchType = list[ImageTargetTuple]
|
37 |
+
|
38 |
+
type EllipseType = torch.Tensor
|
39 |
+
|
40 |
+
|
41 |
+
class EllipseTuple(NamedTuple):
|
42 |
+
a: float
|
43 |
+
b: float
|
44 |
+
theta: float
|
45 |
+
x: float
|
46 |
+
y: float
|
ellipse_rcnn/utils/viz.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
from typing import Literal
|
4 |
+
from torch import Tensor
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torchvision.ops import boxes as box_ops
|
8 |
+
from matplotlib import pyplot as plt
|
9 |
+
from matplotlib.axes import Axes
|
10 |
+
from matplotlib.collections import EllipseCollection, PatchCollection
|
11 |
+
from matplotlib.patches import Rectangle
|
12 |
+
from ellipse_rcnn.utils.conics import ellipse_angle, conic_center, ellipse_axes
|
13 |
+
from matplotlib.figure import Figure
|
14 |
+
|
15 |
+
|
16 |
+
def plot_single_pred(
|
17 |
+
image: Tensor,
|
18 |
+
prediction,
|
19 |
+
min_score: float = 0.75,
|
20 |
+
) -> Figure:
|
21 |
+
if isinstance(prediction, list):
|
22 |
+
if len(prediction) > 1:
|
23 |
+
raise ValueError(
|
24 |
+
"Multiple predictions detected. Please pass a single prediction."
|
25 |
+
)
|
26 |
+
prediction = prediction[0]
|
27 |
+
fig, ax = plt.subplots(1, 1, figsize=(10, 10))
|
28 |
+
fig.patch.set_alpha(0)
|
29 |
+
ax.imshow(image.permute(1, 2, 0), cmap="grey")
|
30 |
+
score_mask = prediction["scores"] > min_score
|
31 |
+
|
32 |
+
plot_ellipses(prediction["ellipse_matrices"][score_mask], ax=ax)
|
33 |
+
|
34 |
+
return fig
|
35 |
+
|
36 |
+
|
37 |
+
def plot_ellipses(
|
38 |
+
A_craters: torch.Tensor,
|
39 |
+
figsize: tuple[float, float] = (15, 15),
|
40 |
+
plot_centers: bool = False,
|
41 |
+
ax: Axes | None = None,
|
42 |
+
rim_color="r",
|
43 |
+
alpha=1.0,
|
44 |
+
):
|
45 |
+
a_proj, b_proj = ellipse_axes(A_craters)
|
46 |
+
psi_proj = ellipse_angle(A_craters)
|
47 |
+
x_pix_proj, y_pix_proj = conic_center(A_craters)
|
48 |
+
|
49 |
+
a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj = map(
|
50 |
+
lambda t: t.detach().cpu().numpy(),
|
51 |
+
(a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj),
|
52 |
+
)
|
53 |
+
|
54 |
+
if ax is None:
|
55 |
+
fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"})
|
56 |
+
|
57 |
+
ec = EllipseCollection(
|
58 |
+
a_proj * 2,
|
59 |
+
b_proj * 2,
|
60 |
+
np.degrees(psi_proj),
|
61 |
+
units="xy",
|
62 |
+
offsets=np.column_stack((x_pix_proj, y_pix_proj)),
|
63 |
+
transOffset=ax.transData,
|
64 |
+
facecolors="None",
|
65 |
+
edgecolors=rim_color,
|
66 |
+
alpha=alpha,
|
67 |
+
)
|
68 |
+
ax.add_collection(ec)
|
69 |
+
|
70 |
+
if plot_centers:
|
71 |
+
crater_centers = conic_center(A_craters)
|
72 |
+
for k, c_i in enumerate(crater_centers):
|
73 |
+
x, y = c_i[0], c_i[1]
|
74 |
+
ax.text(x.item(), y.item(), str(k), color=rim_color)
|
75 |
+
|
76 |
+
|
77 |
+
def plot_bboxes(
|
78 |
+
boxes: torch.Tensor,
|
79 |
+
box_type: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
|
80 |
+
figsize: tuple[float, float] = (15, 15),
|
81 |
+
plot_centers: bool = False,
|
82 |
+
ax: Axes | None = None,
|
83 |
+
rim_color="r",
|
84 |
+
alpha=1.0,
|
85 |
+
):
|
86 |
+
if ax is None:
|
87 |
+
fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"})
|
88 |
+
|
89 |
+
if box_type != "xyxy":
|
90 |
+
boxes = box_ops.box_convert(boxes, box_type, "xyxy")
|
91 |
+
|
92 |
+
boxes = boxes.detach().cpu().numpy()
|
93 |
+
rectangles = []
|
94 |
+
for k, b_i in enumerate(boxes):
|
95 |
+
x1, y1, x2, y2 = b_i
|
96 |
+
rectangles.append(Rectangle((x1, y1), x2 - x1, y2 - y1))
|
97 |
+
|
98 |
+
collection = PatchCollection(
|
99 |
+
rectangles, edgecolor=rim_color, facecolor="none", alpha=alpha
|
100 |
+
)
|
101 |
+
ax.add_collection(collection)
|
102 |
+
|
103 |
+
if plot_centers:
|
104 |
+
for k, b_i in enumerate(boxes):
|
105 |
+
x1, y1, x2, y2 = b_i
|
106 |
+
ax.text(x1, y1, str(k), color=rim_color)
|
examples/image1.jpg
ADDED
![]() |
examples/image2.jpg
ADDED
![]() |
examples/image3.jpg
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
torch
|
3 |
+
torchvision
|
4 |
+
matplotlib
|
5 |
+
Pillow
|
6 |
+
joblib
|
7 |
+
huggingface_hub
|
8 |
+
lightning
|