Filipstrozik commited on
Commit
afc2161
·
1 Parent(s): 2c3fe3e

Add initial implementation of EllipseRCNN model and dataset utilities

Browse files

- Introduced core model and loss functions for ellipse detection.
- Added dataset classes for loading and processing crater images.
- Included type definitions for better code clarity.
- Created requirements file for necessary dependencies.
- Added README documentation for core functionalities.

app.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ from ast import mod
3
+ import gradio as gr
4
+ from PIL import Image
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.patches as mpatches
7
+ import torchvision.transforms as transforms
8
+ import torch
9
+ from huggingface_hub import hf_hub_download
10
+ from ellipse_rcnn import EllipseRCNN
11
+
12
+
13
+ # load model.pth from Filipstrozik/sat-tree-detection-v0 repository in hugging face
14
+ load_state_dict = torch.load(
15
+ hf_hub_download("Filipstrozik/sat-tree-detection-v0", "model.pth"),
16
+ weights_only=True,
17
+ )
18
+ model = EllipseRCNN()
19
+
20
+ model.load_state_dict(load_state_dict)
21
+ model.eval()
22
+
23
+
24
+ def conic_center(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
25
+ """Returns center of ellipse in 2D cartesian coordinate system with numerical stability."""
26
+ # Extract the top-left 2x2 submatrix of the conic matrix
27
+ A = conic_matrix[..., :2, :2]
28
+
29
+ # Add stabilization for pseudoinverse computation by clamping singular values
30
+ A_pinv = torch.linalg.pinv(A, rcond=torch.finfo(A.dtype).eps)
31
+
32
+ # Extract the last two rows for the linear term
33
+ b = -conic_matrix[..., :2, 2][..., None]
34
+
35
+ # Stabilize any potential numerical instabilities
36
+ centers = torch.matmul(A_pinv, b).squeeze()
37
+
38
+ return centers[..., 0], centers[..., 1]
39
+
40
+
41
+ def ellipse_axes(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
42
+ """Returns semi-major and semi-minor axes of ellipse in 2D cartesian coordinate system."""
43
+ lambdas = (
44
+ torch.linalg.eigvalsh(conic_matrix[..., :2, :2])
45
+ / (-torch.det(conic_matrix) / torch.det(conic_matrix[..., :2, :2]))[..., None]
46
+ )
47
+ axes = torch.sqrt(1 / lambdas)
48
+ return axes[..., 0], axes[..., 1]
49
+
50
+
51
+ def ellipse_angle(conic_matrix: torch.Tensor) -> torch.Tensor:
52
+ """Returns angle of ellipse in radians w.r.t. x-axis."""
53
+ return (
54
+ -torch.atan2(
55
+ 2 * conic_matrix[..., 1, 0],
56
+ conic_matrix[..., 1, 1] - conic_matrix[..., 0, 0],
57
+ )
58
+ / 2
59
+ )
60
+
61
+
62
+ def get_ellipse_params_from_matrices(ellipse_matrices):
63
+ if ellipse_matrices.shape[0] == 0:
64
+ return None
65
+ a, b = ellipse_axes(ellipse_matrices)
66
+ cx, cy = conic_center(ellipse_matrices)
67
+ theta = ellipse_angle(ellipse_matrices)
68
+
69
+ a = a.view(-1)
70
+ b = b.view(-1)
71
+ cx = cx.view(-1)
72
+ cy = cy.view(-1)
73
+ theta = theta.view(-1)
74
+
75
+ ellipses = torch.stack([a, b, cx, cy, theta], dim=1).reshape(-1, 5)
76
+ return ellipses
77
+
78
+
79
+ def plot_ellipses(
80
+ ellipse_params: torch.Tensor,
81
+ image: torch.Tensor,
82
+ plot_centers: bool = False,
83
+ rim_color: str = "r",
84
+ alpha: float = 0.25,
85
+ ) -> None:
86
+ if ellipse_params is None:
87
+ return
88
+ a, b, cx, cy, theta = ellipse_params.unbind(-1)
89
+
90
+ # multiply all pixel values by 4
91
+ cx = cx * 4
92
+ cy = cy * 4
93
+
94
+ # draw ellipses
95
+ for i in range(len(a)):
96
+ ellipse = mpatches.Ellipse(
97
+ (cx[i], cy[i]),
98
+ width=a[i],
99
+ height=b[i],
100
+ angle=theta[i],
101
+ fill=True,
102
+ alpha=alpha,
103
+ color=rim_color,
104
+ )
105
+ plt.gca().add_patch(ellipse)
106
+
107
+ if plot_centers:
108
+ plt.scatter(cx[i], cy[i], c=rim_color, s=10)
109
+
110
+ plt.imshow(image)
111
+
112
+
113
+ # Define the necessary transformations and the inverse normalization
114
+ def invert_normalization(image, mean, std):
115
+ for t, m, s in zip(image, mean, std):
116
+ t.mul_(s).add_(m)
117
+ return torch.clamp(image, 0, 1)
118
+
119
+
120
+ def process_image(image):
121
+ original_size = image.size
122
+
123
+ # Define the transform pipeline
124
+ transform = transforms.Compose(
125
+ [
126
+ transforms.Resize((1024, 1024)),
127
+ transforms.PILToTensor(),
128
+ transforms.ConvertImageDtype(torch.float),
129
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
130
+ ]
131
+ )
132
+
133
+ image_tensor = transform(image).unsqueeze(0) # Add batch dimension
134
+ return image_tensor, original_size
135
+
136
+
137
+ def generate_prediction(image, rpn_nms_thresh, score_thresh, nms_thresh):
138
+ # Preprocess image
139
+ image_tensor, original_size = process_image(image)
140
+ image_tensor = image_tensor.to("cpu")
141
+
142
+ # Ensure the model is in evaluation mode
143
+ model.rpn.nms_thresh = rpn_nms_thresh
144
+ model.roi_heads.score_thresh = score_thresh
145
+ model.roi_heads.nms_thresh = nms_thresh
146
+
147
+ with torch.no_grad():
148
+ prediction = model(image_tensor)[0]
149
+
150
+ # Invert normalization for display
151
+ mean = [0.485, 0.456, 0.406]
152
+ std = [0.229, 0.224, 0.225]
153
+ inverted_image = (
154
+ invert_normalization(image_tensor, mean, std)
155
+ .squeeze(0)
156
+ .permute(1, 2, 0)
157
+ .cpu()
158
+ .numpy()
159
+ )
160
+
161
+ # Plot results with ellipses
162
+ plt.figure(figsize=(10, 10))
163
+ plt.imshow(inverted_image)
164
+ plot_ellipses(
165
+ get_ellipse_params_from_matrices(prediction["ellipse_matrices"]),
166
+ inverted_image,
167
+ plot_centers=True,
168
+ rim_color="red",
169
+ alpha=0.25,
170
+ )
171
+ red_patch = mpatches.Patch(color="red", label="Predicted")
172
+ plt.legend(handles=[red_patch], loc="upper right")
173
+ plt.gca().set_aspect(original_size[0] / original_size[1])
174
+ plt.axis("off")
175
+ plt.tight_layout()
176
+ # Save the figure to a buffer and return as an image
177
+ buf = io.BytesIO()
178
+ plt.savefig(buf, format="png")
179
+ buf.seek(0)
180
+ with Image.open(buf) as output_image:
181
+ output_image = output_image.copy()
182
+ buf.close()
183
+ return output_image
184
+
185
+
186
+ # Define Gradio interface
187
+ with gr.Blocks() as demo:
188
+ gr.Markdown("## Tree Detection from Satellite Images")
189
+ gr.Markdown("Upload an image and see the detected trees with ellipses.")
190
+
191
+ with gr.Row():
192
+ image_input = gr.Image(label="Input Image", type="pil")
193
+ image_output = gr.Image(label="Detected Trees")
194
+
195
+ examples = [
196
+ ["examples/image1.jpg"],
197
+ ["examples/image2.jpg"],
198
+ ["examples/image3.jpg"],
199
+ ]
200
+
201
+ with gr.Row():
202
+ rpn_nms_slider = gr.Slider(
203
+ 0.0, 1.0, value=model.rpn.nms_thresh, label="RPN NMS Threshold"
204
+ )
205
+ score_thresh_slider = gr.Slider(
206
+ 0.0,
207
+ 1.0,
208
+ value=model.roi_heads.score_thresh,
209
+ label="ROI Heads Score Threshold",
210
+ )
211
+ nms_thresh_slider = gr.Slider(
212
+ 0.0, 1.0, value=model.roi_heads.nms_thresh, label="ROI Heads NMS Threshold"
213
+ )
214
+
215
+ submit_button = gr.Button("Detect Trees")
216
+ submit_button.click(
217
+ fn=generate_prediction,
218
+ inputs=[image_input, rpn_nms_slider, score_thresh_slider, nms_thresh_slider],
219
+ outputs=image_output,
220
+ )
221
+
222
+ gr.Examples(examples=examples, inputs=image_input, outputs=image_output)
223
+
224
+
225
+ demo.launch()
ellipse_rcnn/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .core.model import EllipseRCNN
ellipse_rcnn/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (223 Bytes). View file
 
ellipse_rcnn/core/README.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # README.md
2
+
3
+ ## Symmetric Kullback-Leibler Divergence Loss
4
+
5
+ This directory provides an implementation of a **Symmetric Kullback-Leibler (KL) Divergence Loss** tailored for tensors
6
+ representing ellipses in matrix form. The loss function is designed to measure the difference between two elliptical
7
+ shapes and is particularly useful in optimization and generative modeling tasks.
8
+
9
+ ## Loss Calculation
10
+
11
+ ### **Kullback-Leibler Divergence**
12
+
13
+ For two ellipses represented by their matrix forms ( $A_1$ ) and ( $A_2$ ), the KL divergence is calculated as:
14
+ $$ D_{KL}(A_1 \parallel A_2) = \frac{1}{2} \left( \text{Tr}(C_2^{-1}C_1) + (\mu_1 - \mu_2)^T C_2^{-1} (\mu_1 - \mu_2) - 2 + \log\left(\frac{\det(C_2)}{\det(C_1)}\right) \right) $$
15
+ Where:
16
+
17
+ - ( $C_1$, $C_2$ ): Covariance matrices extracted from ( $A_1$, $A_2$ ).
18
+ - ( $\mu_1$, $\mu_2$ ): Centers (means) of the ellipses, computed from the conic representation.
19
+ - ( $\text{Tr}$ ): Trace operator.
20
+ - ( $C_2^{-1}$ ): Inverse of the covariance matrix of ( $A_2$ ).
21
+ - ( $\det(C_1)$, $\det(C_2)$ ): Determinants of covariance matrices.
22
+
23
+ A regularization term ( $\epsilon$ ) is added to ensure numerical stability when computing inverses and determinants.
24
+
25
+ ### **Symmetric KL Divergence**
26
+
27
+ The symmetric version of the KL divergence combines the calculations in both directions:
28
+ $$ D_{KL}^{\text{sym}}(A_1, A_2) = \frac{1}{2} \left( D_{KL}(A_1 \parallel A_2) + D_{KL}(A_2 \parallel A_1) \right) $$
29
+ This ensures a bidirectional comparison, making the function suitable as a loss metric in optimization tasks.
30
+
31
+ ## Features of the Loss
32
+
33
+ - **Shape-Only Comparison**: Option to ignore translation and compute divergence based purely on the shapes (covariance
34
+ matrices).
35
+ - **NaN Handling**: Replaces NaN values with a specified constant, ensuring robust loss evaluation.
36
+ - **Normalization**: An optional normalization step that rescales the divergence for certain applications.
37
+
38
+ ### Usage
39
+
40
+ The loss is encapsulated in the `SymmetricKLDLoss` class, which integrates seamlessly into PyTorch-based workflows.
ellipse_rcnn/core/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .model import EllipseRCNN # noqa: F401
ellipse_rcnn/core/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (223 Bytes). View file
 
ellipse_rcnn/core/__pycache__/ellipse_roi_head.cpython-312.pyc ADDED
Binary file (19.2 kB). View file
 
ellipse_rcnn/core/__pycache__/kld.cpython-312.pyc ADDED
Binary file (5.72 kB). View file
 
ellipse_rcnn/core/__pycache__/model.cpython-312.pyc ADDED
Binary file (10.3 kB). View file
 
ellipse_rcnn/core/__pycache__/wd.cpython-312.pyc ADDED
Binary file (5.82 kB). View file
 
ellipse_rcnn/core/ellipse_roi_head.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, List, Tuple, Optional, TypedDict, NamedTuple, Self
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+ from torchvision.models.detection.roi_heads import RoIHeads, fastrcnn_loss
7
+
8
+ from .kld import SymmetricKLDLoss
9
+ from .wd import WassersteinLoss
10
+ from ..utils.conics import (
11
+ ellipse_to_conic_matrix,
12
+ ellipse_axes,
13
+ ellipse_angle,
14
+ conic_center,
15
+ )
16
+
17
+
18
+ class RegressorPrediction(NamedTuple):
19
+ """
20
+ Represents the processed outputs of a regression model as a named tuple.
21
+
22
+ This class encapsulates regression model outputs in a structured format, where
23
+ each attribute corresponds to a specific component of the regression output.
24
+ These outputs can be directly used for post-processing steps such as transformation
25
+ into conic matrices or further evaluations of ellipse geometry.
26
+
27
+ Attributes
28
+ ----------
29
+ d_a : torch.Tensor
30
+ The normalized semi-major axis scale factor (logarithmic) used to compute
31
+ the actual semi-major axis length of ellipses.
32
+ d_b : torch.Tensor
33
+ The normalized semi-minor axis scale factor (logarithmic) used to compute
34
+ the actual semi-minor axis length of ellipses.
35
+ d_x : torch.Tensor
36
+ The normalized x-coordinate translation factor, specifying the adjustment
37
+ to the center of bounding boxes for ellipse placement.
38
+ d_y : torch.Tensor
39
+ The normalized y-coordinate translation factor, specifying the adjustment
40
+ to the center of bounding boxes for ellipse placement.
41
+ d_theta : torch.Tensor
42
+ The normalized rotation angle factor which is processed to derive the
43
+ actual rotation angle (in radians) of ellipses.
44
+
45
+ Notes
46
+ -----
47
+ - The attributes `d_a` and `d_b`, representing scale factors for the semi-major
48
+ and semi-minor axes, are typically bounded between 0 and 1 using a sigmoid activation.
49
+ - The attributes `d_x` and `d_y` serve as adjustments to bounding box centers, normalized
50
+ with respect to the bounding box diagonals.
51
+ - The attribute `d_theta` is normalized to ensure the rotation angle lies within
52
+ a valid range (after transformation, typically between -π/2 and π/2 radians).
53
+ - These normalized outputs are post-processed together with bounding box information
54
+ to construct actionable ellipse parameters such as their axes lengths, centers,
55
+ and angles.
56
+ - This structure simplifies downstream regression tasks, such as conversion into
57
+ conic matrices or calculation of geometrical losses.
58
+ """
59
+
60
+ d_a: torch.Tensor
61
+ d_b: torch.Tensor
62
+ d_theta: torch.Tensor
63
+
64
+ @property
65
+ def device(self) -> torch.device:
66
+ return self.d_a.device
67
+
68
+ @property
69
+ def dtype(self) -> torch.dtype:
70
+ return self.d_a.dtype
71
+
72
+ def split(self, split_size: list[int] | int, dim: int = 0) -> list[Self]:
73
+ return [
74
+ RegressorPrediction(*tensors)
75
+ for tensors in zip(
76
+ *[torch.split(attr, split_size, dim=dim) for attr in self]
77
+ )
78
+ ]
79
+
80
+
81
+ class EllipseRegressor(nn.Module):
82
+ """
83
+ EllipseRegressor is a neural network module designed to predict parameters of
84
+ an ellipse given input features.
85
+
86
+ This class is a PyTorch module that uses a feedforward neural network to predict
87
+ the normalized five parameters of an ellipse: semi-major axis `a`, semi-minor axis `b`, center
88
+ coordinates (`x`, `y`), and orientation `theta`. The class includes mechanisms
89
+ for batch normalization and uses Xavier weight initialization for improved
90
+ training stability and convergence.
91
+
92
+ Attributes
93
+ ----------
94
+ ffnn : nn.Sequential
95
+ A feedforward neural network with two hidden layers and ReLU activations.
96
+ """
97
+
98
+ def __init__(self, in_channels: int = 1024, hidden_size: int = 64):
99
+ super().__init__()
100
+ # Separate prediction heads for better gradient flow
101
+ self.ffnn = nn.Sequential(
102
+ nn.Linear(in_channels, hidden_size),
103
+ nn.ReLU(),
104
+ nn.Linear(hidden_size, 3),
105
+ nn.Tanh(),
106
+ )
107
+
108
+ # Initialize with small values
109
+ for lin in self.ffnn:
110
+ if isinstance(lin, nn.Linear):
111
+ nn.init.xavier_uniform_(lin.weight, gain=0.01)
112
+ nn.init.zeros_(lin.bias)
113
+
114
+ def forward(self, x: torch.Tensor) -> RegressorPrediction:
115
+ x = x.flatten(start_dim=1)
116
+ x = self.ffnn(x)
117
+
118
+ d_a, d_b, d_theta = x.unbind(dim=-1)
119
+
120
+ return RegressorPrediction(d_a=d_a, d_b=d_b, d_theta=d_theta)
121
+
122
+
123
+ def postprocess_ellipse_predictor(
124
+ pred: RegressorPrediction,
125
+ box_proposals: torch.Tensor,
126
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
127
+ """Processes elliptical predictor outputs and converts them into conic matrices.
128
+
129
+ Parameters
130
+ ----------
131
+ pred : RegressorPrediction
132
+ The output of the elliptical predictor model.
133
+ box_proposals : torch.Tensor
134
+ Tensor containing proposed bounding box information, with shape (N, 4). Each box
135
+ is represented as a 4-tuple (x_min, y_min, x_max, y_max).
136
+
137
+ Returns
138
+ -------
139
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
140
+ A tuple containing:
141
+ - a (torch.Tensor): Computed semi-major axis of the ellipses.
142
+ - b (torch.Tensor): Computed semi-minor axis of the ellipses.
143
+ - x (torch.Tensor): X-coordinates of the ellipse centers.
144
+ - y (torch.Tensor): Y-coordinates of the ellipse centers.
145
+ - theta (torch.Tensor): Rotation angles (in radians) for the ellipses.
146
+
147
+ """
148
+ d_a, d_b, d_theta = pred
149
+
150
+ # Pre-compute box width, height, and diagonal
151
+ box_width = box_proposals[:, 2] - box_proposals[:, 0]
152
+ box_height = box_proposals[:, 3] - box_proposals[:, 1]
153
+ box_diag = torch.sqrt(box_width**2 + box_height**2)
154
+
155
+ a = box_diag * d_a.exp()
156
+ b = box_diag * d_b.exp()
157
+
158
+ box_x = box_proposals[:, 0] + box_width * 0.5
159
+ box_y = box_proposals[:, 1] + box_height * 0.5
160
+
161
+ theta = (d_theta * 2.0 - 1.0) * (torch.pi / 2)
162
+
163
+ cos_theta = torch.cos(theta)
164
+ sin_theta = torch.sin(theta)
165
+ theta = torch.where(
166
+ cos_theta >= 0,
167
+ torch.atan2(sin_theta, cos_theta),
168
+ torch.atan2(-sin_theta, -cos_theta),
169
+ )
170
+
171
+ return a, b, box_x, box_y, theta
172
+
173
+
174
+ class EllipseLossDict(TypedDict):
175
+ loss_ellipse_kld: torch.Tensor
176
+ loss_ellipse_smooth_l1: torch.Tensor
177
+ loss_ellipse_wasserstein: torch.Tensor
178
+
179
+
180
+ def ellipse_loss(
181
+ pred: RegressorPrediction,
182
+ A_target: List[torch.Tensor],
183
+ pos_matched_idxs: List[torch.Tensor],
184
+ box_proposals: List[torch.Tensor],
185
+ kld_loss_fn: SymmetricKLDLoss,
186
+ wd_loss_fn: WassersteinLoss,
187
+ ) -> EllipseLossDict:
188
+ pos_matched_idxs_batched = torch.cat(pos_matched_idxs, dim=0)
189
+ A_target = torch.cat(A_target, dim=0)[pos_matched_idxs_batched]
190
+
191
+ box_proposals = torch.cat(box_proposals, dim=0)
192
+
193
+ if A_target.numel() == 0:
194
+ return {
195
+ "loss_ellipse_kld": torch.tensor(0.0, device=pred.device, dtype=pred.dtype),
196
+ "loss_ellipse_smooth_l1": torch.tensor(
197
+ 0.0, device=pred.device, dtype=pred.dtype
198
+ ),
199
+ "loss_ellipse_wasserstein": torch.tensor(
200
+ 0.0, device=pred.device, dtype=pred.dtype
201
+ ),
202
+ }
203
+
204
+ a_target, b_target = ellipse_axes(A_target)
205
+ theta_target = ellipse_angle(A_target)
206
+
207
+ # Box proposal parameters
208
+ box_width = box_proposals[:, 2] - box_proposals[:, 0]
209
+ box_height = box_proposals[:, 3] - box_proposals[:, 1]
210
+ box_diag = torch.sqrt(box_width**2 + box_height**2).clamp(min=1e-6)
211
+
212
+ # Normalize target variables
213
+ da_target = (a_target / box_diag).log()
214
+ db_target = (b_target / box_diag).log()
215
+ dtheta_target = (theta_target / (torch.pi / 2) + 1) / 2
216
+
217
+ # Direct parameter losses
218
+ d_a, d_b, d_theta = pred
219
+
220
+ pred_t = torch.stack([d_a, d_b, d_theta], dim=1)
221
+ target_t = torch.stack([da_target, db_target, dtheta_target], dim=1)
222
+
223
+ loss_smooth_l1 = F.smooth_l1_loss(pred_t, target_t, beta=(1 / 9), reduction="sum")
224
+ loss_smooth_l1 /= box_proposals.shape[0]
225
+ loss_smooth_l1 = loss_smooth_l1.nan_to_num(nan=0.0).clip(max=float(1e4))
226
+
227
+ a, b, x, y, theta = postprocess_ellipse_predictor(pred, box_proposals)
228
+
229
+ A_pred = ellipse_to_conic_matrix(a=a, b=b, theta=theta, x=x, y=y)
230
+
231
+ loss_kld = kld_loss_fn.forward(A_pred, A_target).clip(max=float(1e4)).mean() * 0.1
232
+ loss_wd = torch.zeros(1, device=pred.device, dtype=pred.dtype)
233
+ # loss_wd = wd_loss_fn.forward(A_pred, A_target).clip(max=float(1e4)).mean() * 0.1
234
+
235
+ return {
236
+ "loss_ellipse_kld": loss_kld,
237
+ "loss_ellipse_smooth_l1": loss_smooth_l1,
238
+ "loss_ellipse_wasserstein": loss_wd,
239
+ }
240
+
241
+
242
+ class EllipseRoIHeads(RoIHeads):
243
+ def __init__(
244
+ self,
245
+ box_roi_pool: nn.Module,
246
+ box_head: nn.Module,
247
+ box_predictor: nn.Module,
248
+ fg_iou_thresh: float,
249
+ bg_iou_thresh: float,
250
+ batch_size_per_image: int,
251
+ positive_fraction: float,
252
+ bbox_reg_weights: Optional[Tuple[float, float, float, float]],
253
+ score_thresh: float,
254
+ nms_thresh: float,
255
+ detections_per_img: int,
256
+ ellipse_roi_pool: nn.Module,
257
+ ellipse_head: nn.Module,
258
+ ellipse_predictor: nn.Module,
259
+ # Loss parameters
260
+ kld_shape_only: bool = False,
261
+ kld_normalize: bool = False,
262
+ # Numerical stability parameters
263
+ nan_to_num: float = 10.0,
264
+ loss_scale: float = 1.0,
265
+ ):
266
+ super().__init__(
267
+ box_roi_pool,
268
+ box_head,
269
+ box_predictor,
270
+ fg_iou_thresh,
271
+ bg_iou_thresh,
272
+ batch_size_per_image,
273
+ positive_fraction,
274
+ bbox_reg_weights,
275
+ score_thresh,
276
+ nms_thresh,
277
+ detections_per_img,
278
+ )
279
+
280
+ self.ellipse_roi_pool = ellipse_roi_pool
281
+ self.ellipse_head = ellipse_head
282
+ self.ellipse_predictor = ellipse_predictor
283
+
284
+ self.kld_loss = SymmetricKLDLoss(
285
+ shape_only=kld_shape_only,
286
+ normalize=kld_normalize,
287
+ nan_to_num=nan_to_num,
288
+ )
289
+ self.wd_loss = WassersteinLoss(
290
+ nan_to_num=nan_to_num,
291
+ normalize=kld_normalize,
292
+ )
293
+ self.loss_scale = loss_scale
294
+
295
+ def has_ellipse_reg(self) -> bool:
296
+ if self.ellipse_roi_pool is None:
297
+ return False
298
+ if self.ellipse_head is None:
299
+ return False
300
+ if self.ellipse_predictor is None:
301
+ return False
302
+ return True
303
+
304
+ def postprocess_ellipse_regressions(self):
305
+ pass
306
+
307
+ def forward(
308
+ self,
309
+ features: Dict[str, torch.Tensor],
310
+ proposals: List[torch.Tensor],
311
+ image_shapes: List[Tuple[int, int]],
312
+ targets: Optional[List[Dict[str, torch.Tensor]]] = None,
313
+ ) -> Tuple[List[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]:
314
+ if targets is not None:
315
+ for t in targets:
316
+ floating_point_types = (torch.float, torch.double, torch.half)
317
+ if t["boxes"].dtype not in floating_point_types:
318
+ raise TypeError("target boxes must be of float type")
319
+ if t["ellipse_matrices"].dtype not in floating_point_types:
320
+ raise TypeError("target ellipse_offsets must be of float type")
321
+ if t["labels"].dtype != torch.int64:
322
+ raise TypeError("target labels must be of int64 type")
323
+
324
+ if self.training:
325
+ proposals, matched_idxs, labels, regression_targets = (
326
+ self.select_training_samples(proposals, targets)
327
+ )
328
+ else:
329
+ labels = None
330
+ regression_targets = None
331
+ matched_idxs = None
332
+
333
+ box_features = self.box_roi_pool(features, proposals, image_shapes)
334
+ box_features = self.box_head(box_features)
335
+ class_logits, box_regression = self.box_predictor(box_features)
336
+
337
+ result: List[Dict[str, torch.Tensor]] = []
338
+ losses = {}
339
+ if self.training:
340
+ if labels is None or regression_targets is None:
341
+ raise ValueError(
342
+ "Labels and regression targets must not be None during training"
343
+ )
344
+ loss_classifier, loss_box_reg = fastrcnn_loss(
345
+ class_logits, box_regression, labels, regression_targets
346
+ )
347
+ losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
348
+ else:
349
+ boxes, scores, labels = self.postprocess_detections(
350
+ class_logits, box_regression, proposals, image_shapes
351
+ )
352
+ num_images = len(boxes)
353
+ for i in range(num_images):
354
+ result.append(
355
+ {
356
+ "boxes": boxes[i],
357
+ "labels": labels[i],
358
+ "scores": scores[i],
359
+ }
360
+ )
361
+
362
+ if self.has_ellipse_reg():
363
+ ellipse_box_proposals = [p["boxes"] for p in result]
364
+ if self.training:
365
+ if matched_idxs is None:
366
+ raise ValueError("matched_idxs must not be None during training")
367
+ # during training, only focus on positive boxes
368
+ num_images = len(proposals)
369
+ ellipse_box_proposals = []
370
+ pos_matched_idxs = []
371
+ for img_id in range(num_images):
372
+ pos = torch.where(labels[img_id] > 0)[0]
373
+ ellipse_box_proposals.append(proposals[img_id][pos])
374
+ pos_matched_idxs.append(matched_idxs[img_id][pos])
375
+ else:
376
+ pos_matched_idxs = None # type: ignore
377
+
378
+ if self.ellipse_roi_pool is not None:
379
+ ellipse_features = self.ellipse_roi_pool(
380
+ features, ellipse_box_proposals, image_shapes
381
+ )
382
+ ellipse_features = self.ellipse_head(ellipse_features)
383
+ ellipse_shapes_normalised = self.ellipse_predictor(ellipse_features)
384
+ else:
385
+ raise Exception("Expected ellipse_roi_pool to be not None")
386
+
387
+ loss_ellipse_regressor = {}
388
+ if self.training:
389
+ if targets is None:
390
+ raise ValueError("Targets must not be None during training")
391
+ if pos_matched_idxs is None:
392
+ raise ValueError(
393
+ "pos_matched_idxs must not be None during training"
394
+ )
395
+ if ellipse_shapes_normalised is None:
396
+ raise ValueError(
397
+ "ellipse_shapes_normalised must not be None during training"
398
+ )
399
+
400
+ ellipse_matrix_targets = [t["ellipse_matrices"] for t in targets]
401
+ rcnn_loss_ellipse = ellipse_loss(
402
+ ellipse_shapes_normalised,
403
+ ellipse_matrix_targets,
404
+ pos_matched_idxs,
405
+ ellipse_box_proposals,
406
+ self.kld_loss,
407
+ self.wd_loss,
408
+ )
409
+
410
+ if self.loss_scale != 1.0:
411
+ rcnn_loss_ellipse["loss_ellipse_kld"] *= self.loss_scale
412
+ rcnn_loss_ellipse["loss_ellipse_smooth_l1"] *= self.loss_scale
413
+
414
+ loss_ellipse_regressor.update(rcnn_loss_ellipse)
415
+ else:
416
+ ellipses_per_image = [lbl.shape[0] for lbl in labels]
417
+ for pred, r, box in zip(
418
+ ellipse_shapes_normalised.split(ellipses_per_image, dim=0),
419
+ result,
420
+ ellipse_box_proposals,
421
+ ):
422
+ a, b, x, y, theta = postprocess_ellipse_predictor(pred, box)
423
+ A_pred = ellipse_to_conic_matrix(a=a, b=b, theta=theta, x=x, y=y)
424
+ r["ellipse_matrices"] = A_pred
425
+ # r["boxes"] = bbox_ellipse(A_pred)
426
+
427
+ losses.update(loss_ellipse_regressor)
428
+
429
+ return result, losses
ellipse_rcnn/core/ga.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from ellipse_rcnn.utils.conics import conic_center
3
+
4
+
5
+ def gaussian_angle_distance(A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor:
6
+ # Extract covariance matrices (negative of the top-left blocks)
7
+ cov1, cov2 = map(lambda arr: -arr[..., :2, :2], (A1, A2))
8
+
9
+ # Extract the means by computing conic centers
10
+ c1_x, c1_y = conic_center(A1)
11
+ c2_x, c2_y = conic_center(A2)
12
+
13
+ # Stack the conic centers into the appropriate shape for computation
14
+ m1 = torch.stack((c1_x, c1_y), dim=-1)[..., None]
15
+ m2 = torch.stack((c2_x, c2_y), dim=-1)[..., None]
16
+
17
+ # Compute determinants for covariance matrices
18
+ det_cov1 = torch.clamp(cov1.det(), min=torch.finfo(cov1.dtype).eps)
19
+ det_cov2 = torch.clamp(cov2.det(), min=torch.finfo(cov2.dtype).eps)
20
+ cov_sum = cov1 + cov2
21
+
22
+ # Determinant of sum (clamped for numerical stability)
23
+ det_cov_sum = torch.clamp(cov_sum.det(), min=torch.finfo(cov_sum.dtype).eps)
24
+
25
+ # Compute fractional term with stabilized determinants
26
+ frac_term = (4 * torch.sqrt(det_cov1 * det_cov2)) / det_cov_sum
27
+ # Stable computation of the exponential term
28
+ mean_diff = m1 - m2
29
+ cov_sum_inv = torch.linalg.solve(
30
+ cov_sum, torch.eye(cov_sum.size(-1), dtype=cov_sum.dtype, device=cov_sum.device)
31
+ )
32
+ exp_arg = -0.5 * mean_diff.transpose(-1, -2) @ cov1 @ cov_sum_inv @ cov2 @ mean_diff
33
+ exp_term = torch.exp(torch.clamp(exp_arg, min=-50, max=50)).squeeze()
34
+
35
+ angle_term = frac_term * exp_term
36
+
37
+ return torch.arccos(angle_term)
38
+
39
+
40
+ class GaussianAngleDistanceLoss(torch.nn.Module):
41
+ """
42
+ Computes the Gaussian Angle Distance loss between two tensors.
43
+
44
+ This class serves as a wrapper around the `gaussian_angle_distance` function,
45
+ providing a clean interface and ensuring numerical stability.
46
+
47
+ Attributes
48
+ ----------
49
+ normalize : bool
50
+
51
+ nan_to_num : float
52
+ The value to replace NaN entries in the computation with. Helps maintain numerical
53
+ stability in cases where the input tensors contain undefined or invalid values.
54
+ """
55
+
56
+ def __init__(self, normalize: bool = True, nan_to_num: float = 10.0):
57
+ super().__init__()
58
+ self.nan_to_num = nan_to_num
59
+
60
+ def forward(self, A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor:
61
+ # Calculate the Gaussian angle distance
62
+ distance = gaussian_angle_distance(A1, A2)
63
+
64
+ # Replace NaN values with a predefined constant for numerical stability
65
+ distance = torch.nan_to_num(distance, nan=self.nan_to_num)
66
+
67
+ return distance
ellipse_rcnn/core/kld.py ADDED
@@ -0,0 +1,124 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ellipse_rcnn.utils.conics import conic_center
4
+
5
+
6
+ def mv_kullback_leibler_divergence(
7
+ A1: torch.Tensor,
8
+ A2: torch.Tensor,
9
+ *,
10
+ shape_only: bool = False,
11
+ ) -> torch.Tensor:
12
+ """
13
+ Compute multi-variate KL divergence between ellipses represented by their matrices.
14
+
15
+ Args:
16
+ A1, A2: Ellipse matrices of shape (..., 3, 3)
17
+ shape_only: If True, ignores displacement term
18
+ """
19
+
20
+ # Ensure that batch sizes are equal
21
+ if A1.shape[:-2] != A2.shape[:-2]:
22
+ raise ValueError(
23
+ f"Batch size mismatch: A1 has shape {A1.shape[:-2]}, A2 has shape {A2.shape[:-2]}"
24
+ )
25
+
26
+ # Extract the upper 2x2 blocks as covariance matrices
27
+ cov1 = A1[..., :2, :2]
28
+ cov2 = A2[..., :2, :2]
29
+
30
+ # Compute centers
31
+ m1 = torch.vstack(conic_center(A1)).T[..., None]
32
+ m2 = torch.vstack(conic_center(A2)).T[..., None]
33
+
34
+ # Compute inverse
35
+ try:
36
+ cov2_inv = torch.linalg.inv(cov2)
37
+ except RuntimeError:
38
+ cov2_inv = torch.linalg.pinv(cov2)
39
+
40
+ # Trace term
41
+ trace_term = (cov2_inv @ cov1).diagonal(dim2=-2, dim1=-1).sum(1)
42
+
43
+ # Log determinant term
44
+ det_cov1 = torch.det(cov1)
45
+ det_cov2 = torch.det(cov2)
46
+ log_term = torch.log(det_cov2 / det_cov1).nan_to_num(nan=0.0)
47
+
48
+ if shape_only:
49
+ displacement_term = 0
50
+ else:
51
+ # Mean difference term
52
+ displacement_term = (
53
+ ((m1 - m2).transpose(-1, -2) @ cov2_inv @ (m1 - m2)).squeeze().abs()
54
+ )
55
+
56
+ return 0.5 * (trace_term + displacement_term - 2 + log_term)
57
+
58
+
59
+ def symmetric_kl_divergence(
60
+ A1: torch.Tensor,
61
+ A2: torch.Tensor,
62
+ *,
63
+ shape_only: bool = False,
64
+ nan_to_num: float = float(1e4),
65
+ normalize: bool = False,
66
+ ) -> torch.Tensor:
67
+ """
68
+ Compute symmetric KL divergence between ellipses.
69
+ """
70
+ kl_12 = torch.nan_to_num(
71
+ mv_kullback_leibler_divergence(A1, A2, shape_only=shape_only), nan_to_num
72
+ )
73
+ kl_21 = torch.nan_to_num(
74
+ mv_kullback_leibler_divergence(A2, A1, shape_only=shape_only), nan_to_num
75
+ )
76
+ kl = (kl_12 + kl_21) / 2
77
+
78
+ if kl.lt(0).any():
79
+ raise ValueError("Negative KL divergence encountered.")
80
+
81
+ if normalize:
82
+ kl = 1 - torch.exp(-kl)
83
+ return kl
84
+
85
+
86
+ class SymmetricKLDLoss(torch.nn.Module):
87
+ """
88
+ Computes the symmetric Kullback-Leibler divergence (KLD) loss between two tensors.
89
+
90
+ SymmetricKLDLoss is used for measuring the divergence between two probability
91
+ distributions or tensors, which can be useful in tasks such as generative modeling
92
+ or optimization. The function allows for options such as normalizing the tensors or
93
+ focusing only on their shape for comparison. Additionally, it includes a feature
94
+ to handle NaN values by replacing them with a numeric constant.
95
+
96
+ Attributes
97
+ ----------
98
+ shape_only : bool
99
+ If True, computes the divergence based on the shape of the tensors only. This
100
+ can be used to evaluate similarity without considering magnitude differences.
101
+ nan_to_num : float
102
+ The value to replace NaN entries in the tensors with. Helps maintain numerical
103
+ stability in cases where the input tensors contain undefined or invalid values.
104
+ normalize : bool
105
+ If True, normalizes the tensors before computing the divergence. This is
106
+ typically used when the inputs are not already probability distributions.
107
+ """
108
+
109
+ def __init__(
110
+ self, shape_only: bool = True, nan_to_num: float = 10.0, normalize: bool = False
111
+ ):
112
+ super().__init__()
113
+ self.shape_only = shape_only
114
+ self.nan_to_num = nan_to_num
115
+ self.normalize = normalize
116
+
117
+ def forward(self, A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor:
118
+ return symmetric_kl_divergence(
119
+ A1,
120
+ A2,
121
+ shape_only=self.shape_only,
122
+ nan_to_num=self.nan_to_num,
123
+ normalize=self.normalize,
124
+ )
ellipse_rcnn/core/model.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from types import NoneType
2
+ from typing import List, Tuple, Optional, Any
3
+
4
+ import pytorch_lightning as pl
5
+ import torch
6
+ from torch import nn
7
+ from torchvision.models import ResNet50_Weights, WeightsEnum
8
+ from torchvision.models.detection.anchor_utils import AnchorGenerator
9
+ from torchvision.models.detection.backbone_utils import resnet_fpn_backbone
10
+ from torchvision.models.detection.faster_rcnn import TwoMLPHead, FastRCNNPredictor # noqa: F
11
+ from torchvision.models.detection.generalized_rcnn import GeneralizedRCNN
12
+ from torchvision.models.detection.rpn import RPNHead, RegionProposalNetwork
13
+ from torchvision.models.detection.transform import GeneralizedRCNNTransform
14
+ from torchvision.ops import MultiScaleRoIAlign
15
+
16
+ from .ellipse_roi_head import EllipseRoIHeads, EllipseRegressor
17
+ from ..utils.types import CollatedBatchType
18
+
19
+
20
+ class EllipseRCNN(GeneralizedRCNN):
21
+ def __init__(
22
+ self,
23
+ num_classes: int = 2,
24
+ # transform parameters
25
+ backbone_name: str = "resnet50",
26
+ weights: WeightsEnum | str = ResNet50_Weights.IMAGENET1K_V1,
27
+ min_size: int = 256,
28
+ max_size: int = 512,
29
+ image_mean: Optional[List[float]] = None,
30
+ image_std: Optional[List[float]] = None,
31
+ # Region Proposal Network parameters
32
+ rpn_anchor_generator: Optional[nn.Module] = None,
33
+ rpn_head: Optional[nn.Module] = None,
34
+ rpn_pre_nms_top_n_train: int = 2000,
35
+ rpn_pre_nms_top_n_test: int = 1000,
36
+ rpn_post_nms_top_n_train: int = 2000,
37
+ rpn_post_nms_top_n_test: int = 1000,
38
+ rpn_nms_thresh: float = 0.7,
39
+ rpn_fg_iou_thresh: float = 0.7,
40
+ rpn_bg_iou_thresh: float = 0.3,
41
+ rpn_batch_size_per_image: int = 256,
42
+ rpn_positive_fraction: float = 0.5,
43
+ rpn_score_thresh: float = 0.0,
44
+ # Box parameters
45
+ box_roi_pool: Optional[nn.Module] = None,
46
+ box_head: Optional[nn.Module] = None,
47
+ box_predictor: Optional[nn.Module] = None,
48
+ box_score_thresh: float = 0.05,
49
+ box_nms_thresh: float = 0.5,
50
+ box_detections_per_img: int = 100,
51
+ box_fg_iou_thresh: float = 0.5,
52
+ box_bg_iou_thresh: float = 0.5,
53
+ box_batch_size_per_image: int = 512,
54
+ box_positive_fraction: float = 0.25,
55
+ bbox_reg_weights: Optional[Tuple[float, float, float, float]] = None,
56
+ # Ellipse regressor
57
+ ellipse_roi_pool: Optional[nn.Module] = None,
58
+ ellipse_head: Optional[nn.Module] = None,
59
+ ellipse_predictor: Optional[nn.Module] = None,
60
+ ellipse_loss_scale: float = 1.0,
61
+ ellipse_loss_normalize: bool = False,
62
+ ):
63
+ if backbone_name != "resnet50" and weights == ResNet50_Weights.IMAGENET1K_V1:
64
+ raise ValueError(
65
+ "If backbone_name is not resnet50, weights_enum must be specified"
66
+ )
67
+
68
+ backbone = resnet_fpn_backbone(
69
+ backbone_name=backbone_name, weights=weights, trainable_layers=5
70
+ )
71
+
72
+ if not hasattr(backbone, "out_channels"):
73
+ raise ValueError(
74
+ "backbone should contain an attribute out_channels "
75
+ "specifying the number of output channels (assumed to be the "
76
+ "same for all the levels)"
77
+ )
78
+
79
+ if not isinstance(rpn_anchor_generator, (AnchorGenerator, NoneType)):
80
+ raise TypeError(
81
+ "rpn_anchor_generator must be an instance of AnchorGenerator or None"
82
+ )
83
+
84
+ if not isinstance(box_roi_pool, (MultiScaleRoIAlign, NoneType)):
85
+ raise TypeError(
86
+ "box_roi_pool must be an instance of MultiScaleRoIAlign or None"
87
+ )
88
+
89
+ if num_classes is not None:
90
+ if box_predictor is not None:
91
+ raise ValueError(
92
+ "num_classes should be None when box_predictor is specified"
93
+ )
94
+ else:
95
+ if box_predictor is None:
96
+ raise ValueError(
97
+ "num_classes should not be None when box_predictor "
98
+ "is not specified"
99
+ )
100
+
101
+ out_channels = backbone.out_channels
102
+
103
+ if rpn_anchor_generator is None:
104
+ anchor_sizes = ((32,), (64,), (128,), (256,), (512,))
105
+ aspect_ratios = ((0.5, 1.0, 2.0),) * len(anchor_sizes)
106
+ rpn_anchor_generator = AnchorGenerator(anchor_sizes, aspect_ratios)
107
+ if rpn_head is None:
108
+ rpn_head = RPNHead(
109
+ out_channels, rpn_anchor_generator.num_anchors_per_location()[0]
110
+ )
111
+
112
+ rpn_pre_nms_top_n = dict(
113
+ training=rpn_pre_nms_top_n_train, testing=rpn_pre_nms_top_n_test
114
+ )
115
+ rpn_post_nms_top_n = dict(
116
+ training=rpn_post_nms_top_n_train, testing=rpn_post_nms_top_n_test
117
+ )
118
+
119
+ rpn = RegionProposalNetwork(
120
+ rpn_anchor_generator,
121
+ rpn_head,
122
+ rpn_fg_iou_thresh,
123
+ rpn_bg_iou_thresh,
124
+ rpn_batch_size_per_image,
125
+ rpn_positive_fraction,
126
+ rpn_pre_nms_top_n,
127
+ rpn_post_nms_top_n,
128
+ rpn_nms_thresh,
129
+ score_thresh=rpn_score_thresh,
130
+ )
131
+
132
+ default_representation_size = 1024
133
+
134
+ if box_roi_pool is None:
135
+ box_roi_pool = MultiScaleRoIAlign(
136
+ featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2
137
+ )
138
+
139
+ if box_head is None:
140
+ resolution = box_roi_pool.output_size[0]
141
+ if isinstance(resolution, int):
142
+ box_head = TwoMLPHead(
143
+ out_channels * resolution**2, default_representation_size
144
+ )
145
+ else:
146
+ raise ValueError(
147
+ "resolution should be an int but is {}".format(resolution)
148
+ )
149
+
150
+ if box_predictor is None:
151
+ box_predictor = FastRCNNPredictor(default_representation_size, num_classes)
152
+
153
+ if ellipse_roi_pool is None:
154
+ ellipse_roi_pool = MultiScaleRoIAlign(
155
+ featmap_names=["0", "1", "2", "3"], output_size=7, sampling_ratio=2
156
+ )
157
+
158
+ resolution = box_roi_pool.output_size[0]
159
+ if ellipse_head is None:
160
+ if isinstance(resolution, int):
161
+ ellipse_head = TwoMLPHead(
162
+ out_channels * resolution**2, default_representation_size
163
+ )
164
+ else:
165
+ raise ValueError(
166
+ "resolution should be an int but is {}".format(resolution)
167
+ )
168
+
169
+ if ellipse_predictor is None:
170
+ ellipse_predictor = EllipseRegressor(
171
+ default_representation_size, num_classes
172
+ )
173
+
174
+ roi_heads = EllipseRoIHeads(
175
+ # Box
176
+ box_roi_pool,
177
+ box_head,
178
+ box_predictor,
179
+ box_fg_iou_thresh,
180
+ box_bg_iou_thresh,
181
+ box_batch_size_per_image,
182
+ box_positive_fraction,
183
+ bbox_reg_weights,
184
+ box_score_thresh,
185
+ box_nms_thresh,
186
+ box_detections_per_img,
187
+ # Ellipse
188
+ ellipse_roi_pool=ellipse_roi_pool,
189
+ ellipse_head=ellipse_head,
190
+ ellipse_predictor=ellipse_predictor,
191
+ loss_scale=ellipse_loss_scale,
192
+ kld_normalize=ellipse_loss_normalize,
193
+ )
194
+
195
+ if image_mean is None:
196
+ image_mean = [0.485, 0.456, 0.406]
197
+ if image_std is None:
198
+ image_std = [0.229, 0.224, 0.225]
199
+ transform = GeneralizedRCNNTransform(min_size, max_size, image_mean, image_std)
200
+
201
+ super().__init__(backbone, rpn, roi_heads, transform)
202
+
203
+
204
+ class EllipseRCNNLightning(pl.LightningModule):
205
+ def __init__(
206
+ self,
207
+ model: EllipseRCNN,
208
+ lr: float = 1e-4,
209
+ weight_decay: float = 1e-4,
210
+ ):
211
+ super().__init__()
212
+ self.model = model
213
+ self.save_hyperparameters(ignore=["model"])
214
+
215
+ def configure_optimizers(self) -> Any:
216
+ optimizer = torch.optim.AdamW(
217
+ self.model.parameters(),
218
+ lr=self.hparams.lr,
219
+ weight_decay=self.hparams.weight_decay,
220
+ amsgrad=True,
221
+ )
222
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
223
+ optimizer, mode="min", factor=0.5, patience=2, min_lr=1e-6
224
+ )
225
+ return {
226
+ "optimizer": optimizer,
227
+ "lr_scheduler": {"scheduler": scheduler, "monitor": "val/loss_total"},
228
+ }
229
+
230
+ def training_step(
231
+ self, batch: CollatedBatchType, batch_idx: int = 0
232
+ ) -> torch.Tensor:
233
+ images, targets = batch
234
+ loss_dict = self.model(images, targets)
235
+ self.log_dict(
236
+ {f"train/{k}": v for k, v in loss_dict.items()},
237
+ prog_bar=True,
238
+ logger=True,
239
+ on_step=True,
240
+ )
241
+
242
+ loss = sum(loss_dict.values())
243
+ self.log("train/loss_total", loss, prog_bar=True, logger=True, on_step=True)
244
+
245
+ return loss
246
+
247
+ def validation_step(
248
+ self, batch: CollatedBatchType, batch_idx: int = 0
249
+ ) -> torch.Tensor:
250
+ self.train(True)
251
+ images, targets = batch
252
+
253
+ loss_dict = self.model(images, targets)
254
+
255
+ self.log_dict(
256
+ {f"val/{k}": v for k, v in loss_dict.items()},
257
+ logger=True,
258
+ on_step=False,
259
+ on_epoch=True,
260
+ )
261
+
262
+ val_loss = sum(loss_dict.values())
263
+ self.log(
264
+ "val/loss_total",
265
+ val_loss,
266
+ prog_bar=True,
267
+ logger=True,
268
+ on_step=False,
269
+ on_epoch=True,
270
+ )
271
+
272
+ self.log(
273
+ "hp_metric",
274
+ val_loss,
275
+ )
276
+
277
+ self.log(
278
+ "lr",
279
+ self.lr_schedulers().get_last_lr()[0],
280
+ )
281
+
282
+ return val_loss
ellipse_rcnn/core/wd.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ellipse_rcnn.utils.conics import conic_center
4
+
5
+
6
+ def wasserstein_distance(
7
+ A1: torch.Tensor,
8
+ A2: torch.Tensor,
9
+ *,
10
+ shape_only: bool = False,
11
+ ) -> torch.Tensor:
12
+ """
13
+ Compute the squared Wasserstein-2 distance between ellipses represented by their matrices.
14
+
15
+ Args:
16
+ A1, A2: Ellipse matrices of shape (..., 3, 3)
17
+ shape_only: If True, ignores displacement term
18
+
19
+ Returns:
20
+ Tensor containing Wasserstein distances
21
+ """
22
+ # Ensure batch sizes match
23
+ if A1.shape[:-2] != A2.shape[:-2]:
24
+ raise ValueError(
25
+ f"Batch size mismatch: A1 has shape {A1.shape[:-2]}, A2 has shape {A2.shape[:-2]}"
26
+ )
27
+
28
+ # Extract covariance matrices (upper 2x2 blocks)
29
+ cov1 = A1[..., :2, :2]
30
+ cov2 = A2[..., :2, :2]
31
+
32
+ if shape_only:
33
+ displacement_term = 0
34
+ else:
35
+ # Compute centers
36
+ m1 = torch.vstack(conic_center(A1)).T[..., None]
37
+ m2 = torch.vstack(conic_center(A2)).T[..., None]
38
+
39
+ # Mean difference term
40
+ displacement_term = torch.sum((m1 - m2) ** 2, dim=(1, 2))
41
+
42
+ # Compute the matrix square root term
43
+ eigenvalues1, eigenvectors1 = torch.linalg.eigh(cov1)
44
+ sqrt_eigenvalues1 = torch.sqrt(torch.clamp(eigenvalues1, min=1e-7))
45
+ sqrt_cov1 = (
46
+ eigenvectors1
47
+ @ torch.diag_embed(sqrt_eigenvalues1)
48
+ @ eigenvectors1.transpose(-2, -1)
49
+ )
50
+
51
+ inner_term = sqrt_cov1 @ cov2 @ sqrt_cov1
52
+ eigenvalues_inner, eigenvectors_inner = torch.linalg.eigh(inner_term)
53
+ sqrt_inner = (
54
+ eigenvectors_inner
55
+ @ torch.diag_embed(torch.sqrt(torch.clamp(eigenvalues_inner, min=1e-7)))
56
+ @ eigenvectors_inner.transpose(-2, -1)
57
+ )
58
+
59
+ trace_term = (
60
+ torch.diagonal(cov1, dim1=-2, dim2=-1).sum(-1)
61
+ + torch.diagonal(cov2, dim1=-2, dim2=-1).sum(-1)
62
+ - 2 * torch.diagonal(sqrt_inner, dim1=-2, dim2=-1).sum(-1)
63
+ )
64
+
65
+ return displacement_term + trace_term
66
+
67
+
68
+ def symmetric_wasserstein_distance(
69
+ A1: torch.Tensor,
70
+ A2: torch.Tensor,
71
+ *,
72
+ shape_only: bool = False,
73
+ nan_to_num: float = float(1e4),
74
+ normalize: bool = False,
75
+ ) -> torch.Tensor:
76
+ """
77
+ Compute symmetric Wasserstein distance between ellipses.
78
+
79
+ Args:
80
+ A1, A2: Ellipse matrices
81
+ shape_only: If True, ignores displacement term
82
+ nan_to_num: Value to replace NaN entries with
83
+ normalize: If True, normalizes the output to [0, 1]
84
+ """
85
+ w = torch.nan_to_num(
86
+ wasserstein_distance(A1, A2, shape_only=shape_only), nan=nan_to_num
87
+ )
88
+
89
+ if w.lt(0).any():
90
+ raise ValueError("Negative Wasserstein distance encountered.")
91
+
92
+ if normalize:
93
+ w = 1 - torch.exp(-w)
94
+ return w
95
+
96
+
97
+ class WassersteinLoss(torch.nn.Module):
98
+ """
99
+ Computes the Wasserstein distance loss between two ellipse tensors.
100
+
101
+ The Wasserstein distance provides a natural metric for comparing probability
102
+ distributions or shapes, with advantages over KL divergence such as:
103
+ - It's symmetric by definition
104
+ - It provides a true metric (satisfies triangle inequality)
105
+ - It's well-behaved even when distributions have different supports
106
+
107
+ Attributes:
108
+ shape_only: If True, computes distance based on shape without considering position
109
+ nan_to_num: Value to replace NaN entries with
110
+ normalize: If True, normalizes output to [0, 1] using exponential scaling
111
+ """
112
+
113
+ def __init__(
114
+ self, shape_only: bool = True, nan_to_num: float = 10.0, normalize: bool = False
115
+ ):
116
+ super().__init__()
117
+ self.shape_only = shape_only
118
+ self.nan_to_num = nan_to_num
119
+ self.normalize = normalize
120
+
121
+ def forward(self, A1: torch.Tensor, A2: torch.Tensor) -> torch.Tensor:
122
+ return symmetric_wasserstein_distance(
123
+ A1,
124
+ A2,
125
+ shape_only=self.shape_only,
126
+ nan_to_num=self.nan_to_num,
127
+ normalize=self.normalize,
128
+ )
ellipse_rcnn/utils/__init__.py ADDED
File without changes
ellipse_rcnn/utils/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (180 Bytes). View file
 
ellipse_rcnn/utils/__pycache__/conics.cpython-312.pyc ADDED
Binary file (8.37 kB). View file
 
ellipse_rcnn/utils/__pycache__/types.cpython-312.pyc ADDED
Binary file (2.73 kB). View file
 
ellipse_rcnn/utils/__pycache__/viz.cpython-312.pyc ADDED
Binary file (4.9 kB). View file
 
ellipse_rcnn/utils/conics.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ import torch
4
+
5
+
6
+ @torch.jit.script
7
+ def adjugate_matrix(matrix: torch.Tensor) -> torch.Tensor:
8
+ """Return adjugate matrix [1].
9
+
10
+ Parameters
11
+ ----------
12
+ matrix:
13
+ Input matrix
14
+
15
+ Returns
16
+ -------
17
+ torch.Tensor
18
+ Adjugate of input matrix
19
+
20
+ References
21
+ ----------
22
+ .. [1] https://en.wikipedia.org/wiki/Adjugate_matrix
23
+ """
24
+
25
+ cofactor = torch.inverse(matrix).T * torch.det(matrix)
26
+ return cofactor.T
27
+
28
+
29
+ # @torch.jit.script
30
+ def unimodular_matrix(matrix: torch.Tensor) -> torch.Tensor:
31
+ """Rescale matrix such that det(ellipses) = 1, in other words, make it unimodular. Doest not work with tensors
32
+ of dtype torch.float64.
33
+
34
+ Parameters
35
+ ----------
36
+ matrix:
37
+ Matrix input
38
+
39
+ Returns
40
+ -------
41
+ torch.Tensor
42
+ Unimodular version of input matrix.
43
+ """
44
+ val = 1.0 / torch.det(matrix)
45
+ return (torch.sign(val) * torch.pow(torch.abs(val), 1.0 / 3.0))[
46
+ ..., None, None
47
+ ] * matrix
48
+
49
+
50
+ # @torch.jit.script
51
+ def ellipse_to_conic_matrix(
52
+ *,
53
+ a: torch.Tensor,
54
+ b: torch.Tensor,
55
+ x: torch.Tensor | None = None,
56
+ y: torch.Tensor | None = None,
57
+ theta: torch.Tensor | None = None,
58
+ ) -> torch.Tensor:
59
+ r"""Returns matrix representation for crater derived from ellipse parameters such that _[1]:
60
+
61
+ | A = a²(sin θ)² + b²(cos θ)²
62
+ | B = 2(b² - a²) sin θ cos θ
63
+ | C = a²(cos θ)² + b²(sin θ)²
64
+ | D = -2Ax₀ - By₀
65
+ | E = -Bx₀ - 2Cy₀
66
+ | F = Ax₀² + Bx₀y₀ + Cy₀² - a²b²
67
+
68
+ Resulting in a conic matrix:
69
+ ::
70
+ |A B/2 D/2 |
71
+ M = |B/2 C E/2 |
72
+ |D/2 E/2 G |
73
+
74
+ Parameters
75
+ ----------
76
+ a:
77
+ Semi-Major ellipse axis
78
+ b:
79
+ Semi-Minor ellipse axis
80
+ theta:
81
+ Ellipse angle (radians)
82
+ x:
83
+ X-position in 2D cartesian coordinate system (coplanar)
84
+ y:
85
+ Y-position in 2D cartesian coordinate system (coplanar)
86
+
87
+ Returns
88
+ -------
89
+ torch.Tensor
90
+ Array of ellipse matrices
91
+
92
+ References
93
+ ----------
94
+ .. [1] https://www.researchgate.net/publication/355490899_Lunar_Crater_Identification_in_Digital_Images
95
+ """
96
+
97
+ x = x if x is not None else torch.zeros(1)
98
+ y = y if y is not None else torch.zeros(1)
99
+ theta = theta if theta is not None else torch.zeros(1)
100
+
101
+ sin_theta = torch.sin(theta)
102
+ cos_theta = torch.cos(theta)
103
+
104
+ a2 = a**2
105
+ b2 = b**2
106
+
107
+ A = a2 * sin_theta**2 + b2 * cos_theta**2
108
+ B = 2 * (b2 - a2) * sin_theta * cos_theta
109
+ C = a2 * cos_theta**2 + b2 * sin_theta**2
110
+ D = -2 * A * x - B * y
111
+ F = -B * x - 2 * C * y
112
+ G = A * (x**2) + B * x * y + C * (y**2) - a2 * b2
113
+
114
+ # Create (array of) of conic matrix (N, 3, 3)
115
+ conic_matrix = torch.stack(
116
+ tensors=(
117
+ torch.stack((A, B / 2, D / 2), dim=-1),
118
+ torch.stack((B / 2, C, F / 2), dim=-1),
119
+ torch.stack((D / 2, F / 2, G), dim=-1),
120
+ ),
121
+ dim=-1,
122
+ )
123
+
124
+ return conic_matrix.squeeze()
125
+
126
+
127
+ def conic_center(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
128
+ """Returns center of ellipse in 2D cartesian coordinate system with numerical stability."""
129
+ # Extract the top-left 2x2 submatrix of the conic matrix
130
+ A = conic_matrix[..., :2, :2]
131
+
132
+ # Add stabilization for pseudoinverse computation by clamping singular values
133
+ A_pinv = torch.linalg.pinv(A, rcond=torch.finfo(A.dtype).eps)
134
+
135
+ # Extract the last two rows for the linear term
136
+ b = -conic_matrix[..., :2, 2][..., None]
137
+
138
+ # Stabilize any potential numerical instabilities
139
+ centers = torch.matmul(A_pinv, b).squeeze()
140
+
141
+ return centers[..., 0], centers[..., 1]
142
+
143
+
144
+ def ellipse_axes(conic_matrix: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
145
+ """Returns semi-major and semi-minor axes of ellipse in 2D cartesian coordinate system."""
146
+ lambdas = (
147
+ torch.linalg.eigvalsh(conic_matrix[..., :2, :2])
148
+ / (-torch.det(conic_matrix) / torch.det(conic_matrix[..., :2, :2]))[..., None]
149
+ )
150
+ axes = torch.sqrt(1 / lambdas)
151
+ return axes[..., 0], axes[..., 1]
152
+
153
+
154
+ def ellipse_angle(conic_matrix: torch.Tensor) -> torch.Tensor:
155
+ """Returns angle of ellipse in radians w.r.t. x-axis."""
156
+ return (
157
+ -torch.atan2(
158
+ 2 * conic_matrix[..., 1, 0],
159
+ conic_matrix[..., 1, 1] - conic_matrix[..., 0, 0],
160
+ )
161
+ / 2
162
+ )
163
+
164
+
165
+ def bbox_ellipse(
166
+ ellipses: torch.Tensor,
167
+ box_type: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
168
+ ) -> torch.Tensor:
169
+ """Converts (array of) ellipse matrices to bounding box tensor with format [xmin, ymin, xmax, ymax].
170
+
171
+ Parameters
172
+ ----------
173
+ ellipses:
174
+ Array of ellipse matrices
175
+ box_type:
176
+ Format of bounding boxes, default is "xyxy"
177
+
178
+ Returns
179
+ -------
180
+ Array of bounding boxes
181
+ """
182
+ cx, cy = conic_center(ellipses)
183
+ theta = ellipse_angle(ellipses)
184
+ semi_major_axis, semi_minor_axis = ellipse_axes(ellipses)
185
+
186
+ ux, uy = semi_major_axis * torch.cos(theta), semi_major_axis * torch.sin(theta)
187
+ vx, vy = (
188
+ semi_minor_axis * torch.cos(theta + torch.pi / 2),
189
+ semi_minor_axis * torch.sin(theta + torch.pi / 2),
190
+ )
191
+
192
+ box_halfwidth = torch.sqrt(ux**2 + vx**2)
193
+ box_halfheight = torch.sqrt(uy**2 + vy**2)
194
+
195
+ bboxes = torch.vstack(
196
+ (
197
+ cx - box_halfwidth,
198
+ cy - box_halfheight,
199
+ cx + box_halfwidth,
200
+ cy + box_halfheight,
201
+ )
202
+ ).T
203
+
204
+ if box_type != "xyxy":
205
+ from torchvision.ops import boxes as box_ops
206
+
207
+ bboxes = box_ops.box_convert(bboxes, in_fmt="xyxy", out_fmt=box_type)
208
+
209
+ return bboxes
ellipse_rcnn/utils/data/__init__.py ADDED
File without changes
ellipse_rcnn/utils/data/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (185 Bytes). View file
 
ellipse_rcnn/utils/data/__pycache__/base.cpython-312.pyc ADDED
Binary file (2.01 kB). View file
 
ellipse_rcnn/utils/data/base.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import Any
3
+
4
+ from torch.utils.data import Dataset
5
+
6
+ from ellipse_rcnn.utils.types import (
7
+ TargetDict,
8
+ CollatedBatchType,
9
+ UncollatedBatchType,
10
+ )
11
+
12
+
13
+ def collate_fn(batch: UncollatedBatchType) -> CollatedBatchType:
14
+ """
15
+ Collate function for the :class:`DataLoader`.
16
+
17
+ Parameters
18
+ ----------
19
+ batch:
20
+ A batch of data.
21
+ """
22
+ return tuple(zip(*batch)) # type: ignore
23
+
24
+
25
+ class EllipseDatasetBase(ABC, Dataset):
26
+ @abstractmethod
27
+ def load_image(self, index: int) -> Any:
28
+ """
29
+ Load the image for the given index.
30
+
31
+ Parameters
32
+ ----------
33
+ index:
34
+ The index of the image.
35
+
36
+ Returns
37
+ -------
38
+ image:
39
+ The raw image.
40
+ """
41
+ pass
42
+
43
+ @abstractmethod
44
+ def load_target_dict(self, index: int) -> TargetDict:
45
+ """
46
+ Load the target dict for the given index.
47
+
48
+ Parameters
49
+ ----------
50
+ index:
51
+ The index of the target dict.
52
+
53
+ Returns
54
+ -------
55
+ target_dict:
56
+ The target dictionary.
57
+ """
58
+ pass
59
+
60
+ @abstractmethod
61
+ def __len__(self) -> int:
62
+ pass
ellipse_rcnn/utils/data/craters.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import h5py
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+
5
+ from ellipse_rcnn.utils.types import TargetDict, ImageTargetTuple
6
+ from ellipse_rcnn.utils.conics import bbox_ellipse
7
+
8
+
9
+ class CraterEllipseDataset(Dataset):
10
+ """
11
+ Dataset for crater ellipse detection. Mostly meant as an example in combination with
12
+ https://github.com/wdoppenberg/crater-detection.
13
+ """
14
+
15
+ def __init__(self, file_path: str, group: str) -> None:
16
+ self.file_path = file_path
17
+ self.group = group
18
+
19
+ def __getitem__(self, idx: torch.Tensor) -> ImageTargetTuple:
20
+ with h5py.File(self.file_path, "r") as dataset:
21
+ image = torch.tensor(dataset[self.group]["images"][idx])
22
+
23
+ # The number of instances can vary, hence it was decided to use a separate array with the indices of the
24
+ # instances.
25
+ start_idx = dataset[self.group]["craters/crater_list_idx"][idx]
26
+ end_idx = dataset[self.group]["craters/crater_list_idx"][idx + 1]
27
+ ellipse_matrices = torch.tensor(
28
+ dataset[self.group]["craters/A_craters"][start_idx:end_idx]
29
+ )
30
+
31
+ boxes = bbox_ellipse(ellipse_matrices)
32
+ area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
33
+
34
+ num_objs = len(boxes)
35
+
36
+ labels = torch.ones((num_objs,), dtype=torch.int64)
37
+ image_id = torch.tensor([idx])
38
+
39
+ iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
40
+
41
+ target = TargetDict(
42
+ boxes=boxes,
43
+ labels=labels,
44
+ image_id=image_id,
45
+ area=area,
46
+ iscrowd=iscrowd,
47
+ ellipse_matrices=ellipse_matrices,
48
+ )
49
+
50
+ return image, target
51
+
52
+ def __len__(self) -> int:
53
+ with h5py.File(self.file_path, "r") as f:
54
+ return len(f[self.group]["images"])
ellipse_rcnn/utils/data/fddb.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Data loader and module for the FDDB dataset.
3
+ https://vis-www.cs.umass.edu/fddb/
4
+ """
5
+
6
+ from glob import glob
7
+ from typing import Any
8
+ from pathlib import Path
9
+
10
+ import torch
11
+ import pandas as pd
12
+ import PIL.Image
13
+ import torchvision.transforms
14
+ import pytorch_lightning as pl
15
+ from torch.utils.data import DataLoader, random_split
16
+
17
+ from ellipse_rcnn.utils.types import TargetDict, ImageTargetTuple, EllipseTuple
18
+ from ellipse_rcnn.utils.conics import bbox_ellipse, ellipse_to_conic_matrix, conic_center, unimodular_matrix
19
+ from ellipse_rcnn.utils.data.base import EllipseDatasetBase, collate_fn
20
+
21
+
22
+ def preprocess_label_files(root_path: str) -> dict[str, list[EllipseTuple]]:
23
+ label_files = glob(f"{root_path}/labels/*.txt")
24
+
25
+ file_paths = []
26
+ ellipse_data = []
27
+
28
+ for filename in label_files:
29
+ with open(filename) as f:
30
+ if "ellipseList" not in filename:
31
+ file_paths += [p.strip("\n") for p in f.readlines()]
32
+ else:
33
+ ellipse_data += [p.strip("\n") for p in f.readlines()]
34
+
35
+ pdf_file_paths = pd.DataFrame({"path": file_paths})
36
+ pdf_file_paths["path_idx"] = pdf_file_paths.index
37
+
38
+ pdf_ellipse_data = pd.DataFrame({"data": ellipse_data})
39
+ pdf_ellipse_data["data_idx"] = pdf_ellipse_data.index
40
+
41
+ pdf_file_data_mapping = pdf_file_paths.merge(
42
+ pdf_ellipse_data, left_on="path", right_on="data", how="left"
43
+ )
44
+
45
+ ellipse_dict: dict[str, list[EllipseTuple]] = {
46
+ str(k): [] for k in pdf_file_paths["path"]
47
+ }
48
+
49
+ for i, r in pdf_file_data_mapping.iterrows():
50
+ data_idx = r["data_idx"]
51
+ num_ellipses = int(ellipse_data[data_idx + 1])
52
+ file_path = r["path"]
53
+ for j in range(data_idx + 2, data_idx + num_ellipses + 2):
54
+ a, b, theta, x, y = [
55
+ float(v) for v in ellipse_data[j].split(" ")[:-1] if len(v) > 0
56
+ ]
57
+ ellipse_params = EllipseTuple(a, b, theta, x, y)
58
+ ellipse_dict[file_path].append(ellipse_params)
59
+
60
+ return ellipse_dict
61
+
62
+
63
+ class FDDB(EllipseDatasetBase):
64
+ def __init__(
65
+ self,
66
+ root_path: str | Path,
67
+ ellipse_dict: dict[str, list[EllipseTuple]] | None = None,
68
+ transform: Any = None,
69
+ ) -> None:
70
+ self.root_path = Path(root_path) if isinstance(root_path, str) else root_path
71
+ if transform is None:
72
+ self.transform = torchvision.transforms.Compose(
73
+ [
74
+ torchvision.transforms.ToTensor(),
75
+ torchvision.transforms.Normalize(
76
+ mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
77
+ ),
78
+ ]
79
+ )
80
+ else:
81
+ self.transform = transform
82
+ self.ellipse_dict = ellipse_dict or preprocess_label_files(root_path)
83
+
84
+ def __len__(self) -> int:
85
+ return len(self.ellipse_dict)
86
+
87
+ def load_target_dict(self, index: int) -> TargetDict:
88
+ key = list(self.ellipse_dict.keys())[index]
89
+ ellipses_list = self.ellipse_dict[key]
90
+
91
+ a = torch.tensor([[e.a for e in ellipses_list]])
92
+ b = torch.tensor([[e.b for e in ellipses_list]])
93
+ theta = torch.tensor([[e.theta for e in ellipses_list]])
94
+ x = torch.tensor([[e.x for e in ellipses_list]])
95
+ y = torch.tensor([[e.y for e in ellipses_list]])
96
+
97
+ ellipse_matrices = ellipse_to_conic_matrix(a=a, b=b, x=x, y=y, theta=theta)
98
+
99
+ if torch.stack(conic_center(ellipse_matrices)).isnan().any():
100
+ raise ValueError("NaN values in ellipse matrices. Please check the data.")
101
+
102
+ if len(ellipse_matrices.shape) == 2:
103
+ ellipse_matrices = ellipse_matrices.unsqueeze(0)
104
+
105
+ boxes = bbox_ellipse(ellipse_matrices, box_type="xyxy")
106
+
107
+ num_objs = len(boxes)
108
+
109
+ labels = torch.ones((num_objs,), dtype=torch.int64)
110
+ image_id = torch.tensor([index])
111
+ area = (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 2] - boxes[:, 0])
112
+ iscrowd = torch.zeros((num_objs,), dtype=torch.int64)
113
+
114
+ target = TargetDict(
115
+ boxes=boxes,
116
+ labels=labels,
117
+ image_id=image_id,
118
+ area=area,
119
+ iscrowd=iscrowd,
120
+ ellipse_matrices=ellipse_matrices,
121
+ )
122
+
123
+ return target
124
+
125
+ def load_image(self, index: int) -> PIL.Image.Image:
126
+ key = list(self.ellipse_dict.keys())[index]
127
+ file_path = str(Path(self.root_path) / "images" / Path(key)) + ".jpg"
128
+ return PIL.Image.open(file_path)
129
+
130
+ def __getitem__(self, idx: int) -> ImageTargetTuple:
131
+ image = self.load_image(idx)
132
+ target_dict = self.load_target_dict(idx)
133
+
134
+ # If the image is grayscale, convert it to RGB
135
+ if image.mode == "L":
136
+ image = image.convert("RGB")
137
+
138
+ image = self.transform(image)
139
+
140
+ return image, target_dict
141
+
142
+ def __repr__(self) -> str:
143
+ return f"FDDB<img={len(self)}>"
144
+
145
+ def split(self, fraction: float, shuffle: bool = False) -> tuple["FDDB", "FDDB"]:
146
+ """
147
+ Splits the dataset into two subsets based on the given fraction.
148
+
149
+ Args:
150
+ fraction (float): Fraction of the dataset for the first subset (0 < fraction < 1).
151
+ shuffle (bool): If True, dataset keys will be shuffled before splitting.
152
+
153
+ Returns:
154
+ tuple[FDDB, FDDB]: Two FDDB instances, one with the fraction of data,
155
+ and the other with the remaining data.
156
+ """
157
+ if not (0 < fraction < 1):
158
+ raise ValueError("The fraction must be between 0 and 1.")
159
+
160
+ keys = list(self.ellipse_dict.keys())
161
+ if shuffle:
162
+ import random
163
+
164
+ random.shuffle(keys)
165
+
166
+ total_length = len(keys)
167
+ split_index = int(total_length * fraction)
168
+
169
+ subset1_keys = keys[:split_index]
170
+ subset2_keys = keys[split_index:]
171
+
172
+ subset1_ellipse_dict = {key: self.ellipse_dict[key] for key in subset1_keys}
173
+ subset2_ellipse_dict = {key: self.ellipse_dict[key] for key in subset2_keys}
174
+
175
+ subset1 = FDDB(
176
+ self.root_path, ellipse_dict=subset1_ellipse_dict, transform=self.transform
177
+ )
178
+ subset2 = FDDB(
179
+ self.root_path, ellipse_dict=subset2_ellipse_dict, transform=self.transform
180
+ )
181
+
182
+ return subset1, subset2
183
+
184
+
185
+ class FDDBLightningDataModule(pl.LightningDataModule):
186
+ def __init__(
187
+ self,
188
+ data_dir: str,
189
+ batch_size: int = 16,
190
+ train_fraction: float = 0.8,
191
+ transform: Any = None,
192
+ num_workers: int = 0,
193
+ ) -> None:
194
+ super().__init__()
195
+ self.data_dir = data_dir
196
+ self.batch_size = batch_size
197
+ self.train_fraction = train_fraction
198
+ self.transform = transform
199
+ self.dataset: FDDB | None = None
200
+ self.train_dataset = None
201
+ self.val_dataset = None
202
+ self.num_workers = num_workers
203
+
204
+ def prepare_data(self) -> None:
205
+ # Ensure data preparation or downloading is done here.
206
+ pass
207
+
208
+ def setup(self, stage: str | None = None) -> None:
209
+ # Instantiate the FDDB dataset and split it into training and validation subsets.
210
+ self.dataset = FDDB(self.data_dir, transform=self.transform)
211
+
212
+ train_size = int(len(self.dataset) * self.train_fraction)
213
+ val_size = len(self.dataset) - train_size
214
+ self.train_dataset, self.val_dataset = random_split(
215
+ self.dataset, [train_size, val_size]
216
+ )
217
+
218
+ def train_dataloader(self) -> DataLoader[ImageTargetTuple]:
219
+ return DataLoader(
220
+ self.train_dataset,
221
+ batch_size=self.batch_size,
222
+ shuffle=True,
223
+ collate_fn=collate_fn,
224
+ num_workers=self.num_workers,
225
+ )
226
+
227
+ def val_dataloader(self) -> DataLoader[ImageTargetTuple]:
228
+ return DataLoader(
229
+ self.val_dataset,
230
+ batch_size=self.batch_size,
231
+ collate_fn=collate_fn,
232
+ num_workers=self.num_workers,
233
+ )
234
+
235
+ def test_dataloader(self) -> DataLoader[ImageTargetTuple]:
236
+ # Placeholder for test data; currently returns the validation dataloader as a default.
237
+ return DataLoader(
238
+ self.val_dataset, batch_size=self.batch_size, collate_fn=collate_fn
239
+ )
ellipse_rcnn/utils/types.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import TypedDict, NamedTuple
2
+
3
+ import torch
4
+
5
+
6
+ class TargetDict(TypedDict):
7
+ boxes: torch.Tensor
8
+ labels: torch.Tensor
9
+ image_id: torch.Tensor
10
+ area: torch.Tensor
11
+ iscrowd: torch.Tensor
12
+ ellipse_matrices: torch.Tensor
13
+
14
+
15
+ class LossDict(TypedDict, total=False):
16
+ loss_classifier: torch.Tensor
17
+ loss_box_reg: torch.Tensor
18
+ loss_objectness: torch.Tensor
19
+ loss_rpn_box_reg: torch.Tensor
20
+ loss_ellipse_kld: torch.Tensor
21
+ loss_ellipse_smooth_l1: torch.Tensor
22
+ loss_total: torch.Tensor
23
+
24
+
25
+ class PredictionDict(TypedDict):
26
+ bboxes: torch.Tensor
27
+ labels: torch.Tensor
28
+ scores: torch.Tensor
29
+ ellipse_matrices: torch.Tensor
30
+
31
+
32
+ type ImageTargetTuple = tuple[torch.Tensor, TargetDict] # Tensor shape: (C, H, W)
33
+ type CollatedBatchType = tuple[
34
+ tuple[torch.Tensor, ...], tuple[TargetDict, ...]
35
+ ] # Tensor shape: (C, H, W)
36
+ type UncollatedBatchType = list[ImageTargetTuple]
37
+
38
+ type EllipseType = torch.Tensor
39
+
40
+
41
+ class EllipseTuple(NamedTuple):
42
+ a: float
43
+ b: float
44
+ theta: float
45
+ x: float
46
+ y: float
ellipse_rcnn/utils/viz.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Literal
4
+ from torch import Tensor
5
+ import numpy as np
6
+ import torch
7
+ from torchvision.ops import boxes as box_ops
8
+ from matplotlib import pyplot as plt
9
+ from matplotlib.axes import Axes
10
+ from matplotlib.collections import EllipseCollection, PatchCollection
11
+ from matplotlib.patches import Rectangle
12
+ from ellipse_rcnn.utils.conics import ellipse_angle, conic_center, ellipse_axes
13
+ from matplotlib.figure import Figure
14
+
15
+
16
+ def plot_single_pred(
17
+ image: Tensor,
18
+ prediction,
19
+ min_score: float = 0.75,
20
+ ) -> Figure:
21
+ if isinstance(prediction, list):
22
+ if len(prediction) > 1:
23
+ raise ValueError(
24
+ "Multiple predictions detected. Please pass a single prediction."
25
+ )
26
+ prediction = prediction[0]
27
+ fig, ax = plt.subplots(1, 1, figsize=(10, 10))
28
+ fig.patch.set_alpha(0)
29
+ ax.imshow(image.permute(1, 2, 0), cmap="grey")
30
+ score_mask = prediction["scores"] > min_score
31
+
32
+ plot_ellipses(prediction["ellipse_matrices"][score_mask], ax=ax)
33
+
34
+ return fig
35
+
36
+
37
+ def plot_ellipses(
38
+ A_craters: torch.Tensor,
39
+ figsize: tuple[float, float] = (15, 15),
40
+ plot_centers: bool = False,
41
+ ax: Axes | None = None,
42
+ rim_color="r",
43
+ alpha=1.0,
44
+ ):
45
+ a_proj, b_proj = ellipse_axes(A_craters)
46
+ psi_proj = ellipse_angle(A_craters)
47
+ x_pix_proj, y_pix_proj = conic_center(A_craters)
48
+
49
+ a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj = map(
50
+ lambda t: t.detach().cpu().numpy(),
51
+ (a_proj, b_proj, psi_proj, x_pix_proj, y_pix_proj),
52
+ )
53
+
54
+ if ax is None:
55
+ fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"})
56
+
57
+ ec = EllipseCollection(
58
+ a_proj * 2,
59
+ b_proj * 2,
60
+ np.degrees(psi_proj),
61
+ units="xy",
62
+ offsets=np.column_stack((x_pix_proj, y_pix_proj)),
63
+ transOffset=ax.transData,
64
+ facecolors="None",
65
+ edgecolors=rim_color,
66
+ alpha=alpha,
67
+ )
68
+ ax.add_collection(ec)
69
+
70
+ if plot_centers:
71
+ crater_centers = conic_center(A_craters)
72
+ for k, c_i in enumerate(crater_centers):
73
+ x, y = c_i[0], c_i[1]
74
+ ax.text(x.item(), y.item(), str(k), color=rim_color)
75
+
76
+
77
+ def plot_bboxes(
78
+ boxes: torch.Tensor,
79
+ box_type: Literal["xyxy", "xywh", "cxcywh"] = "xyxy",
80
+ figsize: tuple[float, float] = (15, 15),
81
+ plot_centers: bool = False,
82
+ ax: Axes | None = None,
83
+ rim_color="r",
84
+ alpha=1.0,
85
+ ):
86
+ if ax is None:
87
+ fig, ax = plt.subplots(figsize=figsize, subplot_kw={"aspect": "equal"})
88
+
89
+ if box_type != "xyxy":
90
+ boxes = box_ops.box_convert(boxes, box_type, "xyxy")
91
+
92
+ boxes = boxes.detach().cpu().numpy()
93
+ rectangles = []
94
+ for k, b_i in enumerate(boxes):
95
+ x1, y1, x2, y2 = b_i
96
+ rectangles.append(Rectangle((x1, y1), x2 - x1, y2 - y1))
97
+
98
+ collection = PatchCollection(
99
+ rectangles, edgecolor=rim_color, facecolor="none", alpha=alpha
100
+ )
101
+ ax.add_collection(collection)
102
+
103
+ if plot_centers:
104
+ for k, b_i in enumerate(boxes):
105
+ x1, y1, x2, y2 = b_i
106
+ ax.text(x1, y1, str(k), color=rim_color)
examples/image1.jpg ADDED
examples/image2.jpg ADDED
examples/image3.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ matplotlib
5
+ Pillow
6
+ joblib
7
+ huggingface_hub
8
+ lightning