|
"""up decoder head. |
|
|
|
Adapted from https://github.com/jinlinyi/PerspectiveFields |
|
""" |
|
|
|
import logging |
|
|
|
import torch |
|
from torch import nn |
|
from torch.nn import functional as F |
|
|
|
from siclib.models import get_model |
|
from siclib.models.base_model import BaseModel |
|
from siclib.models.utils.metrics import up_error |
|
from siclib.models.utils.perspective_encoding import decode_up_bin |
|
from siclib.utils.conversions import deg2rad |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
|
|
|
class UpDecoder(BaseModel): |
|
default_conf = { |
|
"loss_type": "l1", |
|
"use_loss": True, |
|
"use_uncertainty_loss": True, |
|
"loss_weight": 1.0, |
|
"recall_thresholds": [1, 3, 5, 10], |
|
"decoder": {"name": "decoders.light_hamburger", "predict_uncertainty": True}, |
|
} |
|
|
|
required_data_keys = ["features"] |
|
|
|
def _init(self, conf): |
|
self.loss_type = conf.loss_type |
|
self.loss_weight = conf.loss_weight |
|
|
|
self.use_uncertainty_loss = conf.use_uncertainty_loss |
|
self.predict_uncertainty = conf.decoder.predict_uncertainty |
|
|
|
self.num_classes = 2 |
|
self.is_classification = self.conf.loss_type == "classification" |
|
if self.is_classification: |
|
self.num_classes = 73 |
|
|
|
self.decoder = get_model(conf.decoder.name)(conf.decoder) |
|
self.linear_pred_up = nn.Conv2d(self.decoder.out_channels, self.num_classes, kernel_size=1) |
|
|
|
def calculate_losses(self, predictions, targets, confidence=None): |
|
predictions = predictions.float() |
|
|
|
residuals = predictions - targets |
|
if self.loss_type == "l2": |
|
loss = (residuals**2).sum(axis=1) |
|
elif self.loss_type == "l1": |
|
loss = residuals.abs().sum(axis=1) |
|
elif self.loss_type == "dot": |
|
loss = 1 - (residuals * targets).sum(axis=1) |
|
elif self.loss_type == "cauchy": |
|
c = 0.007 |
|
residuals = (residuals**2).sum(axis=1) |
|
loss = c**2 / 2 * torch.log(1 + residuals / c**2) |
|
elif self.loss_type == "huber": |
|
c = deg2rad(1) |
|
loss = nn.HuberLoss(reduction="none", delta=c)(predictions, targets).sum(axis=1) |
|
else: |
|
raise NotImplementedError(f"Unknown loss type {self.conf.loss_type}") |
|
|
|
if confidence is not None and self.use_uncertainty_loss: |
|
conf_weight = confidence / confidence.sum(axis=(-2, -1), keepdims=True) |
|
conf_weight = conf_weight * (conf_weight.size(-1) * conf_weight.size(-2)) |
|
loss = loss * conf_weight.detach() |
|
|
|
losses = {f"up-{self.loss_type}-loss": loss.mean(axis=(1, 2))} |
|
losses = {k: v * self.loss_weight for k, v in losses.items()} |
|
|
|
return losses |
|
|
|
def _forward(self, data): |
|
out = {} |
|
x, log_confidence = self.decoder(data["features"]) |
|
up = self.linear_pred_up(x) |
|
|
|
if self.predict_uncertainty: |
|
out["up_confidence"] = torch.sigmoid(log_confidence) |
|
|
|
if self.is_classification: |
|
out["up_field"] = decode_up_bin(up.argmax(dim=1), self.num_classes) |
|
return out |
|
|
|
up = F.normalize(up, dim=1) |
|
|
|
out["up_field"] = up |
|
return out |
|
|
|
def loss(self, pred, data): |
|
if not self.conf.use_loss or self.is_classification: |
|
return {}, self.metrics(pred, data) |
|
|
|
predictions = pred["up_field"] |
|
targets = data["up_field"] |
|
|
|
losses = self.calculate_losses(predictions, targets, pred.get("up_confidence")) |
|
|
|
total = 0 + losses[f"up-{self.loss_type}-loss"] |
|
losses |= {"up_total": total} |
|
return losses, self.metrics(pred, data) |
|
|
|
def metrics(self, pred, data): |
|
predictions = pred["up_field"] |
|
targets = data["up_field"] |
|
|
|
mask = predictions.sum(axis=1) != 0 |
|
|
|
error = up_error(predictions, targets) * mask |
|
out = {"up_angle_error": error.mean(axis=(1, 2))} |
|
|
|
if "up_confidence" in pred: |
|
weighted_error = (error * pred["up_confidence"]).sum(axis=(1, 2)) |
|
out["up_angle_error_weighted"] = weighted_error / pred["up_confidence"].sum(axis=(1, 2)) |
|
|
|
for th in self.conf.recall_thresholds: |
|
rec = (error < th).float().mean(axis=(1, 2)) |
|
out[f"up_angle_recall@{th}"] = rec |
|
|
|
return out |
|
|