henry000 commited on
Commit
710e371
Β·
1 Parent(s): 80efe00

πŸ› [Fix] device bug, make test can run on cpu

Browse files
tests/test_utils/test_loss.py CHANGED
@@ -25,11 +25,9 @@ def loss_function(cfg) -> YOLOLoss:
25
 
26
  @pytest.fixture
27
  def data():
28
- [[torch.zeros]]
29
- targets = torch.zeros(20, 6, device=torch.device("cuda"))
30
- predicts = [
31
- [torch.zeros(1, 144, 80 // i, 80 // i, device=torch.device("cuda")) for i in [1, 2, 4]] for _ in range(2)
32
- ]
33
  return predicts, targets
34
 
35
 
 
25
 
26
  @pytest.fixture
27
  def data():
28
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
+ targets = torch.zeros(20, 6, device=device)
30
+ predicts = [[torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]] for _ in range(2)]
 
 
31
  return predicts, targets
32
 
33
 
yolo/utils/loss.py CHANGED
@@ -81,7 +81,7 @@ class YOLOLoss:
81
  self.class_num = cfg.hyper.data.class_num
82
  self.image_size = list(cfg.hyper.data.image_size)
83
  self.strides = cfg.model.anchor.strides
84
- device = torch.device("cuda")
85
 
86
  self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float16, device=device)
87
  self.scale_up = torch.tensor(self.image_size * 2, device=device)
 
81
  self.class_num = cfg.hyper.data.class_num
82
  self.image_size = list(cfg.hyper.data.image_size)
83
  self.strides = cfg.model.anchor.strides
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
 
86
  self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float16, device=device)
87
  self.scale_up = torch.tensor(self.image_size * 2, device=device)