File size: 1,025 Bytes
dd78229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import os
import torch

"""
GPU wrappers
"""

use_gpu = False
gpu_id = 0
device = None

distributed = False
dist_rank = 0
world_size = 1


def set_gpu_mode(mode, pbs=False):
    global use_gpu
    global device
    global gpu_id
    global distributed
    global dist_rank
    global world_size
    if pbs:
        gpu_id = int(os.environ.get("MPI_LOCALRANKID", 0))
        dist_rank = int(os.environ.get("PMI_RANK", 0))
        world_size = int(os.environ.get("PMI_SIZE", 1))
    else:
        gpu_id = int(os.environ.get("SLURM_LOCALID", 0))
        dist_rank = int(os.environ.get("SLURM_PROCID", 0))
        world_size = int(os.environ.get("SLURM_NTASKS", 1))

    distributed = world_size > 1
    use_gpu = mode
    print('gpu_id: {}, dist_rank: {}, world_size: {}, distributed: {}'.format(gpu_id, dist_rank, world_size,
                                                                              distributed))
    device = torch.device(f"cuda:{gpu_id}" if use_gpu else "cpu")
    torch.backends.cudnn.benchmark = True