Spaces:
Runtime error
Runtime error
# Copyright 2024 EPFL and Apple Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# -------------------------------------------------------- | |
# Based on DETR and MMCV code bases | |
# https://github.com/facebookresearch/detr | |
# https://github.com/open-mmlab/mmcv | |
# -------------------------------------------------------- | |
import os | |
import pickle | |
import shutil | |
import sys | |
import tempfile | |
import datetime | |
import torch | |
import torch.distributed as dist | |
def setup_for_distributed(is_main): | |
""" | |
This function disables printing when not in main process | |
""" | |
import builtins as __builtin__ | |
builtin_print = __builtin__.print | |
def print(*args, **kwargs): | |
force = kwargs.pop('force', False) | |
if is_main or force or kwargs.get('file', None) == sys.stderr: | |
builtin_print(*args, **kwargs) | |
__builtin__.print = print | |
def is_dist_avail_and_initialized(): | |
if not dist.is_available(): | |
return False | |
if not dist.is_initialized(): | |
return False | |
return True | |
def get_world_size(): | |
if not is_dist_avail_and_initialized(): | |
return 1 | |
return dist.get_world_size() | |
def get_rank(): | |
if not is_dist_avail_and_initialized(): | |
return 0 | |
return dist.get_rank() | |
def is_main_process(): | |
return get_rank() == 0 | |
def save_on_main(*args, **kwargs): | |
if is_main_process(): | |
torch.save(*args, **kwargs) | |
def save_on_all(*args, **kwargs): | |
torch.save(*args, **kwargs) | |
def init_distributed_mode(args): | |
if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: | |
args.rank = int(os.environ["RANK"]) | |
args.world_size = int(os.environ['WORLD_SIZE']) | |
args.gpu = int(os.environ['LOCAL_RANK']) | |
else: | |
print('Not using distributed mode') | |
args.distributed = False | |
return | |
args.distributed = True | |
torch.cuda.set_device(args.gpu) | |
args.dist_backend = 'nccl' | |
print('| distributed init (rank {}): {}, gpu {}'.format( | |
args.rank, args.dist_url, args.gpu), flush=True) | |
# Set timeout to 1h20 in case some long download of dataset has to happen | |
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, | |
world_size=args.world_size, rank=args.rank, | |
timeout=datetime.timedelta(4800)) | |
torch.distributed.barrier() | |
if ("print_all" not in args) or (not args.print_all): | |
setup_for_distributed(args.rank == 0) |