File size: 479 Bytes
caa56d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
import torch.nn as nn
class AbstractLossClass(nn.Module):
"""Abstract class for loss functions."""
def __init__(self):
super(AbstractLossClass, self).__init__()
def forward(self, pred, label):
"""
Args:
pred: prediction of the model
label: ground truth label
Return:
loss: loss value
"""
raise NotImplementedError('Each subclass should implement the forward method.')
|