|
import os |
|
import torch |
|
|
|
|
|
def get_world_size(): |
|
"""Find OMPI world size without calling mpi functions |
|
:rtype: int |
|
""" |
|
if os.environ.get('PMI_SIZE') is not None: |
|
return int(os.environ.get('PMI_SIZE') or 1) |
|
elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: |
|
return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) |
|
else: |
|
return torch.cuda.device_count() |
|
|
|
|
|
def get_global_rank(): |
|
"""Find OMPI world rank without calling mpi functions |
|
:rtype: int |
|
""" |
|
if os.environ.get('PMI_RANK') is not None: |
|
return int(os.environ.get('PMI_RANK') or 0) |
|
elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: |
|
return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) |
|
else: |
|
return 0 |
|
|
|
|
|
def get_local_rank(): |
|
"""Find OMPI local rank without calling mpi functions |
|
:rtype: int |
|
""" |
|
if os.environ.get('MPI_LOCALRANKID') is not None: |
|
return int(os.environ.get('MPI_LOCALRANKID') or 0) |
|
elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: |
|
return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) |
|
else: |
|
return 0 |
|
|
|
|
|
def get_master_ip(): |
|
if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: |
|
return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] |
|
elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: |
|
return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') |
|
else: |
|
return "127.0.0.1" |
|
|
|
|