OpenOCR-Demo / openrec /losses /igtr_loss.py
topdu's picture
openocr demo
29f689c
raw
history blame
265 Bytes
from torch import nn
class IGTRLoss(nn.Module):
def __init__(self, **kwargs):
super(IGTRLoss, self).__init__()
def forward(self, predicts, batch):
if isinstance(predicts, list):
predicts = predicts[0]
return predicts