File size: 956 Bytes
caa56d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch.nn as nn
from .abstract_loss_func import AbstractLossClass
from metrics.registry import LOSSFUNC


@LOSSFUNC.register_module(module_name="capsule_loss")
class CapsuleLoss(AbstractLossClass):
    def __init__(self):
        super().__init__()
        self.cross_entropy_loss = nn.CrossEntropyLoss()

    def forward(self, inputs, targets):
        """
        Computes the capsule loss.

        Args:
            inputs: A PyTorch tensor of size (batch_size, num_classes) containing the predicted scores.
            targets: A PyTorch tensor of size (batch_size) containing the ground-truth class indices.

        Returns:
            A scalar tensor representing the capsule loss.
        """
        # Compute the capsule loss
        loss_t = self.cross_entropy_loss(inputs[:,0,:], targets)

        for i in range(inputs.size(1) - 1):
            loss_t = loss_t + self.cross_entropy_loss(inputs[:,i+1,:], targets)
        return loss_t