Commit
·
d1e5716
1
Parent(s):
1dfc285
train with multi-gpu half test bug fix #99
Browse files
test.py
CHANGED
@@ -46,7 +46,7 @@ def test(data,
|
|
46 |
else: # called by train.py
|
47 |
training = True
|
48 |
device = next(model.parameters()).device # get model device
|
49 |
-
half = device.type != 'cpu' # half precision only supported on
|
50 |
if half:
|
51 |
model.half() # to FP16
|
52 |
|
|
|
46 |
else: # called by train.py
|
47 |
training = True
|
48 |
device = next(model.parameters()).device # get model device
|
49 |
+
half = device.type != 'cpu' and torch.cuda.device_count() == 1 # half precision only supported on single-GPU
|
50 |
if half:
|
51 |
model.half() # to FP16
|
52 |
|