Spaces:
Running
Running
File size: 2,405 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Union
import torch
import torch.nn as nn
from mmocr import digit_version
from mmocr.registry import MODELS
@MODELS.register_module()
class SmoothL1Loss(nn.SmoothL1Loss):
"""Smooth L1 loss."""
@MODELS.register_module()
class MaskedSmoothL1Loss(nn.Module):
"""Masked Smooth L1 loss.
Args:
beta (float, optional): The threshold in the piecewise function.
Defaults to 1.
eps (float, optional): Eps to avoid zero-division error. Defaults to
1e-6.
"""
def __init__(self, beta: Union[float, int] = 1, eps: float = 1e-6) -> None:
super().__init__()
if digit_version(torch.__version__) > digit_version('1.6.0'):
if digit_version(torch.__version__) >= digit_version(
'1.13.0') and beta == 0:
beta = beta + eps
self.smooth_l1_loss = nn.SmoothL1Loss(beta=beta, reduction='none')
self.eps = eps
self.beta = beta
def forward(self,
pred: torch.Tensor,
gt: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward function.
Args:
pred (torch.Tensor): The prediction in any shape.
gt (torch.Tensor): The learning target of the prediction in the
same shape as pred.
mask (torch.Tensor, optional): Binary mask in the same shape of
pred, indicating positive regions to calculate the loss. Whole
region will be taken into account if not provided. Defaults to
None.
Returns:
torch.Tensor: The loss value.
"""
assert pred.size() == gt.size() and gt.numel() > 0
if mask is None:
mask = torch.ones_like(gt).bool()
assert mask.size() == gt.size()
x = pred * mask
y = gt * mask
if digit_version(torch.__version__) > digit_version('1.6.0'):
loss = self.smooth_l1_loss(x, y)
else:
loss = torch.zeros_like(gt)
diff = torch.abs(x - y)
mask_beta = diff < self.beta
loss[mask_beta] = 0.5 * torch.square(diff)[mask_beta] / self.beta
loss[~mask_beta] = diff[~mask_beta] - 0.5 * self.beta
return loss.sum() / (mask.sum() + self.eps)
|