File size: 4,381 Bytes
205a7af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
"""up decoder head.
Adapted from https://github.com/jinlinyi/PerspectiveFields
"""
import logging
import torch
from torch import nn
from torch.nn import functional as F
from siclib.models import get_model
from siclib.models.base_model import BaseModel
from siclib.models.utils.metrics import up_error
from siclib.models.utils.perspective_encoding import decode_up_bin
from siclib.utils.conversions import deg2rad
logger = logging.getLogger(__name__)
# flake8: noqa
# mypy: ignore-errors
class UpDecoder(BaseModel):
default_conf = {
"loss_type": "l1",
"use_loss": True,
"use_uncertainty_loss": True,
"loss_weight": 1.0,
"recall_thresholds": [1, 3, 5, 10],
"decoder": {"name": "decoders.light_hamburger", "predict_uncertainty": True},
}
required_data_keys = ["features"]
def _init(self, conf):
self.loss_type = conf.loss_type
self.loss_weight = conf.loss_weight
self.use_uncertainty_loss = conf.use_uncertainty_loss
self.predict_uncertainty = conf.decoder.predict_uncertainty
self.num_classes = 2
self.is_classification = self.conf.loss_type == "classification"
if self.is_classification:
self.num_classes = 73
self.decoder = get_model(conf.decoder.name)(conf.decoder)
self.linear_pred_up = nn.Conv2d(self.decoder.out_channels, self.num_classes, kernel_size=1)
def calculate_losses(self, predictions, targets, confidence=None):
predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163
residuals = predictions - targets
if self.loss_type == "l2":
loss = (residuals**2).sum(axis=1)
elif self.loss_type == "l1":
loss = residuals.abs().sum(axis=1)
elif self.loss_type == "dot":
loss = 1 - (residuals * targets).sum(axis=1)
elif self.loss_type == "cauchy":
c = 0.007 # -> corresponds to about 5 degrees
residuals = (residuals**2).sum(axis=1)
loss = c**2 / 2 * torch.log(1 + residuals / c**2)
elif self.loss_type == "huber":
c = deg2rad(1)
loss = nn.HuberLoss(reduction="none", delta=c)(predictions, targets).sum(axis=1)
else:
raise NotImplementedError(f"Unknown loss type {self.conf.loss_type}")
if confidence is not None and self.use_uncertainty_loss:
conf_weight = confidence / confidence.sum(axis=(-2, -1), keepdims=True)
conf_weight = conf_weight * (conf_weight.size(-1) * conf_weight.size(-2))
loss = loss * conf_weight.detach()
losses = {f"up-{self.loss_type}-loss": loss.mean(axis=(1, 2))}
losses = {k: v * self.loss_weight for k, v in losses.items()}
return losses
def _forward(self, data):
out = {}
x, log_confidence = self.decoder(data["features"])
up = self.linear_pred_up(x)
if self.predict_uncertainty:
out["up_confidence"] = torch.sigmoid(log_confidence)
if self.is_classification:
out["up_field"] = decode_up_bin(up.argmax(dim=1), self.num_classes)
return out
up = F.normalize(up, dim=1)
out["up_field"] = up
return out
def loss(self, pred, data):
if not self.conf.use_loss or self.is_classification:
return {}, self.metrics(pred, data)
predictions = pred["up_field"]
targets = data["up_field"]
losses = self.calculate_losses(predictions, targets, pred.get("up_confidence"))
total = 0 + losses[f"up-{self.loss_type}-loss"]
losses |= {"up_total": total}
return losses, self.metrics(pred, data)
def metrics(self, pred, data):
predictions = pred["up_field"]
targets = data["up_field"]
mask = predictions.sum(axis=1) != 0
error = up_error(predictions, targets) * mask
out = {"up_angle_error": error.mean(axis=(1, 2))}
if "up_confidence" in pred:
weighted_error = (error * pred["up_confidence"]).sum(axis=(1, 2))
out["up_angle_error_weighted"] = weighted_error / pred["up_confidence"].sum(axis=(1, 2))
for th in self.conf.recall_thresholds:
rec = (error < th).float().mean(axis=(1, 2))
out[f"up_angle_recall@{th}"] = rec
return out
|