OpenOCR-Demo / openrec /losses /abinet_loss.py
topdu's picture
openocr demo
29f689c
import torch
from torch import nn
class ABINetLoss(nn.Module):
def __init__(self,
smoothing=False,
ignore_index=100,
align_weight=1.0,
**kwargs):
super(ABINetLoss, self).__init__()
if ignore_index >= 0:
self.loss_func = nn.CrossEntropyLoss(reduction='mean',
ignore_index=ignore_index)
else:
self.loss_func = nn.CrossEntropyLoss(reduction='mean')
self.smoothing = smoothing
self.align_weight = align_weight
def forward(self, pred, batch):
loss = {}
loss_sum = []
for name, logits in pred.items():
if isinstance(logits, list):
logit_num = len(logits)
if logit_num > 0:
all_tgt = torch.cat([batch[1]] * logit_num, 0)
all_logits = torch.cat(logits, 0)
flt_logtis = all_logits.reshape([-1, all_logits.shape[2]])
flt_tgt = all_tgt.reshape([-1])
else:
continue
else:
flt_logtis = logits.reshape([-1, logits.shape[2]])
flt_tgt = batch[1].reshape([-1])
loss[name + '_loss'] = self.loss_func(flt_logtis, flt_tgt) * (
self.align_weight if name == 'align' else 1.0)
loss_sum.append(loss[name + '_loss'])
loss['loss'] = sum(loss_sum)
return loss