# Copyright (c) OpenMMLab. All rights reserved. from typing import Optional, Union import torch import torch.nn as nn from mmocr.registry import MODELS @MODELS.register_module() class MaskedBalancedBCEWithLogitsLoss(nn.Module): """This loss combines a Sigmoid layers and a masked balanced BCE loss in one single class. It's AMP-eligible. Args: reduction (str, optional): The method to reduce the loss. Options are 'none', 'mean' and 'sum'. Defaults to 'none'. negative_ratio (float or int, optional): Maximum ratio of negative samples to positive ones. Defaults to 3. fallback_negative_num (int, optional): When the mask contains no positive samples, the number of negative samples to be sampled. Defaults to 0. eps (float, optional): Eps to avoid zero-division error. Defaults to 1e-6. """ def __init__(self, reduction: str = 'none', negative_ratio: Union[float, int] = 3, fallback_negative_num: int = 0, eps: float = 1e-6) -> None: super().__init__() assert reduction in ['none', 'mean', 'sum'] assert isinstance(negative_ratio, (float, int)) assert isinstance(fallback_negative_num, int) assert isinstance(eps, float) self.eps = eps self.negative_ratio = negative_ratio self.reduction = reduction self.fallback_negative_num = fallback_negative_num self.loss = nn.BCEWithLogitsLoss(reduction=reduction) def forward(self, pred: torch.Tensor, gt: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Forward function. Args: pred (torch.Tensor): The prediction in any shape. gt (torch.Tensor): The learning target of the prediction in the same shape as pred. mask (torch.Tensor, optional): Binary mask in the same shape of pred, indicating positive regions to calculate the loss. Whole region will be taken into account if not provided. Defaults to None. Returns: torch.Tensor: The loss value. """ assert pred.size() == gt.size() and gt.numel() > 0 if mask is None: mask = torch.ones_like(gt) assert mask.size() == gt.size() positive = (gt * mask).float() negative = ((1 - gt) * mask).float() positive_count = int(positive.sum()) if positive_count == 0: negative_count = min( int(negative.sum()), self.fallback_negative_num) else: negative_count = min( int(negative.sum()), int(positive_count * self.negative_ratio)) assert gt.max() <= 1 and gt.min() >= 0 loss = self.loss(pred, gt) positive_loss = loss * positive negative_loss = loss * negative negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) balance_loss = (positive_loss.sum() + negative_loss.sum()) / ( positive_count + negative_count + self.eps) return balance_loss @MODELS.register_module() class MaskedBalancedBCELoss(MaskedBalancedBCEWithLogitsLoss): """Masked Balanced BCE loss. Args: reduction (str, optional): The method to reduce the loss. Options are 'none', 'mean' and 'sum'. Defaults to 'none'. negative_ratio (float or int): Maximum ratio of negative samples to positive ones. Defaults to 3. fallback_negative_num (int): When the mask contains no positive samples, the number of negative samples to be sampled. Defaults to 0. eps (float): Eps to avoid zero-division error. Defaults to 1e-6. """ def __init__(self, reduction: str = 'none', negative_ratio: Union[float, int] = 3, fallback_negative_num: int = 0, eps: float = 1e-6) -> None: super().__init__() assert reduction in ['none', 'mean', 'sum'] assert isinstance(negative_ratio, (float, int)) assert isinstance(fallback_negative_num, int) assert isinstance(eps, float) self.eps = eps self.negative_ratio = negative_ratio self.reduction = reduction self.fallback_negative_num = fallback_negative_num self.loss = nn.BCELoss(reduction=reduction) def forward(self, pred: torch.Tensor, gt: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Forward function. Args: pred (torch.Tensor): The prediction in any shape. gt (torch.Tensor): The learning target of the prediction in the same shape as pred. mask (torch.Tensor, optional): Binary mask in the same shape of pred, indicating positive regions to calculate the loss. Whole region will be taken into account if not provided. Defaults to None. Returns: torch.Tensor: The loss value. """ assert pred.max() <= 1 and pred.min() >= 0 return super().forward(pred, gt, mask) @MODELS.register_module() class MaskedBCEWithLogitsLoss(nn.Module): """This loss combines a Sigmoid layers and a masked BCE loss in one single class. It's AMP-eligible. Args: eps (float): Eps to avoid zero-division error. Defaults to 1e-6. """ def __init__(self, eps: float = 1e-6) -> None: super().__init__() assert isinstance(eps, float) self.eps = eps self.loss = nn.BCEWithLogitsLoss(reduction='none') def forward(self, pred: torch.Tensor, gt: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Forward function. Args: pred (torch.Tensor): The prediction in any shape. gt (torch.Tensor): The learning target of the prediction in the same shape as pred. mask (torch.Tensor, optional): Binary mask in the same shape of pred, indicating positive regions to calculate the loss. Whole region will be taken into account if not provided. Defaults to None. Returns: torch.Tensor: The loss value. """ assert pred.size() == gt.size() and gt.numel() > 0 if mask is None: mask = torch.ones_like(gt) assert mask.size() == gt.size() assert gt.max() <= 1 and gt.min() >= 0 loss = self.loss(pred, gt) return (loss * mask).sum() / (mask.sum() + self.eps) @MODELS.register_module() class MaskedBCELoss(MaskedBCEWithLogitsLoss): """Masked BCE loss. Args: eps (float): Eps to avoid zero-division error. Defaults to 1e-6. """ def __init__(self, eps: float = 1e-6) -> None: super().__init__() assert isinstance(eps, float) self.eps = eps self.loss = nn.BCELoss(reduction='none') def forward(self, pred: torch.Tensor, gt: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor: """Forward function. Args: pred (torch.Tensor): The prediction in any shape. gt (torch.Tensor): The learning target of the prediction in the same shape as pred. mask (torch.Tensor, optional): Binary mask in the same shape of pred, indicating positive regions to calculate the loss. Whole region will be taken into account if not provided. Defaults to None. Returns: torch.Tensor: The loss value. """ assert pred.max() <= 1 and pred.min() >= 0 return super().forward(pred, gt, mask)