aroraaman's picture
Add all of `fourm`
3424266
raw
history blame
2.99 kB
# Copyright 2024 EPFL and Apple Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------
# Based on DETR and MMCV code bases
# https://github.com/facebookresearch/detr
# https://github.com/open-mmlab/mmcv
# --------------------------------------------------------
import os
import pickle
import shutil
import sys
import tempfile
import datetime
import torch
import torch.distributed as dist
def setup_for_distributed(is_main):
"""
This function disables printing when not in main process
"""
import builtins as __builtin__
builtin_print = __builtin__.print
def print(*args, **kwargs):
force = kwargs.pop('force', False)
if is_main or force or kwargs.get('file', None) == sys.stderr:
builtin_print(*args, **kwargs)
__builtin__.print = print
def is_dist_avail_and_initialized():
if not dist.is_available():
return False
if not dist.is_initialized():
return False
return True
def get_world_size():
if not is_dist_avail_and_initialized():
return 1
return dist.get_world_size()
def get_rank():
if not is_dist_avail_and_initialized():
return 0
return dist.get_rank()
def is_main_process():
return get_rank() == 0
def save_on_main(*args, **kwargs):
if is_main_process():
torch.save(*args, **kwargs)
def save_on_all(*args, **kwargs):
torch.save(*args, **kwargs)
def init_distributed_mode(args):
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.gpu = int(os.environ['LOCAL_RANK'])
else:
print('Not using distributed mode')
args.distributed = False
return
args.distributed = True
torch.cuda.set_device(args.gpu)
args.dist_backend = 'nccl'
print('| distributed init (rank {}): {}, gpu {}'.format(
args.rank, args.dist_url, args.gpu), flush=True)
# Set timeout to 1h20 in case some long download of dataset has to happen
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank,
timeout=datetime.timedelta(4800))
torch.distributed.barrier()
if ("print_all" not in args) or (not args.print_all):
setup_for_distributed(args.rank == 0)