# Copyright 2024 EPFL and Apple Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------- # Based on DETR and MMCV code bases # https://github.com/facebookresearch/detr # https://github.com/open-mmlab/mmcv # -------------------------------------------------------- import os import pickle import shutil import sys import tempfile import datetime import torch import torch.distributed as dist def setup_for_distributed(is_main): """ This function disables printing when not in main process """ import builtins as __builtin__ builtin_print = __builtin__.print def print(*args, **kwargs): force = kwargs.pop('force', False) if is_main or force or kwargs.get('file', None) == sys.stderr: builtin_print(*args, **kwargs) __builtin__.print = print def is_dist_avail_and_initialized(): if not dist.is_available(): return False if not dist.is_initialized(): return False return True def get_world_size(): if not is_dist_avail_and_initialized(): return 1 return dist.get_world_size() def get_rank(): if not is_dist_avail_and_initialized(): return 0 return dist.get_rank() def is_main_process(): return get_rank() == 0 def save_on_main(*args, **kwargs): if is_main_process(): torch.save(*args, **kwargs) def save_on_all(*args, **kwargs): torch.save(*args, **kwargs) def init_distributed_mode(args): if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: args.rank = int(os.environ["RANK"]) args.world_size = int(os.environ['WORLD_SIZE']) args.gpu = int(os.environ['LOCAL_RANK']) else: print('Not using distributed mode') args.distributed = False return args.distributed = True torch.cuda.set_device(args.gpu) args.dist_backend = 'nccl' print('| distributed init (rank {}): {}, gpu {}'.format( args.rank, args.dist_url, args.gpu), flush=True) # Set timeout to 1h20 in case some long download of dataset has to happen torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank, timeout=datetime.timedelta(4800)) torch.distributed.barrier() if ("print_all" not in args) or (not args.print_all): setup_for_distributed(args.rank == 0)