File size: 4,381 Bytes
205a7af
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
"""up decoder head.

Adapted from https://github.com/jinlinyi/PerspectiveFields
"""

import logging

import torch
from torch import nn
from torch.nn import functional as F

from siclib.models import get_model
from siclib.models.base_model import BaseModel
from siclib.models.utils.metrics import up_error
from siclib.models.utils.perspective_encoding import decode_up_bin
from siclib.utils.conversions import deg2rad

logger = logging.getLogger(__name__)

# flake8: noqa
# mypy: ignore-errors


class UpDecoder(BaseModel):
    default_conf = {
        "loss_type": "l1",
        "use_loss": True,
        "use_uncertainty_loss": True,
        "loss_weight": 1.0,
        "recall_thresholds": [1, 3, 5, 10],
        "decoder": {"name": "decoders.light_hamburger", "predict_uncertainty": True},
    }

    required_data_keys = ["features"]

    def _init(self, conf):
        self.loss_type = conf.loss_type
        self.loss_weight = conf.loss_weight

        self.use_uncertainty_loss = conf.use_uncertainty_loss
        self.predict_uncertainty = conf.decoder.predict_uncertainty

        self.num_classes = 2
        self.is_classification = self.conf.loss_type == "classification"
        if self.is_classification:
            self.num_classes = 73

        self.decoder = get_model(conf.decoder.name)(conf.decoder)
        self.linear_pred_up = nn.Conv2d(self.decoder.out_channels, self.num_classes, kernel_size=1)

    def calculate_losses(self, predictions, targets, confidence=None):
        predictions = predictions.float()  # https://github.com/pytorch/pytorch/issues/48163

        residuals = predictions - targets
        if self.loss_type == "l2":
            loss = (residuals**2).sum(axis=1)
        elif self.loss_type == "l1":
            loss = residuals.abs().sum(axis=1)
        elif self.loss_type == "dot":
            loss = 1 - (residuals * targets).sum(axis=1)
        elif self.loss_type == "cauchy":
            c = 0.007  # -> corresponds to about 5 degrees
            residuals = (residuals**2).sum(axis=1)
            loss = c**2 / 2 * torch.log(1 + residuals / c**2)
        elif self.loss_type == "huber":
            c = deg2rad(1)
            loss = nn.HuberLoss(reduction="none", delta=c)(predictions, targets).sum(axis=1)
        else:
            raise NotImplementedError(f"Unknown loss type {self.conf.loss_type}")

        if confidence is not None and self.use_uncertainty_loss:
            conf_weight = confidence / confidence.sum(axis=(-2, -1), keepdims=True)
            conf_weight = conf_weight * (conf_weight.size(-1) * conf_weight.size(-2))
            loss = loss * conf_weight.detach()

        losses = {f"up-{self.loss_type}-loss": loss.mean(axis=(1, 2))}
        losses = {k: v * self.loss_weight for k, v in losses.items()}

        return losses

    def _forward(self, data):
        out = {}
        x, log_confidence = self.decoder(data["features"])
        up = self.linear_pred_up(x)

        if self.predict_uncertainty:
            out["up_confidence"] = torch.sigmoid(log_confidence)

        if self.is_classification:
            out["up_field"] = decode_up_bin(up.argmax(dim=1), self.num_classes)
            return out

        up = F.normalize(up, dim=1)

        out["up_field"] = up
        return out

    def loss(self, pred, data):
        if not self.conf.use_loss or self.is_classification:
            return {}, self.metrics(pred, data)

        predictions = pred["up_field"]
        targets = data["up_field"]

        losses = self.calculate_losses(predictions, targets, pred.get("up_confidence"))

        total = 0 + losses[f"up-{self.loss_type}-loss"]
        losses |= {"up_total": total}
        return losses, self.metrics(pred, data)

    def metrics(self, pred, data):
        predictions = pred["up_field"]
        targets = data["up_field"]

        mask = predictions.sum(axis=1) != 0

        error = up_error(predictions, targets) * mask
        out = {"up_angle_error": error.mean(axis=(1, 2))}

        if "up_confidence" in pred:
            weighted_error = (error * pred["up_confidence"]).sum(axis=(1, 2))
            out["up_angle_error_weighted"] = weighted_error / pred["up_confidence"].sum(axis=(1, 2))

        for th in self.conf.recall_thresholds:
            rec = (error < th).float().mean(axis=(1, 2))
            out[f"up_angle_recall@{th}"] = rec

        return out