Spaces:
Running
Running
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional, Union | |
import torch | |
import torch.nn as nn | |
from mmocr.registry import MODELS | |
class MaskedBalancedBCEWithLogitsLoss(nn.Module): | |
"""This loss combines a Sigmoid layers and a masked balanced BCE loss in | |
one single class. It's AMP-eligible. | |
Args: | |
reduction (str, optional): The method to reduce the loss. | |
Options are 'none', 'mean' and 'sum'. Defaults to 'none'. | |
negative_ratio (float or int, optional): Maximum ratio of negative | |
samples to positive ones. Defaults to 3. | |
fallback_negative_num (int, optional): When the mask contains no | |
positive samples, the number of negative samples to be sampled. | |
Defaults to 0. | |
eps (float, optional): Eps to avoid zero-division error. Defaults to | |
1e-6. | |
""" | |
def __init__(self, | |
reduction: str = 'none', | |
negative_ratio: Union[float, int] = 3, | |
fallback_negative_num: int = 0, | |
eps: float = 1e-6) -> None: | |
super().__init__() | |
assert reduction in ['none', 'mean', 'sum'] | |
assert isinstance(negative_ratio, (float, int)) | |
assert isinstance(fallback_negative_num, int) | |
assert isinstance(eps, float) | |
self.eps = eps | |
self.negative_ratio = negative_ratio | |
self.reduction = reduction | |
self.fallback_negative_num = fallback_negative_num | |
self.loss = nn.BCEWithLogitsLoss(reduction=reduction) | |
def forward(self, | |
pred: torch.Tensor, | |
gt: torch.Tensor, | |
mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
"""Forward function. | |
Args: | |
pred (torch.Tensor): The prediction in any shape. | |
gt (torch.Tensor): The learning target of the prediction in the | |
same shape as pred. | |
mask (torch.Tensor, optional): Binary mask in the same shape of | |
pred, indicating positive regions to calculate the loss. Whole | |
region will be taken into account if not provided. Defaults to | |
None. | |
Returns: | |
torch.Tensor: The loss value. | |
""" | |
assert pred.size() == gt.size() and gt.numel() > 0 | |
if mask is None: | |
mask = torch.ones_like(gt) | |
assert mask.size() == gt.size() | |
positive = (gt * mask).float() | |
negative = ((1 - gt) * mask).float() | |
positive_count = int(positive.sum()) | |
if positive_count == 0: | |
negative_count = min( | |
int(negative.sum()), self.fallback_negative_num) | |
else: | |
negative_count = min( | |
int(negative.sum()), int(positive_count * self.negative_ratio)) | |
assert gt.max() <= 1 and gt.min() >= 0 | |
loss = self.loss(pred, gt) | |
positive_loss = loss * positive | |
negative_loss = loss * negative | |
negative_loss, _ = torch.topk(negative_loss.view(-1), negative_count) | |
balance_loss = (positive_loss.sum() + negative_loss.sum()) / ( | |
positive_count + negative_count + self.eps) | |
return balance_loss | |
class MaskedBalancedBCELoss(MaskedBalancedBCEWithLogitsLoss): | |
"""Masked Balanced BCE loss. | |
Args: | |
reduction (str, optional): The method to reduce the loss. | |
Options are 'none', 'mean' and 'sum'. Defaults to 'none'. | |
negative_ratio (float or int): Maximum ratio of negative | |
samples to positive ones. Defaults to 3. | |
fallback_negative_num (int): When the mask contains no | |
positive samples, the number of negative samples to be sampled. | |
Defaults to 0. | |
eps (float): Eps to avoid zero-division error. Defaults to | |
1e-6. | |
""" | |
def __init__(self, | |
reduction: str = 'none', | |
negative_ratio: Union[float, int] = 3, | |
fallback_negative_num: int = 0, | |
eps: float = 1e-6) -> None: | |
super().__init__() | |
assert reduction in ['none', 'mean', 'sum'] | |
assert isinstance(negative_ratio, (float, int)) | |
assert isinstance(fallback_negative_num, int) | |
assert isinstance(eps, float) | |
self.eps = eps | |
self.negative_ratio = negative_ratio | |
self.reduction = reduction | |
self.fallback_negative_num = fallback_negative_num | |
self.loss = nn.BCELoss(reduction=reduction) | |
def forward(self, | |
pred: torch.Tensor, | |
gt: torch.Tensor, | |
mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
"""Forward function. | |
Args: | |
pred (torch.Tensor): The prediction in any shape. | |
gt (torch.Tensor): The learning target of the prediction in the | |
same shape as pred. | |
mask (torch.Tensor, optional): Binary mask in the same shape of | |
pred, indicating positive regions to calculate the loss. Whole | |
region will be taken into account if not provided. Defaults to | |
None. | |
Returns: | |
torch.Tensor: The loss value. | |
""" | |
assert pred.max() <= 1 and pred.min() >= 0 | |
return super().forward(pred, gt, mask) | |
class MaskedBCEWithLogitsLoss(nn.Module): | |
"""This loss combines a Sigmoid layers and a masked BCE loss in one single | |
class. It's AMP-eligible. | |
Args: | |
eps (float): Eps to avoid zero-division error. Defaults to | |
1e-6. | |
""" | |
def __init__(self, eps: float = 1e-6) -> None: | |
super().__init__() | |
assert isinstance(eps, float) | |
self.eps = eps | |
self.loss = nn.BCEWithLogitsLoss(reduction='none') | |
def forward(self, | |
pred: torch.Tensor, | |
gt: torch.Tensor, | |
mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
"""Forward function. | |
Args: | |
pred (torch.Tensor): The prediction in any shape. | |
gt (torch.Tensor): The learning target of the prediction in the | |
same shape as pred. | |
mask (torch.Tensor, optional): Binary mask in the same shape of | |
pred, indicating positive regions to calculate the loss. Whole | |
region will be taken into account if not provided. Defaults to | |
None. | |
Returns: | |
torch.Tensor: The loss value. | |
""" | |
assert pred.size() == gt.size() and gt.numel() > 0 | |
if mask is None: | |
mask = torch.ones_like(gt) | |
assert mask.size() == gt.size() | |
assert gt.max() <= 1 and gt.min() >= 0 | |
loss = self.loss(pred, gt) | |
return (loss * mask).sum() / (mask.sum() + self.eps) | |
class MaskedBCELoss(MaskedBCEWithLogitsLoss): | |
"""Masked BCE loss. | |
Args: | |
eps (float): Eps to avoid zero-division error. Defaults to | |
1e-6. | |
""" | |
def __init__(self, eps: float = 1e-6) -> None: | |
super().__init__() | |
assert isinstance(eps, float) | |
self.eps = eps | |
self.loss = nn.BCELoss(reduction='none') | |
def forward(self, | |
pred: torch.Tensor, | |
gt: torch.Tensor, | |
mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
"""Forward function. | |
Args: | |
pred (torch.Tensor): The prediction in any shape. | |
gt (torch.Tensor): The learning target of the prediction in the | |
same shape as pred. | |
mask (torch.Tensor, optional): Binary mask in the same shape of | |
pred, indicating positive regions to calculate the loss. Whole | |
region will be taken into account if not provided. Defaults to | |
None. | |
Returns: | |
torch.Tensor: The loss value. | |
""" | |
assert pred.max() <= 1 and pred.min() >= 0 | |
return super().forward(pred, gt, mask) | |