NEOX / megatron /mpu /layers.py
akswelh's picture
Upload 251 files
d90b3a8 verified
# Copyright (c) 2024, EleutherAI
# This file is based on code by the authors denoted below and has been modified from its original version.
#
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Parts of the code here are adapted from PyTorch
# repo: https://github.com/pytorch/pytorch
import math
import torch
import torch.nn.functional as F
import torch.nn.init as init
from torch.nn.parameter import Parameter
from .initialize import get_model_parallel_rank
from .initialize import get_model_parallel_world_size
from .mappings import copy_to_model_parallel_region
from .mappings import gather_from_model_parallel_region
from .mappings import reduce_from_model_parallel_region
from .mappings import scatter_to_model_parallel_region
from .mappings import reduce_scatter_to_sequence_parallel_region
from .mappings import gather_from_sequence_parallel_region
from .random import get_cuda_rng_tracker
from .utils import divide
from .utils import VocabUtility
from functools import partial
def _initialize_affine_weight_gpu(weight, init_method, partition_dim, stride=1):
"""Initialize affine weight for model parallel on GPU."""
weight.model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
with get_cuda_rng_tracker().fork():
init_method(weight)
def _initialize_affine_weight_cpu(
neox_args,
weight,
output_size,
input_size,
per_partition_size,
partition_dim,
init_method,
stride=1,
return_master_weight=False,
):
"""Initialize affine weight for model parallel.
Build the master weight on all processes and scatter
the relevant chunk."""
weight.model_parallel = True
weight.partition_dim = partition_dim
weight.partition_stride = stride
# Initialize master weight
master_weight = torch.empty(
output_size, input_size, dtype=torch.float, requires_grad=False
)
init_method(master_weight)
master_weight = master_weight.to(dtype=neox_args.params_dtype)
# Split and copy
per_partition_per_stride_size = divide(per_partition_size, stride)
weight_list = torch.split(
master_weight, per_partition_per_stride_size, dim=partition_dim
)
rank = get_model_parallel_rank()
world_size = get_model_parallel_world_size()
my_weight_list = weight_list[rank::world_size]
with torch.no_grad():
torch.cat(my_weight_list, dim=partition_dim, out=weight)
if return_master_weight:
return master_weight
return None
class VocabParallelEmbedding(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
This is mainly adapted from torch.nn.Embedding and all the default
values are kept.
Arguments:
num_embeddings: vocabulary size.
embedding_dim: size of hidden state.
init_method: method to initialize weights.
"""
def __init__(
self, neox_args, num_embeddings, embedding_dim, init_method=init.xavier_normal_
):
super(VocabParallelEmbedding, self).__init__()
# Keep the input dimensions.
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
# Set the detauls for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.0
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.model_parallel_size = get_model_parallel_world_size()
# Divide the weight matrix along the vocabulary dimension.
(
self.vocab_start_index,
self.vocab_end_index,
) = VocabUtility.vocab_range_from_global_vocab_size(
self.num_embeddings, get_model_parallel_rank(), self.model_parallel_size
)
self.num_embeddings_per_partition = (
self.vocab_end_index - self.vocab_start_index
)
self.init_method = init_method
# Allocate weights and initialize.
if neox_args.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
dtype=neox_args.params_dtype,
)
)
_initialize_affine_weight_cpu(
neox_args,
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
init_method,
)
else:
self.weight = Parameter(
torch.empty(
self.num_embeddings_per_partition,
self.embedding_dim,
device=torch.cuda.current_device(),
dtype=neox_args.params_dtype,
)
)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=0, stride=1
)
def mup_reinitialize_weights(self, neox_args):
if neox_args.use_cpu_initialization:
_initialize_affine_weight_cpu(
neox_args,
self.weight,
self.num_embeddings,
self.embedding_dim,
self.num_embeddings_per_partition,
0,
partial(self.init_method, use_mup=True),
)
else:
_initialize_affine_weight_gpu(
self.weight,
partial(self.init_method, use_mup=True),
partition_dim=0,
stride=1,
)
def forward(self, input_):
if self.model_parallel_size > 1:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (
input_ >= self.vocab_end_index
)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(
masked_input,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
# Mask the output embedding.
if self.model_parallel_size > 1:
output_parallel[input_mask, :] = 0.0
# Reduce across all the model parallel GPUs.
output = reduce_from_model_parallel_region(output_parallel)
return output
class ParallelRelativePositionBias(torch.nn.Module):
"""T5 Relative Position Bias parallelized in the heads dimension
Based on https://github.com/lucidrains/x-transformers/blob/6b93c21be0d0a679da6f7b9621d9bb638ab18428/x_transformers/x_transformers.py#L106 (14.12.2021)
and adapted for megatron's model parallelism
Arguments:
scale: scaling factor for the bias
causal: flag for causal/non-causal language modelling.
num_buckets: number of rp buckets.
max_distance: max distance in sequence dim for each bucket.
heads: number of attention heads (total)
"""
def __init__(
self,
neox_args,
scale,
causal=True,
num_buckets=32,
max_distance=128,
heads=8,
init_method=init.xavier_normal_,
):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.heads = heads
# Set the defaults for compatibility.
self.padding_idx = None
self.max_norm = None
self.norm_type = 2.0
self.scale_grad_by_freq = False
self.sparse = False
self._weight = None
self.model_parallel_size = get_model_parallel_world_size()
self.model_parallel_rank = get_model_parallel_rank()
# Divide the weight matrix along the heads dimension.
self.head_start_index, self.head_end_index = self.get_heads_range(
self.heads, self.model_parallel_rank, self.model_parallel_size
)
self.num_heads_per_partition = self.head_end_index - self.head_start_index
self.init_method = init_method
# Allocate weights and initialize.
if neox_args.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.num_buckets,
self.num_heads_per_partition,
dtype=neox_args.params_dtype,
)
)
_initialize_affine_weight_cpu(
neox_args,
self.weight,
self.num_buckets,
self.heads,
self.num_heads_per_partition,
partition_dim=1,
init_method=init_method,
)
else:
self.weight = Parameter(
torch.empty(
self.num_buckets,
self.num_heads_per_partition,
device=torch.cuda.current_device(),
dtype=neox_args.params_dtype,
)
)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=1, stride=1
)
self._q_len_cached = None
self._k_len_cached = None
self._rel_pos_bucket_cached = None
def mup_reinitialize_weights(self, neox_args):
if self.use_cpu_initialization:
_initialize_affine_weight_cpu(
neox_args,
self.weight,
self.num_buckets,
self.heads,
self.num_heads_per_partition,
partition_dim=1,
init_method=partial(self.init_method, use_mup=True),
)
else:
_initialize_affine_weight_gpu(
self.weight,
partial(self.init_method, use_mup=True),
partition_dim=1,
stride=1,
)
@staticmethod
def get_heads_range(global_n_heads, rank, world_size):
per_partition_n_heads = divide(global_n_heads, world_size)
index_f = rank * per_partition_n_heads
index_l = index_f + per_partition_n_heads
return index_f, index_l
def _relative_position_bucket(
self, relative_position, num_buckets=32, max_distance=128
):
ret = 0
n = -relative_position
if not self.causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = (
max_exact
+ (
torch.log(n.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).long()
)
val_if_large = torch.min(
val_if_large, torch.full_like(val_if_large, num_buckets - 1)
)
ret += torch.where(is_small, n, val_if_large)
self._rel_pos_bucket_cached = ret
return self._rel_pos_bucket_cached
def forward(self, q_len, k_len):
if self._q_len_cached != q_len or self._k_len_cached != k_len:
# cache bucket if first step seq len stays constant
self._q_len_cached, self._k_len_cached = q_len, k_len
q_pos = torch.arange(
q_len, dtype=torch.long, device=torch.cuda.current_device()
)
k_pos = torch.arange(
k_len, dtype=torch.long, device=torch.cuda.current_device()
)
rel_pos = k_pos[None, :] - q_pos[:, None]
rp_bucket = self._relative_position_bucket(
rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance
)
else:
rp_bucket = self._rel_pos_bucket_cached
values = F.embedding(
rp_bucket,
self.weight,
self.padding_idx,
self.max_norm,
self.norm_type,
self.scale_grad_by_freq,
self.sparse,
)
bias = values.movedim(2, 0).unsqueeze(0)
return bias * self.scale
class ColumnParallelLinear(torch.nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def __init__(
self,
neox_args,
input_size,
output_size,
bias=True,
gather_output=True,
init_method=init.xavier_normal_,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
MOE=False,
MoE_mp_size=1,
mup_rescale_parameters=False,
seq_dim=0, # Dimension which is the seq_len dimension. final ParallelLinear overrides this to be 1 ; otherwise, the default is used throughout.
):
super(ColumnParallelLinear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.gather_output = gather_output
# Divide the weight matrix along the last dimension.
world_size = MoE_mp_size if MOE else get_model_parallel_world_size()
self.output_size_per_partition = divide(output_size, world_size)
self.skip_bias_add = skip_bias_add
self.sequence_parallel = neox_args.sequence_parallel
self.seq_dim = seq_dim
self.init_method = init_method
self.stride = stride
self.mup_rescale_parameters = mup_rescale_parameters
self.use_mup = neox_args.use_mup
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if neox_args.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition,
self.input_size,
dtype=neox_args.params_dtype,
)
)
self.master_weight = _initialize_affine_weight_cpu(
neox_args,
self.weight,
self.output_size,
self.input_size,
self.output_size_per_partition,
0,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
else:
self.weight = Parameter(
torch.empty(
self.output_size_per_partition,
self.input_size,
device=torch.cuda.current_device(),
dtype=neox_args.params_dtype,
)
)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=0, stride=stride
)
if bias:
if neox_args.use_cpu_initialization:
self.bias = Parameter(
torch.empty(
self.output_size_per_partition, dtype=neox_args.params_dtype
)
)
else:
self.bias = Parameter(
torch.empty(
self.output_size_per_partition,
device=torch.cuda.current_device(),
dtype=neox_args.params_dtype,
)
)
self.bias.model_parallel = True
self.bias.partition_dim = 0
self.bias.stride = stride
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
# Copied from Mup
def width_mult(self):
assert hasattr(self.weight, "infshape"), (
"Please call set_base_shapes(...). If using torch.nn.DataParallel, "
"switch to distributed training with "
"torch.nn.parallel.DistributedDataParallel instead"
)
return self.weight.infshape.width_mult()
# Copied from Mup
def _rescale_parameters(self):
"""Rescale parameters to convert SP initialization to μP initialization.
Warning: This method is NOT idempotent and should be called only once
unless you know what you are doing.
"""
if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params:
raise RuntimeError(
"`_rescale_parameters` has been called once before already. "
"Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n"
"If you called `set_base_shapes` on a model loaded from a checkpoint, "
"or just want to re-set the base shapes of an existing model, "
"make sure to set the flag `rescale_params=False`.\n"
"To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call."
)
if self.bias is not None:
self.bias.data *= self.width_mult() ** 0.5
self.weight.data *= self.width_mult() ** 0.5
self._has_rescaled_params = True
def mup_reinitialize_weights(self, neox_args):
if neox_args.use_cpu_initialization:
self.master_weight = _initialize_affine_weight_cpu(
neox_args,
self.weight,
self.output_size,
self.input_size,
self.output_size_per_partition,
0,
partial(self.init_method, use_mup=True),
stride=self.stride,
return_master_weight=keep_master_weight_for_test,
)
else:
_initialize_affine_weight_gpu(
self.weight,
partial(self.init_method, use_mup=True),
partition_dim=0,
stride=self.stride,
)
def set_parallel_output(self, value: bool):
assert isinstance(value, bool)
self.gather_output = (
not value
) # if gather_output is True, parallel output is False, so we set the opposite
def forward(self, input_):
if self.use_mup and self.mup_rescale_parameters:
input_ /= self.width_mult()
if self.sequence_parallel:
input_parallel = input_
else:
# Set up backprop all-reduce.
input_parallel = copy_to_model_parallel_region(input_)
# Matrix multiply.
if self.sequence_parallel:
# do an AG in the fwd pass, RS in bwd pass.
# gather / scatter portion happens across the sequence dim (self.seq_dim)--
# almost always is [s, b, h] and so dim 0, but for lm_head ParallelLinear it is seq_dim=1 and [b, s, h]
input_parallel = gather_from_sequence_parallel_region(
input_parallel, seq_dim=self.seq_dim
)
bias = self.bias if not self.skip_bias_add else None
output_parallel = F.linear(input_parallel, self.weight, bias)
if self.gather_output:
# All-gather across the partitions.
assert (
not self.sequence_parallel
), "sequence_parallel=True and gather_output=True are incompatible!"
output = gather_from_model_parallel_region(output_parallel)
else:
output = output_parallel
output_bias = self.bias if self.skip_bias_add else None
return output, output_bias
class RowParallelLinear(torch.nn.Module):
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its first dimension and X along its second dimension as:
- -
| A_1 |
| . |
A = | . | X = [X_1, ..., X_p]
| . |
| A_p |
- -
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias. Note that bias is not parallelized.
input_is_parallel: If true, we assume that the input is already
split across the GPUs and we do not split
again.
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def __init__(
self,
neox_args,
input_size,
output_size,
bias=True,
input_is_parallel=False,
init_method=init.xavier_normal_,
stride=1,
keep_master_weight_for_test=False,
skip_bias_add=False,
MOE=False,
MoE_mp_size=1,
parallel_output=False,
mup_rescale_parameters=False,
):
super(RowParallelLinear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.input_is_parallel = input_is_parallel
# Divide the weight matrix along the last dimension.
world_size = MoE_mp_size if MOE else get_model_parallel_world_size()
self.input_size_per_partition = divide(input_size, world_size)
self.skip_bias_add = skip_bias_add
self.parallel_output = parallel_output
self.sequence_parallel = neox_args.sequence_parallel
assert not (
self.sequence_parallel and not self.input_is_parallel
), "Cannot have self.input_is_parallel=False and self.sequence_parallel=True."
self.init_method = init_method
self.stride = stride
self.keep_master_weight_for_test = keep_master_weight_for_test
self.mup_rescale_parameters = mup_rescale_parameters
self.use_mup = neox_args.use_mup
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# Initialize weight.
if neox_args.use_cpu_initialization:
self.weight = Parameter(
torch.empty(
self.output_size,
self.input_size_per_partition,
dtype=neox_args.params_dtype,
)
)
self.master_weight = _initialize_affine_weight_cpu(
neox_args,
self.weight,
self.output_size,
self.input_size,
self.input_size_per_partition,
1,
init_method,
stride=stride,
return_master_weight=keep_master_weight_for_test,
)
else:
self.weight = Parameter(
torch.empty(
self.output_size,
self.input_size_per_partition,
device=torch.cuda.current_device(),
dtype=neox_args.params_dtype,
)
)
_initialize_affine_weight_gpu(
self.weight, init_method, partition_dim=1, stride=stride
)
if bias:
if neox_args.use_cpu_initialization:
self.bias = Parameter(
torch.empty(self.output_size, dtype=neox_args.params_dtype)
)
else:
self.bias = Parameter(
torch.empty(
self.output_size,
device=torch.cuda.current_device(),
dtype=neox_args.params_dtype,
)
)
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter("bias", None)
# Copied from Mup
def width_mult(self):
assert hasattr(self.weight, "infshape"), (
"Please call set_base_shapes(...). If using torch.nn.DataParallel, "
"switch to distributed training with "
"torch.nn.parallel.DistributedDataParallel instead"
)
return self.weight.infshape.width_mult()
# Copied from Mup
def _rescale_parameters(self):
"""Rescale parameters to convert SP initialization to μP initialization.
Warning: This method is NOT idempotent and should be called only once
unless you know what you are doing.
"""
if hasattr(self, "_has_rescaled_params") and self._has_rescaled_params:
raise RuntimeError(
"`_rescale_parameters` has been called once before already. "
"Unless you know what you are doing, usually you should not be calling `_rescale_parameters` more than once.\n"
"If you called `set_base_shapes` on a model loaded from a checkpoint, "
"or just want to re-set the base shapes of an existing model, "
"make sure to set the flag `rescale_params=False`.\n"
"To bypass this error and *still rescale parameters*, set `self._has_rescaled_params=False` before this call."
)
if self.bias is not None:
self.bias.data *= self.width_mult() ** 0.5
self.weight.data *= self.width_mult() ** 0.5
self._has_rescaled_params = True
def mup_reinitialize_weights(self, neox_args):
if neox_args.use_cpu_initialization:
self.master_weight = _initialize_affine_weight_cpu(
neox_args,
self.weight,
self.output_size,
self.input_size,
self.input_size_per_partition,
1,
partial(self.init_method, use_mup=True),
stride=self.stride,
return_master_weight=self.keep_master_weight_for_test,
)
else:
_initialize_affine_weight_gpu(
self.weight,
partial(self.init_method, use_mup=True),
partition_dim=1,
stride=self.stride,
)
def set_parallel_output(self, parallel_output: bool):
assert isinstance(parallel_output, bool)
self.parallel_output = parallel_output
def forward(self, input_):
if self.use_mup and self.mup_rescale_parameters:
input_ /= self.width_mult()
# Set up backprop all-reduce.
if self.input_is_parallel:
input_parallel = input_
else:
input_parallel = scatter_to_model_parallel_region(input_)
# Matrix multiply.
output_parallel = F.linear(input_parallel, self.weight)
# All-reduce across all the partitions.
if self.sequence_parallel and not self.parallel_output:
# do an RS in the fwd pass, AG in bwd pass.
# skip in the gpt-j parallel sublayer case (self.parallel_output=True)
# (user responsible for calling reduce-scatter)
output_ = reduce_scatter_to_sequence_parallel_region(output_parallel)
elif not self.parallel_output:
output_ = reduce_from_model_parallel_region(output_parallel)
else:
output_ = output_parallel
if not self.skip_bias_add:
output = output_ + self.bias if self.bias is not None else output_
output_bias = None
else:
output = output_
output_bias = self.bias
return output, output_bias