Commit
·
a1748a8
1
Parent(s):
bd3e389
test during training default to FP16
Browse files
test.py
CHANGED
@@ -23,6 +23,7 @@ def test(data,
|
|
23 |
verbose=False):
|
24 |
# Initialize/load model and set device
|
25 |
if model is None:
|
|
|
26 |
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
27 |
half = device.type != 'cpu' # half precision only supported on CUDA
|
28 |
|
@@ -42,11 +43,12 @@ def test(data,
|
|
42 |
if device.type != 'cpu' and torch.cuda.device_count() > 1:
|
43 |
model = nn.DataParallel(model)
|
44 |
|
45 |
-
training = False
|
46 |
else: # called by train.py
|
47 |
-
device = next(model.parameters()).device # get model device
|
48 |
-
half = False
|
49 |
training = True
|
|
|
|
|
|
|
|
|
50 |
|
51 |
# Configure
|
52 |
model.eval()
|
|
|
23 |
verbose=False):
|
24 |
# Initialize/load model and set device
|
25 |
if model is None:
|
26 |
+
training = False
|
27 |
device = torch_utils.select_device(opt.device, batch_size=batch_size)
|
28 |
half = device.type != 'cpu' # half precision only supported on CUDA
|
29 |
|
|
|
43 |
if device.type != 'cpu' and torch.cuda.device_count() > 1:
|
44 |
model = nn.DataParallel(model)
|
45 |
|
|
|
46 |
else: # called by train.py
|
|
|
|
|
47 |
training = True
|
48 |
+
device = next(model.parameters()).device # get model device
|
49 |
+
half = device.type != 'cpu' # half precision only supported on CUDA
|
50 |
+
if half:
|
51 |
+
model.half() # to FP16
|
52 |
|
53 |
# Configure
|
54 |
model.eval()
|