henry000 commited on
Commit
ba7a683
Β·
2 Parent(s): 5bbfada 36eb083

πŸ”€ [Merge] branch 'TRAIN' into TEST

Browse files
Files changed (1) hide show
  1. yolo/tools/loss_functions.py +3 -3
yolo/tools/loss_functions.py CHANGED
@@ -39,9 +39,9 @@ class BoxLoss(nn.Module):
39
 
40
 
41
  class DFLoss(nn.Module):
42
- def __init__(self, anchors_norm: Tensor, reg_max: int) -> None:
43
  super().__init__()
44
- self.anchors_norm = anchors_norm
45
  self.reg_max = reg_max
46
 
47
  def forward(
@@ -72,7 +72,7 @@ class YOLOLoss:
72
  self.vec2box = vec2box
73
 
74
  self.cls = BCELoss()
75
- self.dfl = DFLoss(vec2box.anchor_norm, reg_max)
76
  self.iou = BoxLoss()
77
 
78
  self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid)
 
39
 
40
 
41
  class DFLoss(nn.Module):
42
+ def __init__(self, vec2box: Vec2Box, reg_max: int) -> None:
43
  super().__init__()
44
+ self.anchors_norm = (vec2box.anchor_grid / vec2box.scaler[:, None])[None]
45
  self.reg_max = reg_max
46
 
47
  def forward(
 
72
  self.vec2box = vec2box
73
 
74
  self.cls = BCELoss()
75
+ self.dfl = DFLoss(vec2box, reg_max)
76
  self.iou = BoxLoss()
77
 
78
  self.matcher = BoxMatcher(loss_cfg.matcher, self.class_num, vec2box.anchor_grid)