|
from typing import Optional |
|
from functools import partial |
|
|
|
import torch |
|
from torch.nn.modules.loss import _Loss |
|
from ._functional import focal_loss_with_logits |
|
from .constants import BINARY_MODE, MULTICLASS_MODE, MULTILABEL_MODE |
|
|
|
__all__ = ["FocalLoss"] |
|
|
|
|
|
class FocalLoss(_Loss): |
|
def __init__( |
|
self, |
|
mode: str, |
|
alpha: Optional[float] = None, |
|
gamma: Optional[float] = 2.0, |
|
ignore_index: Optional[int] = None, |
|
reduction: Optional[str] = "mean", |
|
normalized: bool = False, |
|
reduced_threshold: Optional[float] = None, |
|
): |
|
"""Compute Focal loss |
|
|
|
Args: |
|
mode: Loss mode 'binary', 'multiclass' or 'multilabel' |
|
alpha: Prior probability of having positive value in target. |
|
gamma: Power factor for dampening weight (focal strength). |
|
ignore_index: If not None, targets may contain values to be ignored. |
|
Target values equal to ignore_index will be ignored from loss computation. |
|
normalized: Compute normalized focal loss (https://arxiv.org/pdf/1909.07829.pdf). |
|
reduced_threshold: Switch to reduced focal loss. Note, when using this mode you |
|
should use `reduction="sum"`. |
|
|
|
Shape |
|
- **y_pred** - torch.Tensor of shape (N, C, H, W) |
|
- **y_true** - torch.Tensor of shape (N, H, W) or (N, C, H, W) |
|
|
|
Reference |
|
https://github.com/BloodAxe/pytorch-toolbelt |
|
|
|
""" |
|
assert mode in {BINARY_MODE, MULTILABEL_MODE, MULTICLASS_MODE} |
|
super().__init__() |
|
|
|
self.mode = mode |
|
self.ignore_index = ignore_index |
|
self.focal_loss_fn = partial( |
|
focal_loss_with_logits, |
|
alpha=alpha, |
|
gamma=gamma, |
|
reduced_threshold=reduced_threshold, |
|
reduction=reduction, |
|
normalized=normalized, |
|
) |
|
|
|
def forward(self, y_pred: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor: |
|
|
|
if self.mode in {BINARY_MODE, MULTILABEL_MODE}: |
|
y_true = y_true.view(-1) |
|
y_pred = y_pred.view(-1) |
|
|
|
if self.ignore_index is not None: |
|
|
|
not_ignored = y_true != self.ignore_index |
|
y_pred = y_pred[not_ignored] |
|
y_true = y_true[not_ignored] |
|
|
|
loss = self.focal_loss_fn(y_pred, y_true) |
|
|
|
elif self.mode == MULTICLASS_MODE: |
|
|
|
num_classes = y_pred.size(1) |
|
loss = 0 |
|
|
|
|
|
if self.ignore_index is not None: |
|
not_ignored = y_true != self.ignore_index |
|
|
|
for cls in range(num_classes): |
|
cls_y_true = (y_true == cls).long() |
|
cls_y_pred = y_pred[:, cls, ...] |
|
|
|
if self.ignore_index is not None: |
|
cls_y_true = cls_y_true[not_ignored] |
|
cls_y_pred = cls_y_pred[not_ignored] |
|
|
|
loss += self.focal_loss_fn(cls_y_pred, cls_y_true) |
|
|
|
return loss |
|
|