File size: 1,496 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple

import torch
from mmengine.model import BaseModule
from torch import nn

from mmpretrain.registry import MODELS


@MODELS.register_module()
class CAELoss(BaseModule):
    """Loss function for CAE.

    Compute the align loss and the main loss.

    Args:
        lambd (float): The weight for the align loss.
    """

    def __init__(self, lambd: float) -> None:
        super().__init__()
        self.lambd = lambd
        self.loss_cross_entropy = nn.CrossEntropyLoss()
        self.loss_mse = nn.MSELoss()

    def forward(
            self, logits: torch.Tensor, target: torch.Tensor,
            latent_pred: torch.Tensor,
            latent_target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """Forward function of CAE Loss.

        Args:
            logits (torch.Tensor): The outputs from the decoder.
            target (torch.Tensor): The targets generated by dalle.
            latent_pred (torch.Tensor): The latent prediction from the
                regressor.
            latent_target (torch.Tensor): The latent target from the teacher
                network.

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: The main loss and align loss.
        """
        loss_main = self.loss_cross_entropy(logits, target)
        loss_align = self.loss_mse(latent_pred,
                                   latent_target.detach()) * self.lambd

        return loss_main, loss_align