|
""" |
|
Lovasz-Softmax and Jaccard hinge loss in PyTorch |
|
Maxim Berman 2018 ESAT-PSI KU Leuven (MIT License) |
|
""" |
|
|
|
from __future__ import print_function, division |
|
from typing import Optional |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from torch.nn.modules.loss import _Loss |
|
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE |
|
|
|
try: |
|
from itertools import ifilterfalse |
|
except ImportError: |
|
from itertools import filterfalse as ifilterfalse |
|
|
|
__all__ = ["LovaszLoss"] |
|
|
|
|
|
def _lovasz_grad(gt_sorted): |
|
"""Compute gradient of the Lovasz extension w.r.t sorted errors |
|
See Alg. 1 in paper |
|
""" |
|
p = len(gt_sorted) |
|
gts = gt_sorted.sum() |
|
intersection = gts - gt_sorted.float().cumsum(0) |
|
union = gts + (1 - gt_sorted).float().cumsum(0) |
|
jaccard = 1.0 - intersection / union |
|
if p > 1: |
|
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1] |
|
return jaccard |
|
|
|
|
|
def _lovasz_hinge(logits, labels, per_image=True, ignore=None): |
|
""" |
|
Binary Lovasz hinge loss |
|
logits: [B, H, W] Logits at each pixel (between -infinity and +infinity) |
|
labels: [B, H, W] Tensor, binary ground truth masks (0 or 1) |
|
per_image: compute the loss per image instead of per batch |
|
ignore: void class id |
|
""" |
|
if per_image: |
|
loss = mean( |
|
_lovasz_hinge_flat( |
|
*_flatten_binary_scores(log.unsqueeze(0), lab.unsqueeze(0), ignore) |
|
) |
|
for log, lab in zip(logits, labels) |
|
) |
|
else: |
|
loss = _lovasz_hinge_flat(*_flatten_binary_scores(logits, labels, ignore)) |
|
return loss |
|
|
|
|
|
def _lovasz_hinge_flat(logits, labels): |
|
"""Binary Lovasz hinge loss |
|
Args: |
|
logits: [P] Logits at each prediction (between -infinity and +infinity) |
|
labels: [P] Tensor, binary ground truth labels (0 or 1) |
|
ignore: label to ignore |
|
""" |
|
if len(labels) == 0: |
|
|
|
return logits.sum() * 0.0 |
|
signs = 2.0 * labels.float() - 1.0 |
|
errors = 1.0 - logits * signs |
|
errors_sorted, perm = torch.sort(errors, dim=0, descending=True) |
|
perm = perm.data |
|
gt_sorted = labels[perm] |
|
grad = _lovasz_grad(gt_sorted) |
|
loss = torch.dot(F.relu(errors_sorted), grad) |
|
return loss |
|
|
|
|
|
def _flatten_binary_scores(scores, labels, ignore=None): |
|
"""Flattens predictions in the batch (binary case) |
|
Remove labels equal to 'ignore' |
|
""" |
|
scores = scores.view(-1) |
|
labels = labels.view(-1) |
|
if ignore is None: |
|
return scores, labels |
|
valid = labels != ignore |
|
vscores = scores[valid] |
|
vlabels = labels[valid] |
|
return vscores, vlabels |
|
|
|
|
|
|
|
|
|
|
|
def _lovasz_softmax(probas, labels, classes="present", per_image=False, ignore=None): |
|
"""Multi-class Lovasz-Softmax loss |
|
Args: |
|
@param probas: [B, C, H, W] Class probabilities at each prediction (between 0 and 1). |
|
Interpreted as binary (sigmoid) output with outputs of size [B, H, W]. |
|
@param labels: [B, H, W] Tensor, ground truth labels (between 0 and C - 1) |
|
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. |
|
@param per_image: compute the loss per image instead of per batch |
|
@param ignore: void class labels |
|
""" |
|
if per_image: |
|
loss = mean( |
|
_lovasz_softmax_flat( |
|
*_flatten_probas(prob.unsqueeze(0), lab.unsqueeze(0), ignore), |
|
classes=classes |
|
) |
|
for prob, lab in zip(probas, labels) |
|
) |
|
else: |
|
loss = _lovasz_softmax_flat( |
|
*_flatten_probas(probas, labels, ignore), classes=classes |
|
) |
|
return loss |
|
|
|
|
|
def _lovasz_softmax_flat(probas, labels, classes="present"): |
|
"""Multi-class Lovasz-Softmax loss |
|
Args: |
|
@param probas: [P, C] Class probabilities at each prediction (between 0 and 1) |
|
@param labels: [P] Tensor, ground truth labels (between 0 and C - 1) |
|
@param classes: 'all' for all, 'present' for classes present in labels, or a list of classes to average. |
|
""" |
|
if probas.numel() == 0: |
|
|
|
return probas * 0.0 |
|
C = probas.size(1) |
|
losses = [] |
|
class_to_sum = list(range(C)) if classes in ["all", "present"] else classes |
|
for c in class_to_sum: |
|
fg = (labels == c).type_as(probas) |
|
if classes == "present" and fg.sum() == 0: |
|
continue |
|
if C == 1: |
|
if len(classes) > 1: |
|
raise ValueError("Sigmoid output possible only with 1 class") |
|
class_pred = probas[:, 0] |
|
else: |
|
class_pred = probas[:, c] |
|
errors = (fg - class_pred).abs() |
|
errors_sorted, perm = torch.sort(errors, 0, descending=True) |
|
perm = perm.data |
|
fg_sorted = fg[perm] |
|
losses.append(torch.dot(errors_sorted, _lovasz_grad(fg_sorted))) |
|
return mean(losses) |
|
|
|
|
|
def _flatten_probas(probas, labels, ignore=None): |
|
"""Flattens predictions in the batch""" |
|
if probas.dim() == 3: |
|
|
|
B, H, W = probas.size() |
|
probas = probas.view(B, 1, H, W) |
|
|
|
C = probas.size(1) |
|
probas = torch.movedim(probas, 1, -1) |
|
probas = probas.contiguous().view(-1, C) |
|
|
|
labels = labels.view(-1) |
|
if ignore is None: |
|
return probas, labels |
|
valid = labels != ignore |
|
vprobas = probas[valid] |
|
vlabels = labels[valid] |
|
return vprobas, vlabels |
|
|
|
|
|
|
|
def isnan(x): |
|
return x != x |
|
|
|
|
|
def mean(values, ignore_nan=False, empty=0): |
|
"""Nanmean compatible with generators.""" |
|
values = iter(values) |
|
if ignore_nan: |
|
values = ifilterfalse(isnan, values) |
|
try: |
|
n = 1 |
|
acc = next(values) |
|
except StopIteration: |
|
if empty == "raise": |
|
raise ValueError("Empty mean") |
|
return empty |
|
for n, v in enumerate(values, 2): |
|
acc += v |
|
if n == 1: |
|
return acc |
|
return acc / n |
|
|
|
|
|
class LovaszLoss(_Loss): |
|
def __init__( |
|
self, |
|
mode: str, |
|
per_image: bool = False, |
|
ignore_index: Optional[int] = None, |
|
from_logits: bool = True, |
|
): |
|
"""Lovasz loss for image segmentation task. |
|
It supports binary, multiclass and multilabel cases |
|
|
|
Args: |
|
mode: Loss mode 'binary', 'multiclass' or 'multilabel' |
|
ignore_index: Label that indicates ignored pixels (does not contribute to loss) |
|
per_image: If True loss computed per each image and then averaged, else computed per whole batch |
|
|
|
Shape |
|
- **y_pred** - torch.Tensor of shape (N, C, H, W) |
|
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) |
|
|
|
Reference |
|
https://github.com/BloodAxe/pytorch-toolbelt |
|
""" |
|
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} |
|
super().__init__() |
|
|
|
self.mode = mode |
|
self.ignore_index = ignore_index |
|
self.per_image = per_image |
|
|
|
def forward(self, y_pred, y_true): |
|
|
|
if self.mode in {BINARY_MODE, MULTILABEL_MODE}: |
|
loss = _lovasz_hinge( |
|
y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index |
|
) |
|
elif self.mode == MULTICLASS_MODE: |
|
y_pred = y_pred.softmax(dim=1) |
|
loss = _lovasz_softmax( |
|
y_pred, y_true, per_image=self.per_image, ignore=self.ignore_index |
|
) |
|
else: |
|
raise ValueError("Wrong mode {}.".format(self.mode)) |
|
return loss |
|
|