|
from typing import List, Optional |
|
|
|
import torch |
|
from ._functional import soft_tversky_score |
|
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE |
|
from .dice import DiceLoss |
|
|
|
__all__ = ["TverskyLoss"] |
|
|
|
|
|
class TverskyLoss(DiceLoss): |
|
"""Tversky loss for image segmentation task. |
|
Where FP and FN is weighted by alpha and beta params. |
|
With alpha == beta == 0.5, this loss becomes equal DiceLoss. |
|
It supports binary, multiclass and multilabel cases |
|
|
|
Args: |
|
mode: Metric mode {'binary', 'multiclass', 'multilabel'} |
|
classes: Optional list of classes that contribute in loss computation; |
|
By default, all channels are included. |
|
log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky`` |
|
from_logits: If True assumes input is raw logits |
|
smooth: |
|
ignore_index: Label that indicates ignored pixels (does not contribute to loss) |
|
eps: Small epsilon for numerical stability |
|
alpha: Weight constant that penalize model for FPs (False Positives) |
|
beta: Weight constant that penalize model for FNs (False Negatives) |
|
gamma: Constant that squares the error function. Defaults to ``1.0`` |
|
|
|
Return: |
|
loss: torch.Tensor |
|
|
|
""" |
|
|
|
def __init__( |
|
self, |
|
mode: str, |
|
classes: List[int] = None, |
|
log_loss: bool = False, |
|
from_logits: bool = True, |
|
smooth: float = 0.0, |
|
ignore_index: Optional[int] = None, |
|
eps: float = 1e-7, |
|
alpha: float = 0.5, |
|
beta: float = 0.5, |
|
gamma: float = 1.0, |
|
): |
|
|
|
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} |
|
super().__init__( |
|
mode, classes, log_loss, from_logits, smooth, ignore_index, eps |
|
) |
|
self.alpha = alpha |
|
self.beta = beta |
|
self.gamma = gamma |
|
|
|
def aggregate_loss(self, loss): |
|
return loss.mean() ** self.gamma |
|
|
|
def compute_score( |
|
self, output, target, smooth=0.0, eps=1e-7, dims=None |
|
) -> torch.Tensor: |
|
return soft_tversky_score( |
|
output, target, self.alpha, self.beta, smooth, eps, dims |
|
) |
|
|