File size: 16,095 Bytes
8e9893b
afc2161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e9893b
afc2161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
from typing import Dict, List, Tuple, Optional, TypedDict, NamedTuple

import torch
from torch import nn
from torch.nn import functional as F
from torchvision.models.detection.roi_heads import RoIHeads, fastrcnn_loss

from .kld import SymmetricKLDLoss
from .wd import WassersteinLoss
from ..utils.conics import (
    ellipse_to_conic_matrix,
    ellipse_axes,
    ellipse_angle,
    conic_center,
)


class RegressorPrediction(NamedTuple):
    """
    Represents the processed outputs of a regression model as a named tuple.

    This class encapsulates regression model outputs in a structured format, where
    each attribute corresponds to a specific component of the regression output.
    These outputs can be directly used for post-processing steps such as transformation
    into conic matrices or further evaluations of ellipse geometry.

    Attributes
    ----------
    d_a : torch.Tensor
        The normalized semi-major axis scale factor (logarithmic) used to compute
        the actual semi-major axis length of ellipses.
    d_b : torch.Tensor
        The normalized semi-minor axis scale factor (logarithmic) used to compute
        the actual semi-minor axis length of ellipses.
    d_x : torch.Tensor
        The normalized x-coordinate translation factor, specifying the adjustment
        to the center of bounding boxes for ellipse placement.
    d_y : torch.Tensor
        The normalized y-coordinate translation factor, specifying the adjustment
        to the center of bounding boxes for ellipse placement.
    d_theta : torch.Tensor
        The normalized rotation angle factor which is processed to derive the
        actual rotation angle (in radians) of ellipses.

    Notes
    -----
    - The attributes `d_a` and `d_b`, representing scale factors for the semi-major
      and semi-minor axes, are typically bounded between 0 and 1 using a sigmoid activation.
    - The attributes `d_x` and `d_y` serve as adjustments to bounding box centers, normalized
      with respect to the bounding box diagonals.
    - The attribute `d_theta` is normalized to ensure the rotation angle lies within
      a valid range (after transformation, typically between -π/2 and π/2 radians).
    - These normalized outputs are post-processed together with bounding box information
      to construct actionable ellipse parameters such as their axes lengths, centers,
      and angles.
    - This structure simplifies downstream regression tasks, such as conversion into
      conic matrices or calculation of geometrical losses.
    """

    d_a: torch.Tensor
    d_b: torch.Tensor
    d_theta: torch.Tensor

    @property
    def device(self) -> torch.device:
        return self.d_a.device

    @property
    def dtype(self) -> torch.dtype:
        return self.d_a.dtype

    def split(self, split_size: list[int] | int, dim: int = 0):
        return [
            RegressorPrediction(*tensors)
            for tensors in zip(
                *[torch.split(attr, split_size, dim=dim) for attr in self]
            )
        ]


class EllipseRegressor(nn.Module):
    """
    EllipseRegressor is a neural network module designed to predict parameters of
    an ellipse given input features.

    This class is a PyTorch module that uses a feedforward neural network to predict
    the normalized five parameters of an ellipse: semi-major axis `a`, semi-minor axis `b`, center
    coordinates (`x`, `y`), and orientation `theta`. The class includes mechanisms
    for batch normalization and uses Xavier weight initialization for improved
    training stability and convergence.

    Attributes
    ----------
    ffnn : nn.Sequential
        A feedforward neural network with two hidden layers and ReLU activations.
    """

    def __init__(self, in_channels: int = 1024, hidden_size: int = 64):
        super().__init__()
        # Separate prediction heads for better gradient flow
        self.ffnn = nn.Sequential(
            nn.Linear(in_channels, hidden_size),
            nn.ReLU(),
            nn.Linear(hidden_size, 3),
            nn.Tanh(),
        )

        # Initialize with small values
        for lin in self.ffnn:
            if isinstance(lin, nn.Linear):
                nn.init.xavier_uniform_(lin.weight, gain=0.01)
                nn.init.zeros_(lin.bias)

    def forward(self, x: torch.Tensor) -> RegressorPrediction:
        x = x.flatten(start_dim=1)
        x = self.ffnn(x)

        d_a, d_b, d_theta = x.unbind(dim=-1)

        return RegressorPrediction(d_a=d_a, d_b=d_b, d_theta=d_theta)


def postprocess_ellipse_predictor(
    pred: RegressorPrediction,
    box_proposals: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """Processes elliptical predictor outputs and converts them into conic matrices.

    Parameters
    ----------
    pred : RegressorPrediction
        The output of the elliptical predictor model.
    box_proposals : torch.Tensor
        Tensor containing proposed bounding box information, with shape (N, 4). Each box
        is represented as a 4-tuple (x_min, y_min, x_max, y_max).

    Returns
    -------
    tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]
        A tuple containing:
        - a (torch.Tensor): Computed semi-major axis of the ellipses.
        - b (torch.Tensor): Computed semi-minor axis of the ellipses.
        - x (torch.Tensor): X-coordinates of the ellipse centers.
        - y (torch.Tensor): Y-coordinates of the ellipse centers.
        - theta (torch.Tensor): Rotation angles (in radians) for the ellipses.

    """
    d_a, d_b, d_theta = pred

    # Pre-compute box width, height, and diagonal
    box_width = box_proposals[:, 2] - box_proposals[:, 0]
    box_height = box_proposals[:, 3] - box_proposals[:, 1]
    box_diag = torch.sqrt(box_width**2 + box_height**2)

    a = box_diag * d_a.exp()
    b = box_diag * d_b.exp()

    box_x = box_proposals[:, 0] + box_width * 0.5
    box_y = box_proposals[:, 1] + box_height * 0.5

    theta = (d_theta * 2.0 - 1.0) * (torch.pi / 2)

    cos_theta = torch.cos(theta)
    sin_theta = torch.sin(theta)
    theta = torch.where(
        cos_theta >= 0,
        torch.atan2(sin_theta, cos_theta),
        torch.atan2(-sin_theta, -cos_theta),
    )

    return a, b, box_x, box_y, theta


class EllipseLossDict(TypedDict):
    loss_ellipse_kld: torch.Tensor
    loss_ellipse_smooth_l1: torch.Tensor
    loss_ellipse_wasserstein: torch.Tensor


def ellipse_loss(
    pred: RegressorPrediction,
    A_target: List[torch.Tensor],
    pos_matched_idxs: List[torch.Tensor],
    box_proposals: List[torch.Tensor],
    kld_loss_fn: SymmetricKLDLoss,
    wd_loss_fn: WassersteinLoss,
) -> EllipseLossDict:
    pos_matched_idxs_batched = torch.cat(pos_matched_idxs, dim=0)
    A_target = torch.cat(A_target, dim=0)[pos_matched_idxs_batched]

    box_proposals = torch.cat(box_proposals, dim=0)

    if A_target.numel() == 0:
        return {
            "loss_ellipse_kld": torch.tensor(0.0, device=pred.device, dtype=pred.dtype),
            "loss_ellipse_smooth_l1": torch.tensor(
                0.0, device=pred.device, dtype=pred.dtype
            ),
            "loss_ellipse_wasserstein": torch.tensor(
                0.0, device=pred.device, dtype=pred.dtype
            ),
        }

    a_target, b_target = ellipse_axes(A_target)
    theta_target = ellipse_angle(A_target)

    # Box proposal parameters
    box_width = box_proposals[:, 2] - box_proposals[:, 0]
    box_height = box_proposals[:, 3] - box_proposals[:, 1]
    box_diag = torch.sqrt(box_width**2 + box_height**2).clamp(min=1e-6)

    # Normalize target variables
    da_target = (a_target / box_diag).log()
    db_target = (b_target / box_diag).log()
    dtheta_target = (theta_target / (torch.pi / 2) + 1) / 2

    # Direct parameter losses
    d_a, d_b, d_theta = pred

    pred_t = torch.stack([d_a, d_b, d_theta], dim=1)
    target_t = torch.stack([da_target, db_target, dtheta_target], dim=1)

    loss_smooth_l1 = F.smooth_l1_loss(pred_t, target_t, beta=(1 / 9), reduction="sum")
    loss_smooth_l1 /= box_proposals.shape[0]
    loss_smooth_l1 = loss_smooth_l1.nan_to_num(nan=0.0).clip(max=float(1e4))

    a, b, x, y, theta = postprocess_ellipse_predictor(pred, box_proposals)

    A_pred = ellipse_to_conic_matrix(a=a, b=b, theta=theta, x=x, y=y)

    loss_kld = kld_loss_fn.forward(A_pred, A_target).clip(max=float(1e4)).mean() * 0.1
    loss_wd = torch.zeros(1, device=pred.device, dtype=pred.dtype)
    # loss_wd = wd_loss_fn.forward(A_pred, A_target).clip(max=float(1e4)).mean() * 0.1

    return {
        "loss_ellipse_kld": loss_kld,
        "loss_ellipse_smooth_l1": loss_smooth_l1,
        "loss_ellipse_wasserstein": loss_wd,
    }


class EllipseRoIHeads(RoIHeads):
    def __init__(
        self,
        box_roi_pool: nn.Module,
        box_head: nn.Module,
        box_predictor: nn.Module,
        fg_iou_thresh: float,
        bg_iou_thresh: float,
        batch_size_per_image: int,
        positive_fraction: float,
        bbox_reg_weights: Optional[Tuple[float, float, float, float]],
        score_thresh: float,
        nms_thresh: float,
        detections_per_img: int,
        ellipse_roi_pool: nn.Module,
        ellipse_head: nn.Module,
        ellipse_predictor: nn.Module,
        # Loss parameters
        kld_shape_only: bool = False,
        kld_normalize: bool = False,
        # Numerical stability parameters
        nan_to_num: float = 10.0,
        loss_scale: float = 1.0,
    ):
        super().__init__(
            box_roi_pool,
            box_head,
            box_predictor,
            fg_iou_thresh,
            bg_iou_thresh,
            batch_size_per_image,
            positive_fraction,
            bbox_reg_weights,
            score_thresh,
            nms_thresh,
            detections_per_img,
        )

        self.ellipse_roi_pool = ellipse_roi_pool
        self.ellipse_head = ellipse_head
        self.ellipse_predictor = ellipse_predictor

        self.kld_loss = SymmetricKLDLoss(
            shape_only=kld_shape_only,
            normalize=kld_normalize,
            nan_to_num=nan_to_num,
        )
        self.wd_loss = WassersteinLoss(
            nan_to_num=nan_to_num,
            normalize=kld_normalize,
        )
        self.loss_scale = loss_scale

    def has_ellipse_reg(self) -> bool:
        if self.ellipse_roi_pool is None:
            return False
        if self.ellipse_head is None:
            return False
        if self.ellipse_predictor is None:
            return False
        return True

    def postprocess_ellipse_regressions(self):
        pass

    def forward(
        self,
        features: Dict[str, torch.Tensor],
        proposals: List[torch.Tensor],
        image_shapes: List[Tuple[int, int]],
        targets: Optional[List[Dict[str, torch.Tensor]]] = None,
    ) -> Tuple[List[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]:
        if targets is not None:
            for t in targets:
                floating_point_types = (torch.float, torch.double, torch.half)
                if t["boxes"].dtype not in floating_point_types:
                    raise TypeError("target boxes must be of float type")
                if t["ellipse_matrices"].dtype not in floating_point_types:
                    raise TypeError("target ellipse_offsets must be of float type")
                if t["labels"].dtype != torch.int64:
                    raise TypeError("target labels must be of int64 type")

        if self.training:
            proposals, matched_idxs, labels, regression_targets = (
                self.select_training_samples(proposals, targets)
            )
        else:
            labels = None
            regression_targets = None
            matched_idxs = None

        box_features = self.box_roi_pool(features, proposals, image_shapes)
        box_features = self.box_head(box_features)
        class_logits, box_regression = self.box_predictor(box_features)

        result: List[Dict[str, torch.Tensor]] = []
        losses = {}
        if self.training:
            if labels is None or regression_targets is None:
                raise ValueError(
                    "Labels and regression targets must not be None during training"
                )
            loss_classifier, loss_box_reg = fastrcnn_loss(
                class_logits, box_regression, labels, regression_targets
            )
            losses = {"loss_classifier": loss_classifier, "loss_box_reg": loss_box_reg}
        else:
            boxes, scores, labels = self.postprocess_detections(
                class_logits, box_regression, proposals, image_shapes
            )
            num_images = len(boxes)
            for i in range(num_images):
                result.append(
                    {
                        "boxes": boxes[i],
                        "labels": labels[i],
                        "scores": scores[i],
                    }
                )

        if self.has_ellipse_reg():
            ellipse_box_proposals = [p["boxes"] for p in result]
            if self.training:
                if matched_idxs is None:
                    raise ValueError("matched_idxs must not be None during training")
                # during training, only focus on positive boxes
                num_images = len(proposals)
                ellipse_box_proposals = []
                pos_matched_idxs = []
                for img_id in range(num_images):
                    pos = torch.where(labels[img_id] > 0)[0]
                    ellipse_box_proposals.append(proposals[img_id][pos])
                    pos_matched_idxs.append(matched_idxs[img_id][pos])
            else:
                pos_matched_idxs = None  # type: ignore

            if self.ellipse_roi_pool is not None:
                ellipse_features = self.ellipse_roi_pool(
                    features, ellipse_box_proposals, image_shapes
                )
                ellipse_features = self.ellipse_head(ellipse_features)
                ellipse_shapes_normalised = self.ellipse_predictor(ellipse_features)
            else:
                raise Exception("Expected ellipse_roi_pool to be not None")

            loss_ellipse_regressor = {}
            if self.training:
                if targets is None:
                    raise ValueError("Targets must not be None during training")
                if pos_matched_idxs is None:
                    raise ValueError(
                        "pos_matched_idxs must not be None during training"
                    )
                if ellipse_shapes_normalised is None:
                    raise ValueError(
                        "ellipse_shapes_normalised must not be None during training"
                    )

                ellipse_matrix_targets = [t["ellipse_matrices"] for t in targets]
                rcnn_loss_ellipse = ellipse_loss(
                    ellipse_shapes_normalised,
                    ellipse_matrix_targets,
                    pos_matched_idxs,
                    ellipse_box_proposals,
                    self.kld_loss,
                    self.wd_loss,
                )

                if self.loss_scale != 1.0:
                    rcnn_loss_ellipse["loss_ellipse_kld"] *= self.loss_scale
                    rcnn_loss_ellipse["loss_ellipse_smooth_l1"] *= self.loss_scale

                loss_ellipse_regressor.update(rcnn_loss_ellipse)
            else:
                ellipses_per_image = [lbl.shape[0] for lbl in labels]
                for pred, r, box in zip(
                    ellipse_shapes_normalised.split(ellipses_per_image, dim=0),
                    result,
                    ellipse_box_proposals,
                ):
                    a, b, x, y, theta = postprocess_ellipse_predictor(pred, box)
                    A_pred = ellipse_to_conic_matrix(a=a, b=b, theta=theta, x=x, y=y)
                    r["ellipse_matrices"] = A_pred
                    # r["boxes"] = bbox_ellipse(A_pred)

            losses.update(loss_ellipse_regressor)

        return result, losses