Make `select_device()` robust to `batch_size=-1` (#5940)
Browse files* Find out a bug. When set batch_size = -1 to use the autobatch.
reproduce:
* Fix type conflict
Co-authored-by: Glenn Jocher <[email protected]>
- utils/torch_utils.py +2 -2
utils/torch_utils.py
CHANGED
@@ -53,7 +53,7 @@ def git_describe(path=Path(__file__).parent): # path must be a directory
|
|
53 |
return '' # not a git repository
|
54 |
|
55 |
|
56 |
-
def select_device(device='', batch_size=
|
57 |
# device = 'cpu' or '0' or '0,1,2,3'
|
58 |
s = f'YOLOv5 π {git_describe() or date_modified()} torch {torch.__version__} ' # string
|
59 |
device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
|
@@ -68,7 +68,7 @@ def select_device(device='', batch_size=None, newline=True):
|
|
68 |
if cuda:
|
69 |
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
70 |
n = len(devices) # device count
|
71 |
-
if n > 1 and batch_size: # check batch_size is divisible by device_count
|
72 |
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
73 |
space = ' ' * (len(s) + 1)
|
74 |
for i, d in enumerate(devices):
|
|
|
53 |
return '' # not a git repository
|
54 |
|
55 |
|
56 |
+
def select_device(device='', batch_size=0, newline=True):
|
57 |
# device = 'cpu' or '0' or '0,1,2,3'
|
58 |
s = f'YOLOv5 π {git_describe() or date_modified()} torch {torch.__version__} ' # string
|
59 |
device = str(device).strip().lower().replace('cuda:', '') # to string, 'cuda:0' to '0'
|
|
|
68 |
if cuda:
|
69 |
devices = device.split(',') if device else '0' # range(torch.cuda.device_count()) # i.e. 0,1,6,7
|
70 |
n = len(devices) # device count
|
71 |
+
if n > 1 and batch_size > 0: # check batch_size is divisible by device_count
|
72 |
assert batch_size % n == 0, f'batch-size {batch_size} not multiple of GPU count {n}'
|
73 |
space = ' ' * (len(s) + 1)
|
74 |
for i, d in enumerate(devices):
|