|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Communications utilities.""" |
|
|
|
|
|
import torch |
|
|
|
from megatron import mpu |
|
|
|
|
|
|
|
|
|
def recv_from_prev_pipeline_rank_(recv_buffer=None): |
|
"""Receive from previous pipeline stage and update the |
|
input buffer inplace.""" |
|
if not mpu.is_pipeline_first_stage(): |
|
assert recv_buffer is not None |
|
recv_prev_op = torch.distributed.P2POp( |
|
torch.distributed.irecv, recv_buffer, |
|
mpu.get_pipeline_model_parallel_prev_rank()) |
|
reqs = torch.distributed.batch_isend_irecv([recv_prev_op]) |
|
for req in reqs: |
|
req.wait() |
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
|
def send_to_next_pipeline_rank(tensor=None): |
|
"""Send output to the next pipeline stage.""" |
|
if not mpu.is_pipeline_last_stage(): |
|
assert tensor is not None |
|
send_next_op = torch.distributed.P2POp( |
|
torch.distributed.isend, tensor, |
|
mpu.get_pipeline_model_parallel_next_rank()) |
|
reqs = torch.distributed.batch_isend_irecv([send_next_op]) |
|
for req in reqs: |
|
req.wait() |
|
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
def _is_cuda(tensor): |
|
"""Check if a tensor is not none and is cuda.""" |
|
assert tensor is not None |
|
assert tensor.is_cuda |
|
|
|
|
|
|
|
def _is_cuda_contiguous(tensor): |
|
"""Check if a tensor is not none, is cuda, and is contiguous.""" |
|
_is_cuda(tensor) |
|
assert tensor.is_contiguous() |
|
|
|
|
|
|
|
def broadcast_from_last_pipeline_stage(size, dtype, tensor=None): |
|
"""Broadcast a tensor from last pipeline stage to all ranks.""" |
|
|
|
is_last_stage = mpu.is_pipeline_last_stage() |
|
|
|
|
|
if mpu.is_pipeline_first_stage() and is_last_stage: |
|
return tensor |
|
|
|
if is_last_stage: |
|
_is_cuda_contiguous(tensor) |
|
else: |
|
tensor = torch.empty(size, |
|
dtype=dtype, |
|
device=torch.cuda.current_device()) |
|
|
|
src = mpu.get_pipeline_model_parallel_last_rank() |
|
group = mpu.get_pipeline_model_parallel_group() |
|
torch.distributed.broadcast(tensor, src, group) |
|
|
|
return tensor |
|
|
|
|
|
|
|
def broadcast_from_last_to_first_pipeline_stage(size, dtype, tensor=None): |
|
"""Broadcast tensor values from last stage into the first stage.""" |
|
|
|
is_last_stage = mpu.is_pipeline_last_stage() |
|
is_first_stage = mpu.is_pipeline_first_stage() |
|
|
|
|
|
if is_first_stage and is_last_stage: |
|
return tensor |
|
|
|
if is_last_stage or is_first_stage: |
|
if is_last_stage: |
|
_is_cuda_contiguous(tensor) |
|
else: |
|
tensor = torch.empty(size, |
|
dtype=dtype, |
|
device=torch.cuda.current_device()) |
|
src = mpu.get_pipeline_model_parallel_last_rank() |
|
group = mpu.get_embedding_group() |
|
|
|
torch.distributed.broadcast(tensor, src, group) |
|
else: |
|
tensor = None |
|
|
|
return tensor |
|
|
|
|
|
|
|
def copy_from_last_to_first_pipeline_stage(size, dtype, tensor=None): |
|
"""Copy tensor values from last stage into the first stage. |
|
Note that the input tensor is updated in place.""" |
|
|
|
is_last_stage = mpu.is_pipeline_last_stage() |
|
is_first_stage = mpu.is_pipeline_first_stage() |
|
|
|
|
|
if is_first_stage and is_last_stage: |
|
return |
|
|
|
if is_last_stage or is_first_stage: |
|
_is_cuda(tensor) |
|
is_contiguous = tensor.is_contiguous() |
|
src = mpu.get_pipeline_model_parallel_last_rank() |
|
group = mpu.get_embedding_group() |
|
if is_contiguous: |
|
tensor_ = tensor |
|
else: |
|
if is_last_stage: |
|
tensor_ = tensor.contiguous() |
|
else: |
|
tensor_ = torch.empty(size, |
|
dtype=dtype, |
|
device=torch.cuda.current_device()) |
|
|
|
torch.distributed.broadcast(tensor_, src, group) |
|
|
|
if is_first_stage and not is_contiguous: |
|
tensor[...] = tensor_ |
|
|
|
|
|
|
|
def broadcast_tensor(size, dtype, tensor=None, rank=0): |
|
""" Given size and type of a tensor on all ranks and the tensor value |
|
only on a specific rank, broadcast from that rank to all other ranks. |
|
""" |
|
|
|
if torch.distributed.get_rank() == rank: |
|
_is_cuda_contiguous(tensor) |
|
else: |
|
tensor = torch.empty(size, |
|
dtype=dtype, |
|
device=torch.cuda.current_device()) |
|
|
|
torch.distributed.broadcast(tensor, rank) |
|
|
|
return tensor |
|
|
|
|
|
|
|
def broadcast_list(size, dtype, list_values=None, rank=0): |
|
"""Broadcast a list of values with a given type.""" |
|
|
|
tensor = None |
|
if torch.distributed.get_rank() == rank: |
|
tensor = torch.tensor(list_values, dtype=dtype, |
|
device=torch.cuda.current_device()) |
|
|
|
return broadcast_tensor(size, dtype, tensor=tensor, rank=rank) |
|
|
|
|
|
|
|
def broadcast_int_list(size, int_list=None, rank=0): |
|
"""Broadcast a list of interger values.""" |
|
|
|
return broadcast_list(size, torch.int64, list_values=int_list, rank=rank) |
|
|
|
|
|
|
|
def broadcast_float_list(size, float_list=None, rank=0): |
|
"""Broadcast a list of float values.""" |
|
|
|
return broadcast_list(size, torch.float32, list_values=float_list, |
|
rank=rank) |
|
|