π [Fix] device bug, make test can run on cpu
Browse files- tests/test_utils/test_loss.py +3 -5
- yolo/utils/loss.py +1 -1
tests/test_utils/test_loss.py
CHANGED
@@ -25,11 +25,9 @@ def loss_function(cfg) -> YOLOLoss:
|
|
25 |
|
26 |
@pytest.fixture
|
27 |
def data():
|
28 |
-
|
29 |
-
targets = torch.zeros(20, 6, device=
|
30 |
-
predicts = [
|
31 |
-
[torch.zeros(1, 144, 80 // i, 80 // i, device=torch.device("cuda")) for i in [1, 2, 4]] for _ in range(2)
|
32 |
-
]
|
33 |
return predicts, targets
|
34 |
|
35 |
|
|
|
25 |
|
26 |
@pytest.fixture
|
27 |
def data():
|
28 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
29 |
+
targets = torch.zeros(20, 6, device=device)
|
30 |
+
predicts = [[torch.zeros(1, 144, 80 // i, 80 // i, device=device) for i in [1, 2, 4]] for _ in range(2)]
|
|
|
|
|
31 |
return predicts, targets
|
32 |
|
33 |
|
yolo/utils/loss.py
CHANGED
@@ -81,7 +81,7 @@ class YOLOLoss:
|
|
81 |
self.class_num = cfg.hyper.data.class_num
|
82 |
self.image_size = list(cfg.hyper.data.image_size)
|
83 |
self.strides = cfg.model.anchor.strides
|
84 |
-
device = torch.device("cuda")
|
85 |
|
86 |
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float16, device=device)
|
87 |
self.scale_up = torch.tensor(self.image_size * 2, device=device)
|
|
|
81 |
self.class_num = cfg.hyper.data.class_num
|
82 |
self.image_size = list(cfg.hyper.data.image_size)
|
83 |
self.strides = cfg.model.anchor.strides
|
84 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
85 |
|
86 |
self.reverse_reg = torch.arange(self.reg_max, dtype=torch.float16, device=device)
|
87 |
self.scale_up = torch.tensor(self.image_size * 2, device=device)
|