File size: 4,716 Bytes
205a7af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
"""Latitude decoder head.
Adapted from https://github.com/jinlinyi/PerspectiveFields
"""
import logging
import torch
from torch import nn
from siclib.models import get_model
from siclib.models.base_model import BaseModel
from siclib.models.utils.metrics import latitude_error
from siclib.models.utils.perspective_encoding import decode_bin_latitude
from siclib.utils.conversions import deg2rad
logger = logging.getLogger(__name__)
# flake8: noqa
# mypy: ignore-errors
class LatitudeDecoder(BaseModel):
default_conf = {
"loss_type": "l1",
"use_loss": True,
"use_uncertainty_loss": True,
"loss_weight": 1.0,
"recall_thresholds": [1, 3, 5, 10],
"use_tanh": True, # backward compatibility to original perspective weights
"decoder": {"name": "decoders.light_hamburger", "predict_uncertainty": True},
}
required_data_keys = ["features"]
def _init(self, conf):
self.loss_type = conf.loss_type
self.loss_weight = conf.loss_weight
self.use_uncertainty_loss = conf.use_uncertainty_loss
self.predict_uncertainty = conf.decoder.predict_uncertainty
self.num_classes = 1
self.is_classification = self.conf.loss_type == "classification"
if self.is_classification:
self.num_classes = 180
self.decoder = get_model(conf.decoder.name)(conf.decoder)
self.linear_pred_latitude = nn.Conv2d(
self.decoder.out_channels, self.num_classes, kernel_size=1
)
def calculate_losses(self, predictions, targets, confidence=None):
predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163
residuals = predictions - targets
if self.loss_type == "l2":
loss = (residuals**2).sum(axis=1)
elif self.loss_type == "l1":
loss = residuals.abs().sum(axis=1)
elif self.loss_type == "cauchy":
c = 0.007 # -> corresponds to about 5 degrees
residuals = (residuals**2).sum(axis=1)
loss = c**2 / 2 * torch.log(1 + residuals / c**2)
elif self.loss_type == "huber":
c = deg2rad(1)
loss = nn.HuberLoss(reduction="none", delta=c)(predictions, targets).sum(axis=1)
else:
raise NotImplementedError(f"Unknown loss type {self.conf.loss_type}")
if confidence is not None and self.use_uncertainty_loss:
conf_weight = confidence / confidence.sum(axis=(-2, -1), keepdims=True)
conf_weight = conf_weight * (conf_weight.size(-1) * conf_weight.size(-2))
loss = loss * conf_weight.detach()
losses = {f"latitude-{self.loss_type}-loss": loss.mean(axis=(1, 2))}
losses = {k: v * self.loss_weight for k, v in losses.items()}
return losses
def _forward(self, data):
out = {}
x, log_confidence = self.decoder(data["features"])
lat = self.linear_pred_latitude(x)
if self.predict_uncertainty:
out["latitude_confidence"] = torch.sigmoid(log_confidence)
if self.is_classification:
out["latitude_field_logits"] = lat
out["latitude_field"] = decode_bin_latitude(
lat.argmax(dim=1), self.num_classes
).unsqueeze(1)
return out
eps = 1e-5 # avoid nan in backward of asin
lat = torch.tanh(lat) if self.conf.use_tanh else lat
lat = torch.asin(torch.clamp(lat, -1 + eps, 1 - eps))
out["latitude_field"] = lat
return out
def loss(self, pred, data):
if not self.conf.use_loss or self.is_classification:
return {}, self.metrics(pred, data)
predictions = pred["latitude_field"]
targets = data["latitude_field"]
losses = self.calculate_losses(predictions, targets, pred.get("latitude_confidence"))
total = 0 + losses[f"latitude-{self.loss_type}-loss"]
losses |= {"latitude_total": total}
return losses, self.metrics(pred, data)
def metrics(self, pred, data):
predictions = pred["latitude_field"]
targets = data["latitude_field"]
error = latitude_error(predictions, targets)
out = {"latitude_angle_error": error.mean(axis=(1, 2))}
if "latitude_confidence" in pred:
weighted_error = (error * pred["latitude_confidence"]).sum(axis=(1, 2))
out["latitude_angle_error_weighted"] = weighted_error / pred["latitude_confidence"].sum(
axis=(1, 2)
)
for th in self.conf.recall_thresholds:
rec = (error < th).float().mean(axis=(1, 2))
out[f"latitude_angle_recall@{th}"] = rec
return out
|