Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
# @Time : 2024/8/1 下午2:45 | |
# @Author : xiaoshun | |
# @Email : [email protected] | |
# @File : cdnetv2_loss.py | |
# @Software: PyCharm | |
import torch | |
import torch.nn as nn | |
class CDnetv2Loss(nn.Module): | |
def __init__(self, loss_fn: nn.Module) -> None: | |
super().__init__() | |
self.loss_fn = loss_fn | |
def forward(self, logits: torch.Tensor, logits_aux,target: torch.Tensor) -> torch.Tensor: | |
loss = self.loss_fn(logits, target) | |
loss_aux = self.loss_fn(logits_aux, target) | |
total_loss = loss + loss_aux | |
return total_loss | |