File size: 1,025 Bytes
dd78229 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
import os
import torch
"""
GPU wrappers
"""
use_gpu = False
gpu_id = 0
device = None
distributed = False
dist_rank = 0
world_size = 1
def set_gpu_mode(mode, pbs=False):
global use_gpu
global device
global gpu_id
global distributed
global dist_rank
global world_size
if pbs:
gpu_id = int(os.environ.get("MPI_LOCALRANKID", 0))
dist_rank = int(os.environ.get("PMI_RANK", 0))
world_size = int(os.environ.get("PMI_SIZE", 1))
else:
gpu_id = int(os.environ.get("SLURM_LOCALID", 0))
dist_rank = int(os.environ.get("SLURM_PROCID", 0))
world_size = int(os.environ.get("SLURM_NTASKS", 1))
distributed = world_size > 1
use_gpu = mode
print('gpu_id: {}, dist_rank: {}, world_size: {}, distributed: {}'.format(gpu_id, dist_rank, world_size,
distributed))
device = torch.device(f"cuda:{gpu_id}" if use_gpu else "cpu")
torch.backends.cudnn.benchmark = True
|