File size: 2,162 Bytes
2a13495 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 |
from typing import List, Optional
import torch
from ._functional import soft_tversky_score
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE
from .dice import DiceLoss
__all__ = ["TverskyLoss"]
class TverskyLoss(DiceLoss):
"""Tversky loss for image segmentation task.
Where FP and FN is weighted by alpha and beta params.
With alpha == beta == 0.5, this loss becomes equal DiceLoss.
It supports binary, multiclass and multilabel cases
Args:
mode: Metric mode {'binary', 'multiclass', 'multilabel'}
classes: Optional list of classes that contribute in loss computation;
By default, all channels are included.
log_loss: If True, loss computed as ``-log(tversky)`` otherwise ``1 - tversky``
from_logits: If True assumes input is raw logits
smooth:
ignore_index: Label that indicates ignored pixels (does not contribute to loss)
eps: Small epsilon for numerical stability
alpha: Weight constant that penalize model for FPs (False Positives)
beta: Weight constant that penalize model for FNs (False Negatives)
gamma: Constant that squares the error function. Defaults to ``1.0``
Return:
loss: torch.Tensor
"""
def __init__(
self,
mode: str,
classes: List[int] = None,
log_loss: bool = False,
from_logits: bool = True,
smooth: float = 0.0,
ignore_index: Optional[int] = None,
eps: float = 1e-7,
alpha: float = 0.5,
beta: float = 0.5,
gamma: float = 1.0,
):
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE}
super().__init__(
mode, classes, log_loss, from_logits, smooth, ignore_index, eps
)
self.alpha = alpha
self.beta = beta
self.gamma = gamma
def aggregate_loss(self, loss):
return loss.mean() ** self.gamma
def compute_score(
self, output, target, smooth=0.0, eps=1e-7, dims=None
) -> torch.Tensor:
return soft_tversky_score(
output, target, self.alpha, self.beta, smooth, eps, dims
)
|