caixiaoshun commited on
Commit
ae47333
·
1 Parent(s): 799cb75

新增loss模块

Browse files
Files changed (2) hide show
  1. src/loss/__init__.py +6 -0
  2. src/loss/cdnetv2_loss.py +20 -0
src/loss/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/8/1 下午2:45
3
+ # @Author : xiaoshun
4
+ # @Email : [email protected]
5
+ # @File : __init__.py.py
6
+ # @Software: PyCharm
src/loss/cdnetv2_loss.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # @Time : 2024/8/1 下午2:45
3
+ # @Author : xiaoshun
4
+ # @Email : [email protected]
5
+ # @File : cdnetv2_loss.py
6
+ # @Software: PyCharm
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ class CDnetv2Loss(nn.Module):
12
+ def __init__(self, loss_fn: nn.Module) -> None:
13
+ super().__init__()
14
+ self.loss_fn = loss_fn
15
+
16
+ def forward(self, logits: torch.Tensor, logits_aux,target: torch.Tensor) -> torch.Tensor:
17
+ loss = self.loss_fn(logits, target)
18
+ loss_aux = self.loss_fn(logits_aux, target)
19
+ total_loss = loss + loss_aux
20
+ return total_loss