import torch.nn as nn | |
from .abstract_loss_func import AbstractLossClass | |
from metrics.registry import LOSSFUNC | |
class CapsuleLoss(AbstractLossClass): | |
def __init__(self): | |
super().__init__() | |
self.cross_entropy_loss = nn.CrossEntropyLoss() | |
def forward(self, inputs, targets): | |
""" | |
Computes the capsule loss. | |
Args: | |
inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores. | |
targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices. | |
Returns: | |
A scalar tensor representing the capsule loss. | |
""" | |
# Compute the capsule loss | |
loss_t = self.cross_entropy_loss(inputs[:,0,:], targets) | |
for i in range(inputs.size(1) - 1): | |
loss_t = loss_t + self.cross_entropy_loss(inputs[:,i+1,:], targets) | |
return loss_t | |