Spaces:
Running
Running
from torch import nn | |
class IGTRLoss(nn.Module): | |
def __init__(self, **kwargs): | |
super(IGTRLoss, self).__init__() | |
def forward(self, predicts, batch): | |
if isinstance(predicts, list): | |
predicts = predicts[0] | |
return predicts | |