Spaces:
Sleeping
Sleeping
Commit
·
ae47333
1
Parent(s):
799cb75
新增loss模块
Browse files- src/loss/__init__.py +6 -0
- src/loss/cdnetv2_loss.py +20 -0
src/loss/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/8/1 下午2:45
|
3 |
+
# @Author : xiaoshun
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @File : __init__.py.py
|
6 |
+
# @Software: PyCharm
|
src/loss/cdnetv2_loss.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
# @Time : 2024/8/1 下午2:45
|
3 |
+
# @Author : xiaoshun
|
4 |
+
# @Email : [email protected]
|
5 |
+
# @File : cdnetv2_loss.py
|
6 |
+
# @Software: PyCharm
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
|
10 |
+
|
11 |
+
class CDnetv2Loss(nn.Module):
|
12 |
+
def __init__(self, loss_fn: nn.Module) -> None:
|
13 |
+
super().__init__()
|
14 |
+
self.loss_fn = loss_fn
|
15 |
+
|
16 |
+
def forward(self, logits: torch.Tensor, logits_aux,target: torch.Tensor) -> torch.Tensor:
|
17 |
+
loss = self.loss_fn(logits, target)
|
18 |
+
loss_aux = self.loss_fn(logits_aux, target)
|
19 |
+
total_loss = loss + loss_aux
|
20 |
+
return total_loss
|