ghlee94's picture
Init
2a13495
from typing import List, Optional
import torch
from ._functional import soft_tversky_score
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
from .dice import DiceLoss
__all__ = ["TverskyLoss"]
class TverskyLoss(DiceLoss):
"""Tversky loss for image segmentation task.
Where FP and FN is weighted by alpha and beta params.
With alpha == beta == 0.5, this loss becomes equal DiceLoss.
It supports binary, multiclass and multilabel cases
Args:
mode: Metric mode {'binary', 'multiclass', 'multilabel'}
classes: Optional list of classes that contribute in loss computation;
By default, all channels are included.
log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky``
from_logits: If True assumes input is raw logits
smooth:
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
eps: Small epsilon for numerical stability
alpha: Weight constant that penalize model for FPs (False Positives)
beta: Weight constant that penalize model for FNs (False Negatives)
gamma: Constant that squares the error function. Defaults to ``1.0``
Return:
loss: torch.Tensor
"""
def __init__(
self,
mode: str,
classes: List[int] = None,
log_loss: bool = False,
from_logits: bool = True,
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
):
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__(
mode, classes, log_loss, from_logits, smooth, ignore_index, eps
)
self.alpha = alpha
self.beta = beta
self.gamma = gamma
def aggregate_loss(self, loss):
return loss.mean() ** self.gamma
def compute_score(
self, output, target, smooth=0.0, eps=1e-7, dims=None
) -> torch.Tensor:
return soft_tversky_score(
output, target, self.alpha, self.beta, smooth, eps, dims
)