π [Merge] branch 'TRAIN' into TEST
Browse files
yolo/tools/loss_functions.py
CHANGED
@@ -39,9 +39,9 @@ class BoxLoss(nn.Module):
|
|
39 |
|
40 |
|
41 |
class DFLoss(nn.Module):
|
42 |
-
def __init__(self,
|
43 |
super().__init__()
|
44 |
-
self.anchors_norm =
|
45 |
self.reg_max = reg_max
|
46 |
|
47 |
def forward(
|
@@ -72,7 +72,7 @@ class YOLOLoss:
|
|
72 |
self.vec2box = vec2box
|
73 |
|
74 |
self.cls = BCELoss()
|
75 |
-
self.dfl = DFLoss(vec2box
|
76 |
self.iou = BoxLoss()
|
77 |
|
78 |
self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid)
|
|
|
39 |
|
40 |
|
41 |
class DFLoss(nn.Module):
|
42 |
+
def __init__(self, vec2box: Vec2Box, reg_max: int) -> None:
|
43 |
super().__init__()
|
44 |
+
self.anchors_norm = (vec2box.anchor_grid / vec2box.scaler[:, None])[None]
|
45 |
self.reg_max = reg_max
|
46 |
|
47 |
def forward(
|
|
|
72 |
self.vec2box = vec2box
|
73 |
|
74 |
self.cls = BCELoss()
|
75 |
+
self.dfl = DFLoss(vec2box, reg_max)
|
76 |
self.iou = BoxLoss()
|
77 |
|
78 |
self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid)
|