Spaces:
Sleeping
Sleeping
File size: 3,432 Bytes
9bf4bd7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
import torch.nn as nn
from mmocr.registry import MODELS
@MODELS.register_module()
class MaskedDiceLoss(nn.Module):
"""Masked dice loss.
Args:
eps (float, optional): Eps to avoid zero-divison error. Defaults to
1e-6.
"""
def __init__(self, eps: float = 1e-6) -> None:
super().__init__()
assert isinstance(eps, float)
self.eps = eps
def forward(self,
pred: torch.Tensor,
gt: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward function.
Args:
pred (torch.Tensor): The prediction in any shape.
gt (torch.Tensor): The learning target of the prediction in the
same shape as pred.
mask (torch.Tensor, optional): Binary mask in the same shape of
pred, indicating positive regions to calculate the loss. Whole
region will be taken into account if not provided. Defaults to
None.
Returns:
torch.Tensor: The loss value.
"""
assert pred.size() == gt.size() and gt.numel() > 0
if mask is None:
mask = torch.ones_like(gt)
assert mask.size() == gt.size()
pred = pred.contiguous().view(pred.size(0), -1)
gt = gt.contiguous().view(gt.size(0), -1)
mask = mask.contiguous().view(mask.size(0), -1)
pred = pred * mask
gt = gt * mask
dice_coeff = (2 * (pred * gt).sum()) / (
pred.sum() + gt.sum() + self.eps)
return 1 - dice_coeff
@MODELS.register_module()
class MaskedSquareDiceLoss(nn.Module):
"""Masked square dice loss.
Args:
eps (float, optional): Eps to avoid zero-divison error. Defaults to
1e-3.
"""
def __init__(self, eps: float = 1e-3) -> None:
super().__init__()
assert isinstance(eps, float)
self.eps = eps
def forward(self,
pred: torch.Tensor,
gt: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward function.
Args:
pred (torch.Tensor): The prediction in any shape.
gt (torch.Tensor): The learning target of the prediction in the
same shape as pred.
mask (torch.Tensor, optional): Binary mask in the same shape of
pred, indicating positive regions to calculate the loss. Whole
region will be taken into account if not provided. Defaults to
None.
Returns:
torch.Tensor: The loss value.
"""
assert pred.size() == gt.size() and gt.numel() > 0
if mask is None:
mask = torch.ones_like(gt)
assert mask.size() == gt.size()
batch_size = pred.size(0)
pred = pred.contiguous().view(batch_size, -1)
gt = gt.contiguous().view(batch_size, -1).float()
mask = mask.contiguous().view(batch_size, -1).float()
pred = pred * mask
gt = gt * mask
a = torch.sum(pred * gt, dim=1)
b = torch.sum(pred * pred, dim=1) + self.eps
c = torch.sum(gt * gt, dim=1) + self.eps
d = (2 * a) / (b + c)
loss = 1 - d
loss = torch.mean(loss)
return loss
|