henry000 commited on
Commit
033231b
·
1 Parent(s): 73207bd

✅ [Pass] Train, Model, Loss Test

Browse files
tests/test_utils/test_loss.py CHANGED
@@ -27,14 +27,13 @@ def loss_function(cfg) -> YOLOLoss:
27
  def data():
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  targets = torch.zeros(1, 20, 5, device=device)
30
- predicts = [[torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]] for _ in range(2)]
31
  return predicts, targets
32
 
33
 
34
  def test_yolo_loss(loss_function, data):
35
  predicts, targets = data
36
- loss, (loss_iou, loss_dfl, loss_cls) = loss_function(predicts, targets)
37
- assert torch.isnan(loss)
38
  assert torch.isnan(loss_iou)
39
  assert torch.isnan(loss_dfl)
40
  assert torch.isinf(loss_cls)
 
27
  def data():
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  targets = torch.zeros(1, 20, 5, device=device)
30
+ predicts = [torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]]
31
  return predicts, targets
32
 
33
 
34
  def test_yolo_loss(loss_function, data):
35
  predicts, targets = data
36
+ loss_iou, loss_dfl, loss_cls = loss_function(predicts, targets)
 
37
  assert torch.isnan(loss_iou)
38
  assert torch.isnan(loss_dfl)
39
  assert torch.isinf(loss_cls)
yolo/utils/loss.py CHANGED
@@ -80,7 +80,7 @@ class YOLOLoss:
80
  self.strides = cfg.model.anchor.strides
81
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
 
83
- self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float16, device=device)
84
  self.scale_up = torch.tensor(self.image_size * 2, device=device)
85
 
86
  self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)
 
80
  self.strides = cfg.model.anchor.strides
81
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
82
 
83
+ self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float32, device=device)
84
  self.scale_up = torch.tensor(self.image_size * 2, device=device)
85
 
86
  self.anchors, self.scaler = make_anchor(self.image_size, self.strides, device)