cloudseg / src /loss /cdnetv2_loss.py
caixiaoshun's picture
使用huggingface hub尝试更新
fa7be76 verified
raw
history blame contribute delete
611 Bytes
# -*- coding: utf-8 -*-
# @Time : 2024/8/1 下午2:45
# @Author : xiaoshun
# @Email : [email protected]
# @File : cdnetv2_loss.py
# @Software: PyCharm
import torch
import torch.nn as nn
class CDnetv2Loss(nn.Module):
def __init__(self, loss_fn: nn.Module) -> None:
super().__init__()
self.loss_fn = loss_fn
def forward(self, logits: torch.Tensor, logits_aux,target: torch.Tensor) -> torch.Tensor:
loss = self.loss_fn(logits, target)
loss_aux = self.loss_fn(logits_aux, target)
total_loss = loss + loss_aux
return total_loss