topdu's picture
openocr demo
29f689c
raw
history blame
1.15 kB
from torch import nn
class MGPLoss(nn.Module):
def __init__(self, only_char=False, **kwargs):
super(MGPLoss, self).__init__()
self.ce = nn.CrossEntropyLoss(reduction='mean', ignore_index=0)
self.only_char = only_char
def forward(self, pred, batch):
if self.only_char:
char_feats = pred
char_tgt = batch[1].flatten(0, 1)
char_loss = self.ce(char_feats.flatten(0, 1), char_tgt)
return {'loss': char_loss}
else:
return self.forward_all(pred, batch)
def forward_all(self, pred, batch):
char_feats, dpe_feats, wp_feats = pred
char_tgt = batch[1].flatten(0, 1)
dpe_tgt = batch[2].flatten(0, 1)
wp_tgt = batch[3].flatten(0, 1)
char_loss = self.ce(char_feats.flatten(0, 1), char_tgt)
dpe_loss = self.ce(dpe_feats.flatten(0, 1), dpe_tgt)
wp_loss = self.ce(wp_feats.flatten(0, 1), wp_tgt)
loss = char_loss + dpe_loss + wp_loss
return {
'loss': loss,
'char_loss': char_loss,
'dpe_loss': dpe_loss,
'wp_loss': wp_loss
}