import warnings from collections.abc import Callable, Sequence from typing import Any import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.modules.loss import _Loss from monai.losses.dice import DiceLoss from monai.losses.focal_loss import FocalLoss from monai.networks import one_hot from monai.utils import DiceCEReduction, LossReduction, Weight, deprecated_arg, look_up_option, pytorch_after ##### Adapted from Monai DiceFocalLoss class WeaklyDiceFocalLoss(_Loss): """ Compute Dice loss, Focal Loss, and weakly supervised loss from clinical predictor, and return the weighted sum of these three losses. ``gamma`` and ``lambda_focal`` are only used for the focal loss. ``include_background``, ``weight`` and ``reduction`` are used for both losses and other parameters are only used for dice loss. """ def __init__( self, include_background: bool = True, to_onehot_y: bool = False, sigmoid: bool = False, softmax: bool = False, other_act: Callable | None = None, squared_pred: bool = False, jaccard: bool = False, reduction: str = "mean", smooth_nr: float = 1e-5, smooth_dr: float = 1e-5, batch: bool = False, gamma: float = 2.0, focal_weight: Sequence[float] | float | int | torch.Tensor | None = None, weight: Sequence[float] | float | int | torch.Tensor | None = None, lambda_dice: float = 1.0, lambda_focal: float = 1.0, lambda_weak: float = 1.0, ) -> None: """ Args: include_background: if False channel index 0 (background category) is excluded from the calculation. to_onehot_y: whether to convert the ``target`` into the one-hot format, using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False. sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`, don't need to specify activation function for `FocalLoss`. softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`, don't need to specify activation function for `FocalLoss`. other_act: callable function to execute other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`. squared_pred: use squared versions of targets and predictions in the denominator or not. jaccard: compute Jaccard Index (soft IoU) instead of dice or not. reduction: {``"none"``, ``"mean"``, ``"sum"``} Specifies the reduction to apply to the output. Defaults to ``"mean"``. - ``"none"``: no reduction will be applied. - ``"mean"``: the sum of the output will be divided by the number of elements in the output. - ``"sum"``: the output will be summed. smooth_nr: a small constant added to the numerator to avoid zero. smooth_dr: a small constant added to the denominator to avoid nan. batch: whether to sum the intersection and union areas over the batch dimension before the dividing. Defaults to False, a Dice loss value is computed independently from each item in the batch before any `reduction`. gamma: value of the exponent gamma in the definition of the Focal loss. weight: weights to apply to the voxels of each class. If None no weights are applied. The input can be a single value (same weight for all classes), a sequence of values (the length of the sequence should be the same as the number of classes). lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0. Defaults to 1.0. lambda_focal: the trade-off weight value for focal loss. The value should be no less than 0.0. Defaults to 1.0. lambda_weak: the trade-off weight value for weakly supervised loss. The value should be no less than 0.0 Defaults to 0.2. """ super().__init__() weight = focal_weight if focal_weight is not None else weight self.dice = DiceLoss( include_background=include_background, to_onehot_y=False, sigmoid=sigmoid, softmax=softmax, other_act=other_act, squared_pred=squared_pred, jaccard=jaccard, reduction=reduction, smooth_nr=smooth_nr, smooth_dr=smooth_dr, batch=batch, weight=weight, ) self.focal = FocalLoss( include_background=include_background, to_onehot_y=False, gamma=gamma, weight=weight, reduction=reduction ) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.") if lambda_focal < 0.0: raise ValueError("lambda_focal should be no less than 0.0.") if lambda_weak < 0.0: raise ValueError("lambda_weak should be no less than 0.0.") self.lambda_dice = lambda_dice self.lambda_focal = lambda_focal self.to_onehot_y = to_onehot_y self.lambda_weak = lambda_weak def compute_weakly_supervised_loss(self, input: torch.Tensor, weaktarget: torch.Tensor) -> torch.Tensor: # compute ratio of tumor/liver in the predicted mask tumor_pixels = torch.sum(input[:, -1, ...], dim=(1, 2, 3)) liver_pixels = torch.sum(input[:, -2, ...], dim=(1, 2, 3)) + tumor_pixels predicted_ratio = tumor_pixels / liver_pixels loss = torch.mean((predicted_ratio - weaktarget) ** 2) return loss def forward(self, input: torch.Tensor, target: torch.Tensor, weaktarget: torch.Tensor) -> torch.Tensor: """ Args: input: the shape should be BNH[WD]. The input should be the original logits due to the restriction of ``monai.losses.FocalLoss``. target: the shape should be BNH[WD] or B1H[WD]. Raises: ValueError: When number of dimensions for input and target are different. ValueError: When number of channels for target is neither 1 nor the same as input. """ if len(input.shape) != len(target.shape): raise ValueError( "the number of dimensions for input and target should be the same, " f"got shape {input.shape} and {target.shape}." ) if self.to_onehot_y: n_pred_ch = input.shape[1] if n_pred_ch == 1: warnings.warn("single channel prediction, `to_onehot_y=True` ignored.") else: target = one_hot(target, num_classes=n_pred_ch) dice_loss = self.dice(input, target) focal_loss = self.focal(input, target) weak_loss = self.compute_weakly_supervised_loss(input, weaktarget) total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss + self.lambda_weak * weak_loss return total_loss