File size: 611 Bytes
fa7be76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# -*- coding: utf-8 -*-
# @Time    : 2024/8/1 下午2:45
# @Author  : xiaoshun
# @Email   : [email protected]
# @File    : cdnetv2_loss.py
# @Software: PyCharm
import torch
import torch.nn as nn


class CDnetv2Loss(nn.Module):
    def __init__(self, loss_fn: nn.Module) -> None:
        super().__init__()
        self.loss_fn = loss_fn

    def forward(self, logits: torch.Tensor, logits_aux,target: torch.Tensor) -> torch.Tensor:
        loss = self.loss_fn(logits, target)
        loss_aux = self.loss_fn(logits_aux, target)
        total_loss = loss + loss_aux
        return total_loss