ghlee94's picture
Init
2a13495
import math
import numpy as np
from typing import Optional
import torch
import torch.nn.functional as F
__all__ = [
"focal_loss_with_logits",
"softmax_focal_loss_with_logits",
"soft_jaccard_score",
"soft_dice_score",
"wing_loss",
]
def to_tensor(x, dtype=None) -> torch.Tensor:
if isinstance(x, torch.Tensor):
if dtype is not None:
x = x.type(dtype)
return x
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if dtype is not None:
x = x.type(dtype)
return x
if isinstance(x, (list, tuple)):
x = np.array(x)
x = torch.from_numpy(x)
if dtype is not None:
x = x.type(dtype)
return x
def focal_loss_with_logits(
output: torch.Tensor,
target: torch.Tensor,
gamma: float = 2.0,
alpha: Optional[float] = 0.25,
reduction: str = "mean",
normalized: bool = False,
reduced_threshold: Optional[float] = None,
eps: float = 1e-6,
) -> torch.Tensor:
"""Compute binary focal loss between target and output logits.
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
Args:
output: Tensor of arbitrary shape (predictions of the model)
target: Tensor of the same shape as input
gamma: Focal loss power factor
alpha: Weight factor to balance positive and negative samples. Alpha must be in [0...1] range,
high values will give more weight to positive class.
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of
elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`.
'batchwise_mean' computes mean loss per sample in batch. Default: 'mean'
normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347).
References:
https://github.com/open-mmlab/mmdetection/blob/master/mmdet/core/loss/losses.py
"""
target = target.type(output.type())
logpt = F.binary_cross_entropy_with_logits(output, target, reduction="none")
pt = torch.exp(-logpt)
# compute the loss
if reduced_threshold is None:
focal_term = (1.0 - pt).pow(gamma)
else:
focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma)
focal_term[pt < reduced_threshold] = 1
loss = focal_term * logpt
if alpha is not None:
loss *= alpha * target + (1 - alpha) * (1 - target)
if normalized:
norm_factor = focal_term.sum().clamp_min(eps)
loss /= norm_factor
if reduction == "mean":
loss = loss.mean()
if reduction == "sum":
loss = loss.sum()
if reduction == "batchwise_mean":
loss = loss.sum(0)
return loss
def softmax_focal_loss_with_logits(
output: torch.Tensor,
target: torch.Tensor,
gamma: float = 2.0,
reduction="mean",
normalized=False,
reduced_threshold: Optional[float] = None,
eps: float = 1e-6,
) -> torch.Tensor:
"""Softmax version of focal loss between target and output logits.
See :class:`~pytorch_toolbelt.losses.FocalLoss` for details.
Args:
output: Tensor of shape [B, C, *] (Similar to nn.CrossEntropyLoss)
target: Tensor of shape [B, *] (Similar to nn.CrossEntropyLoss)
reduction (string, optional): Specifies the reduction to apply to the output:
'none' | 'mean' | 'sum' | 'batchwise_mean'. 'none': no reduction will be applied,
'mean': the sum of the output will be divided by the number of
elements in the output, 'sum': the output will be summed. Note: :attr:`size_average`
and :attr:`reduce` are in the process of being deprecated, and in the meantime,
specifying either of those two args will override :attr:`reduction`.
'batchwise_mean' computes mean loss per sample in batch. Default: 'mean'
normalized (bool): Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf).
reduced_threshold (float, optional): Compute reduced focal loss (https://arxiv.org/abs/1903.01347).
"""
log_softmax = F.log_softmax(output, dim=1)
loss = F.nll_loss(log_softmax, target, reduction="none")
pt = torch.exp(-loss)
# compute the loss
if reduced_threshold is None:
focal_term = (1.0 - pt).pow(gamma)
else:
focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma)
focal_term[pt < reduced_threshold] = 1
loss = focal_term * loss
if normalized:
norm_factor = focal_term.sum().clamp_min(eps)
loss = loss / norm_factor
if reduction == "mean":
loss = loss.mean()
if reduction == "sum":
loss = loss.sum()
if reduction == "batchwise_mean":
loss = loss.sum(0)
return loss
def soft_jaccard_score(
output: torch.Tensor,
target: torch.Tensor,
smooth: float = 0.0,
eps: float = 1e-7,
dims=None,
) -> torch.Tensor:
assert output.size() == target.size()
if dims is not None:
intersection = torch.sum(output * target, dim=dims)
cardinality = torch.sum(output + target, dim=dims)
else:
intersection = torch.sum(output * target)
cardinality = torch.sum(output + target)
union = cardinality - intersection
jaccard_score = (intersection + smooth) / (union + smooth).clamp_min(eps)
return jaccard_score
def soft_dice_score(
output: torch.Tensor,
target: torch.Tensor,
smooth: float = 0.0,
eps: float = 1e-7,
dims=None,
) -> torch.Tensor:
assert output.size() == target.size()
if dims is not None:
intersection = torch.sum(output * target, dim=dims)
cardinality = torch.sum(output + target, dim=dims)
else:
intersection = torch.sum(output * target)
cardinality = torch.sum(output + target)
dice_score = (2.0 * intersection + smooth) / (cardinality + smooth).clamp_min(eps)
return dice_score
def soft_tversky_score(
output: torch.Tensor,
target: torch.Tensor,
alpha: float,
beta: float,
smooth: float = 0.0,
eps: float = 1e-7,
dims=None,
) -> torch.Tensor:
assert output.size() == target.size()
if dims is not None:
intersection = torch.sum(output * target, dim=dims) # TP
fp = torch.sum(output * (1.0 - target), dim=dims)
fn = torch.sum((1 - output) * target, dim=dims)
else:
intersection = torch.sum(output * target) # TP
fp = torch.sum(output * (1.0 - target))
fn = torch.sum((1 - output) * target)
tversky_score = (intersection + smooth) / (
intersection + alpha * fp + beta * fn + smooth
).clamp_min(eps)
return tversky_score
def wing_loss(
output: torch.Tensor, target: torch.Tensor, width=5, curvature=0.5, reduction="mean"
):
"""Wing loss
References:
https://arxiv.org/pdf/1711.06753.pdf
"""
diff_abs = (target - output).abs()
loss = diff_abs.clone()
idx_smaller = diff_abs < width
idx_bigger = diff_abs >= width
loss[idx_smaller] = width * torch.log(1 + diff_abs[idx_smaller] / curvature)
C = width - width * math.log(1 + width / curvature)
loss[idx_bigger] = loss[idx_bigger] - C
if reduction == "sum":
loss = loss.sum()
if reduction == "mean":
loss = loss.mean()
return loss
def label_smoothed_nll_loss(
lprobs: torch.Tensor,
target: torch.Tensor,
epsilon: float,
ignore_index=None,
reduction="mean",
dim=-1,
) -> torch.Tensor:
"""NLL loss with label smoothing
References:
https://github.com/pytorch/fairseq/blob/master/fairseq/criterions/label_smoothed_cross_entropy.py
Args:
lprobs (torch.Tensor): Log-probabilities of predictions (e.g after log_softmax)
"""
if target.dim() == lprobs.dim() - 1:
target = target.unsqueeze(dim)
if ignore_index is not None:
pad_mask = target.eq(ignore_index)
target = target.masked_fill(pad_mask, 0)
nll_loss = -lprobs.gather(dim=dim, index=target)
smooth_loss = -lprobs.sum(dim=dim, keepdim=True)
# nll_loss.masked_fill_(pad_mask, 0.0)
# smooth_loss.masked_fill_(pad_mask, 0.0)
nll_loss = nll_loss.masked_fill(pad_mask, 0.0)
smooth_loss = smooth_loss.masked_fill(pad_mask, 0.0)
else:
nll_loss = -lprobs.gather(dim=dim, index=target)
smooth_loss = -lprobs.sum(dim=dim, keepdim=True)
nll_loss = nll_loss.squeeze(dim)
smooth_loss = smooth_loss.squeeze(dim)
if reduction == "sum":
nll_loss = nll_loss.sum()
smooth_loss = smooth_loss.sum()
if reduction == "mean":
nll_loss = nll_loss.mean()
smooth_loss = smooth_loss.mean()
eps_i = epsilon / lprobs.size(dim)
loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
return loss