xylieong glenn-jocher commited on
Commit
f43cd53
·
unverified ·
1 Parent(s): fe1b503

Added Windows cmd to count GPU devices (#7891)

Browse files

* Added Windows cmd to count GPU devices

* Cleanup

Co-authored-by: Glenn Jocher <[email protected]>

Files changed (1) hide show
  1. utils/torch_utils.py +3 -3
utils/torch_utils.py CHANGED
@@ -40,10 +40,10 @@ def torch_distributed_zero_first(local_rank: int):
40
 
41
 
42
  def device_count():
43
- # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Only works on Linux.
44
- assert platform.system() == 'Linux', 'device_count() function only works on Linux'
45
  try:
46
- cmd = 'nvidia-smi -L | wc -l'
47
  return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
48
  except Exception:
49
  return 0
 
40
 
41
 
42
  def device_count():
43
+ # Returns number of CUDA devices available. Safe version of torch.cuda.device_count(). Supports Linux and Windows
44
+ assert platform.system() in ('Linux', 'Windows'), 'device_count() only supported on Linux or Windows'
45
  try:
46
+ cmd = 'nvidia-smi -L | wc -l' if platform.system() == 'Linux' else 'nvidia-smi -L | find /c /v ""' # Windows
47
  return int(subprocess.run(cmd, shell=True, capture_output=True, check=True).stdout.decode().split()[-1])
48
  except Exception:
49
  return 0