Added Windows cmd to count GPU devices (#7891)
Browse files* Added Windows cmd to count GPU devices
* Cleanup
Co-authored-by: Glenn Jocher <[email protected]>
- utils/torch_utils.py +3 -3
utils/torch_utils.py
CHANGED
@@ -40,10 +40,10 @@ def torch_distributed_zero_first(local_rank: int):
|
|
40 |
|
41 |
|
42 |
def device_count():
|
43 |
-
# Returns number of CUDA devices available. Safe version of torch.cuda.device_count().
|
44 |
-
assert platform.system()
|
45 |
try:
|
46 |
-
cmd = 'nvidia-smi -L | wc -l'
|
47 |
return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
|
48 |
except Exception:
|
49 |
return 0
|
|
|
40 |
|
41 |
|
42 |
def device_count():
|
43 |
+
# Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows
|
44 |
+
assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows'
|
45 |
try:
|
46 |
+
cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows
|
47 |
return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
|
48 |
except Exception:
|
49 |
return 0
|