veichta's picture
Upload folder using huggingface_hub
205a7af verified
raw
history blame
4.38 kB
"""up decoder head.
Adapted from https://github.com/jinlinyi/PerspectiveFields
"""
import logging
import torch
from torch import nn
from torch.nn import functional as F
from siclib.models import get_model
from siclib.models.base_model import BaseModel
from siclib.models.utils.metrics import up_error
from siclib.models.utils.perspective_encoding import decode_up_bin
from siclib.utils.conversions import deg2rad
logger = logging.getLogger(__name__)
# flake8: noqa
# mypy: ignore-errors
class UpDecoder(BaseModel):
default_conf = {
"loss_type": "l1",
"use_loss": True,
"use_uncertainty_loss": True,
"loss_weight": 1.0,
"recall_thresholds": [1, 3, 5, 10],
"decoder": {"name": "decoders.light_hamburger", "predict_uncertainty": True},
}
required_data_keys = ["features"]
def _init(self, conf):
self.loss_type = conf.loss_type
self.loss_weight = conf.loss_weight
self.use_uncertainty_loss = conf.use_uncertainty_loss
self.predict_uncertainty = conf.decoder.predict_uncertainty
self.num_classes = 2
self.is_classification = self.conf.loss_type == "classification"
if self.is_classification:
self.num_classes = 73
self.decoder = get_model(conf.decoder.name)(conf.decoder)
self.linear_pred_up = nn.Conv2d(self.decoder.out_channels, self.num_classes, kernel_size=1)
def calculate_losses(self, predictions, targets, confidence=None):
predictions = predictions.float() # https://github.com/pytorch/pytorch/issues/48163
residuals = predictions - targets
if self.loss_type == "l2":
loss = (residuals**2).sum(axis=1)
elif self.loss_type == "l1":
loss = residuals.abs().sum(axis=1)
elif self.loss_type == "dot":
loss = 1 - (residuals * targets).sum(axis=1)
elif self.loss_type == "cauchy":
c = 0.007 # -> corresponds to about 5 degrees
residuals = (residuals**2).sum(axis=1)
loss = c**2 / 2 * torch.log(1 + residuals / c**2)
elif self.loss_type == "huber":
c = deg2rad(1)
loss = nn.HuberLoss(reduction="none", delta=c)(predictions, targets).sum(axis=1)
else:
raise NotImplementedError(f"Unknown loss type {self.conf.loss_type}")
if confidence is not None and self.use_uncertainty_loss:
conf_weight = confidence / confidence.sum(axis=(-2, -1), keepdims=True)
conf_weight = conf_weight * (conf_weight.size(-1) * conf_weight.size(-2))
loss = loss * conf_weight.detach()
losses = {f"up-{self.loss_type}-loss": loss.mean(axis=(1, 2))}
losses = {k: v * self.loss_weight for k, v in losses.items()}
return losses
def _forward(self, data):
out = {}
x, log_confidence = self.decoder(data["features"])
up = self.linear_pred_up(x)
if self.predict_uncertainty:
out["up_confidence"] = torch.sigmoid(log_confidence)
if self.is_classification:
out["up_field"] = decode_up_bin(up.argmax(dim=1), self.num_classes)
return out
up = F.normalize(up, dim=1)
out["up_field"] = up
return out
def loss(self, pred, data):
if not self.conf.use_loss or self.is_classification:
return {}, self.metrics(pred, data)
predictions = pred["up_field"]
targets = data["up_field"]
losses = self.calculate_losses(predictions, targets, pred.get("up_confidence"))
total = 0 + losses[f"up-{self.loss_type}-loss"]
losses |= {"up_total": total}
return losses, self.metrics(pred, data)
def metrics(self, pred, data):
predictions = pred["up_field"]
targets = data["up_field"]
mask = predictions.sum(axis=1) != 0
error = up_error(predictions, targets) * mask
out = {"up_angle_error": error.mean(axis=(1, 2))}
if "up_confidence" in pred:
weighted_error = (error * pred["up_confidence"]).sum(axis=(1, 2))
out["up_angle_error_weighted"] = weighted_error / pred["up_confidence"].sum(axis=(1, 2))
for th in self.conf.recall_thresholds:
rec = (error < th).float().mean(axis=(1, 2))
out[f"up_angle_recall@{th}"] = rec
return out