import math import numpy as np from typing import Optional import torch import torch.nn.functional as F __all__ = [ "focal_loss_with_logits", "softmax_focal_loss_with_logits", "soft_jaccard_score", "soft_dice_score", "wing_loss", ] def to_tensor(x, dtype=None) -> torch.Tensor: if isinstance(x, torch.Tensor): if dtype is not None: x = x.type(dtype) return x if isinstance(x, np.ndarray): x = torch.from_numpy(x) if dtype is not None: x = x.type(dtype) return x if isinstance(x, (list, tuple)): x = np.array(x) x = torch.from_numpy(x) if dtype is not None: x = x.type(dtype) return x def focal_loss_with_logits( output: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, alpha: Optional[float] = 0.25, reduction: str = "mean", normalized: bool = False, reduced_threshold: Optional[float] = None, eps: float = 1e-6, ) -> torch.Tensor: """Compute binary focal loss between target and output logits. See :class:`~pytorch_toolbelt.losses.FocalLoss` for details. Args: output: Tensor of arbitrary shape (predictions of the model) target: Tensor of the same shape as input gamma: Focal loss power factor alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range, high values will give more weight to positive class. reduction (string, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. 'batchwise_mean' computes mean loss per sample in batch. Default: 'mean' normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347). References: https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py """ target = target.type(output.type()) logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none") pt = torch.exp(-logpt) # compute the loss if reduced_threshold is None: focal_term = (1.0 - pt).pow(gamma) else: focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) focal_term[pt < reduced_threshold] = 1 loss = focal_term * logpt if alpha is not None: loss *= alpha * target + (1 - alpha) * (1 - target) if normalized: norm_factor = focal_term.sum().clamp_min(eps) loss /= norm_factor if reduction == "mean": loss = loss.mean() if reduction == "sum": loss = loss.sum() if reduction == "batchwise_mean": loss = loss.sum(0) return loss def softmax_focal_loss_with_logits( output: torch.Tensor, target: torch.Tensor, gamma: float = 2.0, reduction="mean", normalized=False, reduced_threshold: Optional[float] = None, eps: float = 1e-6, ) -> torch.Tensor: """Softmax version of focal loss between target and output logits. See :class:`~pytorch_toolbelt.losses.FocalLoss` for details. Args: output: Tensor of shape [B, C, *] (Similar to nn.CrossEntropyLoss) target: Tensor of shape [B, *] (Similar to nn.CrossEntropyLoss) reduction (string, optional): Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied, 'mean': the sum of the output will be divided by the number of elements in the output, 'sum': the output will be summed. Note: :attr:`size_average` and :attr:`reduce` are in the process of being deprecated, and in the meantime, specifying either of those two args will override :attr:`reduction`. 'batchwise_mean' computes mean loss per sample in batch. Default: 'mean' normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347). """ log_softmax = F.log_softmax(output, dim=1) loss = F.nll_loss(log_softmax, target, reduction="none") pt = torch.exp(-loss) # compute the loss if reduced_threshold is None: focal_term = (1.0 - pt).pow(gamma) else: focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) focal_term[pt < reduced_threshold] = 1 loss = focal_term * loss if normalized: norm_factor = focal_term.sum().clamp_min(eps) loss = loss / norm_factor if reduction == "mean": loss = loss.mean() if reduction == "sum": loss = loss.sum() if reduction == "batchwise_mean": loss = loss.sum(0) return loss def soft_jaccard_score( output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None, ) -> torch.Tensor: assert output.size() == target.size() if dims is not None: intersection = torch.sum(output * target, dim=dims) cardinality = torch.sum(output + target, dim=dims) else: intersection = torch.sum(output * target) cardinality = torch.sum(output + target) union = cardinality - intersection jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps) return jaccard_score def soft_dice_score( output: torch.Tensor, target: torch.Tensor, smooth: float = 0.0, eps: float = 1e-7, dims=None, ) -> torch.Tensor: assert output.size() == target.size() if dims is not None: intersection = torch.sum(output * target, dim=dims) cardinality = torch.sum(output + target, dim=dims) else: intersection = torch.sum(output * target) cardinality = torch.sum(output + target) dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps) return dice_score def soft_tversky_score( output: torch.Tensor, target: torch.Tensor, alpha: float, beta: float, smooth: float = 0.0, eps: float = 1e-7, dims=None, ) -> torch.Tensor: assert output.size() == target.size() if dims is not None: intersection = torch.sum(output * target, dim=dims) # TP fp = torch.sum(output * (1.0 - target), dim=dims) fn = torch.sum((1 - output) * target, dim=dims) else: intersection = torch.sum(output * target) # TP fp = torch.sum(output * (1.0 - target)) fn = torch.sum((1 - output) * target) tversky_score = (intersection + smooth) / ( intersection + alpha * fp + beta * fn + smooth ).clamp_min(eps) return tversky_score def wing_loss( output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean" ): """Wing loss References: https://arxiv.org/pdf/1711.06753.pdf """ diff_abs = (target - output).abs() loss = diff_abs.clone() idx_smaller = diff_abs < width idx_bigger = diff_abs >= width loss[idx_smaller] = width * torch.log(1 + diff_abs[idx_smaller] / curvature) C = width - width * math.log(1 + width / curvature) loss[idx_bigger] = loss[idx_bigger] - C if reduction == "sum": loss = loss.sum() if reduction == "mean": loss = loss.mean() return loss def label_smoothed_nll_loss( lprobs: torch.Tensor, target: torch.Tensor, epsilon: float, ignore_index=None, reduction="mean", dim=-1, ) -> torch.Tensor: """NLL loss with label smoothing References: https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py Args: lprobs (torch.Tensor): Log-probabilities of predictions (e.g after log_softmax) """ if target.dim() == lprobs.dim() - 1: target = target.unsqueeze(dim) if ignore_index is not None: pad_mask = target.eq(ignore_index) target = target.masked_fill(pad_mask, 0) nll_loss = -lprobs.gather(dim=dim, index=target) smooth_loss = -lprobs.sum(dim=dim, keepdim=True) # nll_loss.masked_fill_(pad_mask, 0.0) # smooth_loss.masked_fill_(pad_mask, 0.0) nll_loss = nll_loss.masked_fill(pad_mask, 0.0) smooth_loss = smooth_loss.masked_fill(pad_mask, 0.0) else: nll_loss = -lprobs.gather(dim=dim, index=target) smooth_loss = -lprobs.sum(dim=dim, keepdim=True) nll_loss = nll_loss.squeeze(dim) smooth_loss = smooth_loss.squeeze(dim) if reduction == "sum": nll_loss = nll_loss.sum() smooth_loss = smooth_loss.sum() if reduction == "mean": nll_loss = nll_loss.mean() smooth_loss = smooth_loss.mean() eps_i = epsilon / lprobs.size(dim) loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss return loss