topdu's picture
openocr demo
29f689c
raw
history blame
1.32 kB
import torch
from torch import nn
class CTCLoss(nn.Module):
def __init__(self, use_focal_loss=False, zero_infinity=False, **kwargs):
super(CTCLoss, self).__init__()
self.loss_func = nn.CTCLoss(blank=0,
reduction='none',
zero_infinity=zero_infinity)
self.use_focal_loss = use_focal_loss
def forward(self, predicts, batch):
# predicts = predicts['res']
batch_size = predicts.size(0)
label, label_length = batch[1], batch[2]
predicts = predicts.log_softmax(2)
predicts = predicts.permute(1, 0, 2)
preds_lengths = torch.tensor([predicts.size(0)] * batch_size,
dtype=torch.long)
loss = self.loss_func(predicts, label, preds_lengths, label_length)
if self.use_focal_loss:
# Use torch.clamp to limit the range of loss, avoiding overflow in exponential calculation
clamped_loss = torch.clamp(loss, min=-20, max=20)
weight = 1 - torch.exp(-clamped_loss)
weight = torch.square(weight)
# Use torch.where to avoid multiplying by zero weight
loss = torch.where(weight > 0, loss * weight, loss)
loss = loss.mean()
return {'loss': loss}