Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Tuple | |
import torch | |
from mmengine.model import BaseModule | |
from torch import nn | |
from mmpretrain.registry import MODELS | |
class CAELoss(BaseModule): | |
"""Loss function for CAE. | |
Compute the align loss and the main loss. | |
Args: | |
lambd (float): The weight for the align loss. | |
""" | |
def __init__(self, lambd: float) -> None: | |
super().__init__() | |
self.lambd = lambd | |
self.loss_cross_entropy = nn.CrossEntropyLoss() | |
self.loss_mse = nn.MSELoss() | |
def forward( | |
self, logits: torch.Tensor, target: torch.Tensor, | |
latent_pred: torch.Tensor, | |
latent_target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Forward function of CAE Loss. | |
Args: | |
logits (torch.Tensor): The outputs from the decoder. | |
target (torch.Tensor): The targets generated by dalle. | |
latent_pred (torch.Tensor): The latent prediction from the | |
regressor. | |
latent_target (torch.Tensor): The latent target from the teacher | |
network. | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: The main loss and align loss. | |
""" | |
loss_main = self.loss_cross_entropy(logits, target) | |
loss_align = self.loss_mse(latent_pred, | |
latent_target.detach()) * self.lambd | |
return loss_main, loss_align | |