junbiao.chen
Trellis update
cc0c59d
raw
history blame contribute delete
2.53 kB
import os
import io
from contextlib import contextmanager
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup_dist(rank, local_rank, world_size, master_addr, master_port):
os.environ['MASTER_ADDR'] = master_addr
os.environ['MASTER_PORT'] = master_port
os.environ['WORLD_SIZE'] = str(world_size)
os.environ['RANK'] = str(rank)
os.environ['LOCAL_RANK'] = str(local_rank)
torch.cuda.set_device(local_rank)
dist.init_process_group('nccl', rank=rank, world_size=world_size)
def read_file_dist(path):
"""
Read the binary file distributedly.
File is only read once by the rank 0 process and broadcasted to other processes.
Returns:
data (io.BytesIO): The binary data read from the file.
"""
if dist.is_initialized() and dist.get_world_size() > 1:
# read file
size = torch.LongTensor(1).cuda()
if dist.get_rank() == 0:
with open(path, 'rb') as f:
data = f.read()
data = torch.ByteTensor(
torch.UntypedStorage.from_buffer(data, dtype=torch.uint8)
).cuda()
size[0] = data.shape[0]
# broadcast size
dist.broadcast(size, src=0)
if dist.get_rank() != 0:
data = torch.ByteTensor(size[0].item()).cuda()
# broadcast data
dist.broadcast(data, src=0)
# convert to io.BytesIO
data = data.cpu().numpy().tobytes()
data = io.BytesIO(data)
return data
else:
with open(path, 'rb') as f:
data = f.read()
data = io.BytesIO(data)
return data
def unwrap_dist(model):
"""
Unwrap the model from distributed training.
"""
if isinstance(model, DDP):
return model.module
return model
@contextmanager
def master_first():
"""
A context manager that ensures master process executes first.
"""
if not dist.is_initialized():
yield
else:
if dist.get_rank() == 0:
yield
dist.barrier()
else:
dist.barrier()
yield
@contextmanager
def local_master_first():
"""
A context manager that ensures local master process executes first.
"""
if not dist.is_initialized():
yield
else:
if dist.get_rank() % torch.cuda.device_count() == 0:
yield
dist.barrier()
else:
dist.barrier()
yield