Spaces:
Sleeping
Sleeping
File size: 611 Bytes
fa7be76 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
# -*- coding: utf-8 -*-
# @Time : 2024/8/1 下午2:45
# @Author : xiaoshun
# @Email : [email protected]
# @File : cdnetv2_loss.py
# @Software: PyCharm
import torch
import torch.nn as nn
class CDnetv2Loss(nn.Module):
def __init__(self, loss_fn: nn.Module) -> None:
super().__init__()
self.loss_fn = loss_fn
def forward(self, logits: torch.Tensor, logits_aux,target: torch.Tensor) -> torch.Tensor:
loss = self.loss_fn(logits, target)
loss_aux = self.loss_fn(logits_aux, target)
total_loss = loss + loss_aux
return total_loss
|