File size: 822 Bytes
caa56d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
import torch.nn as nn
from .abstract_loss_func import AbstractLossClass
from metrics.registry import LOSSFUNC
@LOSSFUNC.register_module(module_name="cross_entropy")
class CrossEntropyLoss(AbstractLossClass):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, inputs, targets):
"""
Computes the cross-entropy loss.
Args:
inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores.
targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices.
Returns:
A scalar tensor representing the cross-entropy loss.
"""
# Compute the cross-entropy loss
loss = self.loss_fn(inputs, targets)
return loss |