|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Transformer.""" |
|
import math |
|
from contextlib import nullcontext |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from megatron import get_timers, get_args, get_global_memory_buffer |
|
from megatron import mpu |
|
from .module import MegatronModule |
|
from megatron.model.enums import AttnMaskType, ModelType, LayerType, AttnType |
|
from megatron.model import LayerNorm |
|
from megatron.model.fused_softmax import FusedScaleMaskSoftmax |
|
from megatron.model.fused_bias_gelu import bias_gelu_impl |
|
from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu |
|
|
|
|
|
""" We use the following notation throughout this file: |
|
h: hidden size |
|
n: number of attention heads |
|
p: number of model parallel partitions |
|
np: n/p |
|
hp: h/p |
|
hn: h/n |
|
b: batch size |
|
s: sequence length |
|
l: number of layers |
|
Transformer takes input of size [s, b, h] and returns a |
|
tensor of the same size. We use the following arguments: |
|
hyperparameters: transformer hyperparameters |
|
""" |
|
|
|
class DropPath(MegatronModule): |
|
"""Drop paths (Stochastic Depth) per sample |
|
(when applied in main path of residual blocks). |
|
""" |
|
|
|
def __init__(self, drop_prob=0.): |
|
super(DropPath, self).__init__() |
|
self.drop_prob = drop_prob |
|
|
|
def forward(self, hidden_state): |
|
if self.drop_prob == 0. or not self.training: |
|
return hidden_state |
|
keep_prob = 1 - self.drop_prob |
|
|
|
shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1) |
|
random_tensor = keep_prob + \ |
|
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device) |
|
random_tensor.floor_() |
|
output = hidden_state.div(keep_prob) * random_tensor |
|
return output |
|
|
|
|
|
class ParallelMLP(MegatronModule): |
|
"""MLP. |
|
|
|
MLP will take the input with h hidden state, project it to 4*h |
|
hidden dimension, perform nonlinear transformation, and project the |
|
state back into h hidden dimension. |
|
""" |
|
|
|
def __init__(self, init_method, output_layer_init_method): |
|
super(ParallelMLP, self).__init__() |
|
args = get_args() |
|
|
|
|
|
self.dense_h_to_4h = mpu.ColumnParallelLinear( |
|
args.hidden_size, |
|
args.ffn_hidden_size, |
|
gather_output=False, |
|
init_method=init_method, |
|
skip_bias_add=True) |
|
|
|
self.bias_gelu_fusion = args.bias_gelu_fusion |
|
self.activation_func = F.gelu |
|
if args.openai_gelu: |
|
self.activation_func = openai_gelu |
|
elif args.onnx_safe: |
|
self.activation_func = erf_gelu |
|
|
|
|
|
self.dense_4h_to_h = mpu.RowParallelLinear( |
|
args.ffn_hidden_size, |
|
args.hidden_size, |
|
input_is_parallel=True, |
|
init_method=output_layer_init_method, |
|
skip_bias_add=True) |
|
|
|
def forward(self, hidden_states): |
|
|
|
|
|
intermediate_parallel, bias_parallel = self.dense_h_to_4h(hidden_states) |
|
|
|
if self.bias_gelu_fusion: |
|
intermediate_parallel = \ |
|
bias_gelu_impl(intermediate_parallel, bias_parallel) |
|
else: |
|
intermediate_parallel = \ |
|
self.activation_func(intermediate_parallel + bias_parallel) |
|
|
|
|
|
output, output_bias = self.dense_4h_to_h(intermediate_parallel) |
|
return output, output_bias |
|
|
|
class SwitchMLP(MegatronModule): |
|
""" |
|
Routes input to one of N MLP "experts" |
|
""" |
|
def __init__(self, init_method, output_layer_init_method): |
|
super(SwitchMLP, self).__init__() |
|
args = get_args() |
|
self.router = torch.nn.Linear(args.hidden_size, args.num_experts) |
|
self.experts = torch.nn.ModuleList() |
|
for i in range(args.num_experts): |
|
self.experts.append(ParallelMLP(init_method, output_layer_init_method)) |
|
|
|
def forward(self, hidden_states): |
|
|
|
s = hidden_states.size(0) |
|
b = hidden_states.size(1) |
|
h = hidden_states.size(2) |
|
route = self.router(hidden_states) |
|
route = torch.nn.functional.softmax(route, dim=2) |
|
max_prob, max_ind = torch.max(route, dim=2) |
|
max_prob = torch.unsqueeze(max_prob, 2) |
|
|
|
|
|
|
|
|
|
hidden_states = hidden_states.view(-1, hidden_states.size(2)) |
|
max_prob = max_prob.view(-1, max_prob.size(2)) |
|
max_ind = max_ind.view(-1) |
|
|
|
output_total = torch.empty_like(hidden_states) |
|
output_bias_total = torch.empty_like(hidden_states) |
|
|
|
|
|
for expert_num, expert in enumerate(self.experts): |
|
local_indices = (max_ind == expert_num).nonzero() |
|
hidden = hidden_states[local_indices,:] |
|
output, output_bias = expert(hidden) |
|
output_bias = output_bias.expand_as(output) |
|
output_total[local_indices,:] = output |
|
output_bias_total[local_indices,:] = output_bias |
|
|
|
output_total = output_total*max_prob |
|
output_bias_total = output_bias_total*max_prob |
|
output_total = output_total.view(s, b, h) |
|
output_bias_total = output_bias_total.view(s, b, h) |
|
|
|
return output_total, output_bias_total |
|
|
|
|
|
class CoreAttention(MegatronModule): |
|
|
|
def __init__(self, layer_number, |
|
attn_mask_type=AttnMaskType.padding): |
|
super(CoreAttention, self).__init__() |
|
args = get_args() |
|
self.fp16 = args.fp16 |
|
self.bf16 = args.bf16 |
|
|
|
self.apply_query_key_layer_scaling = args.apply_query_key_layer_scaling |
|
self.attention_softmax_in_fp32 = args.attention_softmax_in_fp32 |
|
if self.apply_query_key_layer_scaling: |
|
self.attention_softmax_in_fp32 = True |
|
self.layer_number = max(1, layer_number) |
|
self.attn_mask_type = attn_mask_type |
|
self.sequence_parallel = args.sequence_parallel |
|
|
|
projection_size = args.kv_channels * args.num_attention_heads |
|
|
|
|
|
world_size = mpu.get_tensor_model_parallel_world_size() |
|
self.hidden_size_per_partition = mpu.divide(projection_size, |
|
world_size) |
|
self.hidden_size_per_attention_head = mpu.divide( |
|
projection_size, args.num_attention_heads) |
|
self.num_attention_heads_per_partition = mpu.divide( |
|
args.num_attention_heads, world_size) |
|
|
|
coeff = None |
|
self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) |
|
if self.apply_query_key_layer_scaling: |
|
coeff = self.layer_number |
|
self.norm_factor *= coeff |
|
|
|
self.scale_mask_softmax = FusedScaleMaskSoftmax( |
|
self.fp16, self.bf16, |
|
self.attn_mask_type, |
|
args.masked_softmax_fusion, |
|
attention_mask_func, |
|
self.attention_softmax_in_fp32, |
|
coeff) |
|
|
|
|
|
|
|
|
|
self.attention_dropout = torch.nn.Dropout(args.attention_dropout) |
|
|
|
def forward(self, query_layer, key_layer, |
|
value_layer, attention_mask): |
|
|
|
|
|
|
|
|
|
|
|
|
|
output_size = (query_layer.size(1), |
|
query_layer.size(2), |
|
query_layer.size(0), |
|
key_layer.size(0)) |
|
|
|
|
|
query_layer = query_layer.view(output_size[2], |
|
output_size[0] * output_size[1], -1) |
|
|
|
key_layer = key_layer.view(output_size[3], |
|
output_size[0] * output_size[1], -1) |
|
|
|
|
|
matmul_input_buffer = get_global_memory_buffer().get_tensor( |
|
(output_size[0]*output_size[1], output_size[2], output_size[3]), |
|
query_layer.dtype, "mpu") |
|
|
|
|
|
matmul_result = torch.baddbmm( |
|
matmul_input_buffer, |
|
query_layer.transpose(0, 1), |
|
key_layer.transpose(0, 1).transpose(1, 2), |
|
beta=0.0, alpha=(1.0/self.norm_factor)) |
|
|
|
|
|
attention_scores = matmul_result.view(*output_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
attention_probs = self.scale_mask_softmax(attention_scores, |
|
attention_mask) |
|
|
|
|
|
|
|
|
|
if not self.sequence_parallel: |
|
with mpu.get_cuda_rng_tracker().fork(): |
|
attention_probs = self.attention_dropout(attention_probs) |
|
else: |
|
attention_probs = self.attention_dropout(attention_probs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output_size = (value_layer.size(1), |
|
value_layer.size(2), |
|
query_layer.size(0), |
|
value_layer.size(3)) |
|
|
|
|
|
value_layer = value_layer.view(value_layer.size(0), |
|
output_size[0] * output_size[1], -1) |
|
|
|
|
|
attention_probs = attention_probs.view(output_size[0] * output_size[1], |
|
output_size[2], -1) |
|
|
|
|
|
context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1)) |
|
|
|
|
|
context_layer = context_layer.view(*output_size) |
|
|
|
|
|
context_layer = context_layer.permute(2, 0, 1, 3).contiguous() |
|
|
|
|
|
new_context_layer_shape = context_layer.size()[:-2] + \ |
|
(self.hidden_size_per_partition,) |
|
context_layer = context_layer.view(*new_context_layer_shape) |
|
|
|
return context_layer |
|
|
|
|
|
class ParallelAttention(MegatronModule): |
|
"""Parallel self-attention layer abstract class. |
|
|
|
Self-attention layer takes input with size [s, b, h] |
|
and returns output of the same size. |
|
""" |
|
|
|
def __init__(self, init_method, |
|
output_layer_init_method, layer_number, |
|
attention_type=AttnType.self_attn, |
|
attn_mask_type=AttnMaskType.padding): |
|
super(ParallelAttention, self).__init__() |
|
args = get_args() |
|
self.layer_number = max(1, layer_number) |
|
self.attention_type = attention_type |
|
self.attn_mask_type = attn_mask_type |
|
self.params_dtype = args.params_dtype |
|
|
|
projection_size = args.kv_channels * args.num_attention_heads |
|
|
|
|
|
world_size = mpu.get_tensor_model_parallel_world_size() |
|
self.hidden_size_per_attention_head = mpu.divide( |
|
projection_size, args.num_attention_heads) |
|
self.num_attention_heads_per_partition = mpu.divide( |
|
args.num_attention_heads, world_size) |
|
|
|
|
|
if attention_type == AttnType.self_attn: |
|
self.query_key_value = mpu.ColumnParallelLinear( |
|
args.hidden_size, |
|
3 * projection_size, |
|
gather_output=False, |
|
init_method=init_method) |
|
else: |
|
assert attention_type == AttnType.cross_attn |
|
self.query = mpu.ColumnParallelLinear( |
|
args.hidden_size, |
|
projection_size, |
|
gather_output=False, |
|
init_method=init_method) |
|
|
|
self.key_value = mpu.ColumnParallelLinear( |
|
args.hidden_size, |
|
2 * projection_size, |
|
gather_output=False, |
|
init_method=init_method) |
|
|
|
self.core_attention = CoreAttention(self.layer_number, |
|
self.attn_mask_type) |
|
self.checkpoint_core_attention = args.recompute_granularity == 'selective' |
|
|
|
|
|
self.dense = mpu.RowParallelLinear( |
|
projection_size, |
|
args.hidden_size, |
|
input_is_parallel=True, |
|
init_method=output_layer_init_method, |
|
skip_bias_add=True) |
|
|
|
def _checkpointed_attention_forward(self, query_layer, key_layer, |
|
value_layer, attention_mask): |
|
"""Forward method with activation checkpointing.""" |
|
def custom_forward(*inputs): |
|
query_layer = inputs[0] |
|
key_layer = inputs[1] |
|
value_layer = inputs[2] |
|
attention_mask = inputs[3] |
|
output_ = self.core_attention(query_layer, key_layer, |
|
value_layer, attention_mask) |
|
return output_ |
|
|
|
hidden_states = mpu.checkpoint( |
|
custom_forward, |
|
False, query_layer, key_layer, value_layer, attention_mask) |
|
|
|
return hidden_states |
|
|
|
def _allocate_memory(self, inference_max_sequence_len, batch_size): |
|
return torch.empty( |
|
inference_max_sequence_len, |
|
batch_size, |
|
self.num_attention_heads_per_partition, |
|
self.hidden_size_per_attention_head, |
|
dtype=self.params_dtype, |
|
device=torch.cuda.current_device()) |
|
|
|
def forward(self, hidden_states, attention_mask, |
|
encoder_output=None, inference_params=None): |
|
|
|
|
|
|
|
|
|
|
|
if inference_params: |
|
if self.layer_number not in inference_params.key_value_memory_dict: |
|
inf_max_seq_len = inference_params.max_sequence_len |
|
inf_max_batch_size = inference_params.max_batch_size |
|
inference_key_memory = self._allocate_memory( |
|
inf_max_seq_len, inf_max_batch_size) |
|
inference_value_memory = self._allocate_memory( |
|
inf_max_seq_len, inf_max_batch_size) |
|
inference_params.key_value_memory_dict[self.layer_number] = ( |
|
inference_key_memory, inference_value_memory) |
|
else: |
|
inference_key_memory, inference_value_memory = \ |
|
inference_params.key_value_memory_dict[self.layer_number] |
|
|
|
|
|
|
|
|
|
|
|
if self.attention_type == AttnType.self_attn: |
|
|
|
mixed_x_layer, _ = self.query_key_value(hidden_states) |
|
|
|
|
|
new_tensor_shape = mixed_x_layer.size()[:-1] + \ |
|
(self.num_attention_heads_per_partition, |
|
3 * self.hidden_size_per_attention_head) |
|
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) |
|
|
|
|
|
(query_layer, |
|
key_layer, |
|
value_layer) = mpu.split_tensor_along_last_dim(mixed_x_layer, 3) |
|
else: |
|
|
|
mixed_kv_layer, _ = self.key_value(encoder_output) |
|
|
|
|
|
new_tensor_shape = mixed_kv_layer.size()[:-1] + \ |
|
(self.num_attention_heads_per_partition, |
|
2 * self.hidden_size_per_attention_head) |
|
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape) |
|
|
|
|
|
(key_layer, |
|
value_layer) = mpu.split_tensor_along_last_dim(mixed_kv_layer, 2) |
|
|
|
|
|
query_layer, _ = self.query(hidden_states) |
|
|
|
new_tensor_shape = query_layer.size()[:-1] + \ |
|
(self.num_attention_heads_per_partition, |
|
self.hidden_size_per_attention_head) |
|
query_layer = query_layer.view(*new_tensor_shape) |
|
|
|
|
|
|
|
|
|
|
|
if inference_params: |
|
batch_start = inference_params.batch_size_offset |
|
batch_end = batch_start + key_layer.size(1) |
|
assert batch_end <= inference_key_memory.size(1) |
|
sequence_start = inference_params.sequence_len_offset |
|
sequence_end = sequence_start + key_layer.size(0) |
|
assert sequence_end <= inference_key_memory.size(0) |
|
|
|
inference_key_memory[sequence_start:sequence_end, |
|
batch_start:batch_end, ...] = key_layer |
|
inference_value_memory[sequence_start:sequence_end, |
|
batch_start:batch_end, ...] = value_layer |
|
key_layer = inference_key_memory[ |
|
:sequence_end, batch_start:batch_end, ...] |
|
value_layer = inference_value_memory[ |
|
:sequence_end, batch_start:batch_end, ...] |
|
|
|
|
|
|
|
|
|
|
|
if self.checkpoint_core_attention: |
|
context_layer = self._checkpointed_attention_forward( |
|
query_layer, key_layer, value_layer, attention_mask) |
|
else: |
|
context_layer = self.core_attention( |
|
query_layer, key_layer, value_layer, attention_mask) |
|
|
|
|
|
|
|
|
|
|
|
output, bias = self.dense(context_layer) |
|
|
|
return output, bias |
|
|
|
|
|
def bias_dropout_add(x, bias, residual, prob, training): |
|
|
|
out = torch.nn.functional.dropout(x + bias, p=prob, training=training) |
|
out = residual + out |
|
return out |
|
|
|
|
|
def get_bias_dropout_add(training): |
|
def _bias_dropout_add(x, bias, residual, prob): |
|
return bias_dropout_add(x, bias, residual, prob, training) |
|
return _bias_dropout_add |
|
|
|
|
|
@torch.jit.script |
|
def bias_dropout_add_fused_train(x: torch.Tensor, |
|
bias: torch.Tensor, |
|
residual: torch.Tensor, |
|
prob: float) -> torch.Tensor: |
|
return bias_dropout_add(x, bias, residual, prob, True) |
|
|
|
|
|
@torch.jit.script |
|
def bias_dropout_add_fused_inference(x: torch.Tensor, |
|
bias: torch.Tensor, |
|
residual: torch.Tensor, |
|
prob: float) -> torch.Tensor: |
|
return bias_dropout_add(x, bias, residual, prob, False) |
|
|
|
|
|
class ParallelTransformerLayer(MegatronModule): |
|
"""A single transformer layer. |
|
|
|
Transformer layer takes input with size [s, b, h] and returns an |
|
output of the same size. |
|
""" |
|
|
|
def __init__(self, init_method, output_layer_init_method, |
|
layer_number, layer_type=LayerType.encoder, |
|
self_attn_mask_type=AttnMaskType.padding, |
|
drop_path_rate=0.): |
|
args = get_args() |
|
|
|
super(ParallelTransformerLayer, self).__init__() |
|
self.layer_number = layer_number |
|
self.layer_type = layer_type |
|
|
|
self.apply_residual_connection_post_layernorm \ |
|
= args.apply_residual_connection_post_layernorm |
|
|
|
self.bf16 = args.bf16 |
|
self.fp32_residual_connection = args.fp32_residual_connection |
|
|
|
|
|
self.input_layernorm = LayerNorm( |
|
args.hidden_size, |
|
eps=args.layernorm_epsilon, |
|
no_persist_layer_norm=args.no_persist_layer_norm, |
|
sequence_parallel=args.sequence_parallel) |
|
|
|
|
|
self.self_attention = ParallelAttention( |
|
init_method, |
|
output_layer_init_method, |
|
layer_number, |
|
attention_type=AttnType.self_attn, |
|
attn_mask_type=self_attn_mask_type) |
|
self.hidden_dropout = args.hidden_dropout |
|
self.bias_dropout_fusion = args.bias_dropout_fusion |
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None |
|
|
|
|
|
self.post_attention_layernorm = LayerNorm( |
|
args.hidden_size, |
|
eps=args.layernorm_epsilon, |
|
no_persist_layer_norm=args.no_persist_layer_norm, |
|
sequence_parallel=args.sequence_parallel) |
|
|
|
if self.layer_type == LayerType.decoder: |
|
self.inter_attention = ParallelAttention( |
|
init_method, |
|
output_layer_init_method, |
|
layer_number, |
|
attention_type=AttnType.cross_attn) |
|
|
|
self.post_inter_attention_layernorm = LayerNorm( |
|
args.hidden_size, |
|
eps=args.layernorm_epsilon, |
|
no_persist_layer_norm=args.no_persist_layer_norm, |
|
sequence_parallel=args.sequence_parallel) |
|
|
|
|
|
if args.num_experts is not None: |
|
self.mlp = SwitchMLP(init_method, output_layer_init_method) |
|
else: |
|
self.mlp = ParallelMLP(init_method, output_layer_init_method) |
|
|
|
|
|
TORCH_MAJOR = int(torch.__version__.split('.')[0]) |
|
TORCH_MINOR = int(torch.__version__.split('.')[1]) |
|
use_nvfuser = TORCH_MAJOR > 1 or (TORCH_MAJOR == 1 and TORCH_MINOR >= 10) |
|
self.bias_dropout_add_exec_handler = \ |
|
nullcontext if use_nvfuser else torch.enable_grad |
|
|
|
def forward(self, hidden_states, attention_mask, |
|
encoder_output=None, enc_dec_attn_mask=None, |
|
inference_params=None): |
|
|
|
|
|
|
|
layernorm_output = self.input_layernorm(hidden_states) |
|
|
|
attention_output, attention_bias = \ |
|
self.self_attention( |
|
layernorm_output, |
|
attention_mask, |
|
inference_params=inference_params) |
|
|
|
|
|
if self.apply_residual_connection_post_layernorm: |
|
residual = layernorm_output |
|
else: |
|
residual = hidden_states |
|
|
|
if self.drop_path is None: |
|
|
|
|
|
|
|
|
|
if self.bias_dropout_fusion: |
|
if self.training: |
|
bias_dropout_add_func = bias_dropout_add_fused_train |
|
else: |
|
bias_dropout_add_func = bias_dropout_add_fused_inference |
|
else: |
|
bias_dropout_add_func = get_bias_dropout_add(self.training) |
|
|
|
with self.bias_dropout_add_exec_handler(): |
|
layernorm_input = bias_dropout_add_func( |
|
attention_output, |
|
attention_bias.expand_as(residual), |
|
residual, |
|
self.hidden_dropout) |
|
else: |
|
out = torch.nn.functional.dropout(attention_output + attention_bias, |
|
p=self.hidden_dropout, |
|
training=self.training) |
|
layernorm_input = residual + self.drop_path(out) |
|
|
|
|
|
layernorm_output = self.post_attention_layernorm(layernorm_input) |
|
|
|
if self.layer_type == LayerType.decoder: |
|
attention_output, attention_bias = \ |
|
self.inter_attention(layernorm_output, |
|
enc_dec_attn_mask, |
|
encoder_output=encoder_output) |
|
|
|
if self.apply_residual_connection_post_layernorm: |
|
residual = layernorm_output |
|
else: |
|
residual = layernorm_input |
|
|
|
with self.bias_dropout_add_exec_handler(): |
|
layernorm_input = bias_dropout_add_func( |
|
attention_output, |
|
attention_bias.expand_as(residual), |
|
residual, |
|
self.hidden_dropout) |
|
|
|
|
|
layernorm_output = self.post_inter_attention_layernorm(layernorm_input) |
|
|
|
|
|
mlp_output, mlp_bias = self.mlp(layernorm_output) |
|
|
|
|
|
if self.apply_residual_connection_post_layernorm: |
|
residual = layernorm_output |
|
else: |
|
residual = layernorm_input |
|
|
|
if self.drop_path is None: |
|
with self.bias_dropout_add_exec_handler(): |
|
output = bias_dropout_add_func( |
|
mlp_output, |
|
mlp_bias.expand_as(residual), |
|
residual, |
|
self.hidden_dropout) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = mpu.make_viewless_tensor(inp = output, |
|
requires_grad = output.requires_grad, |
|
keep_graph = True) |
|
|
|
else: |
|
out = torch.nn.functional.dropout(mlp_output + mlp_bias, |
|
p=self.hidden_dropout, |
|
training=self.training) |
|
output = residual + self.drop_path(out) |
|
|
|
return output |
|
|
|
|
|
class NoopTransformerLayer(MegatronModule): |
|
"""A single 'no-op' transformer layer. |
|
|
|
The sole purpose of this layer is for when a standalone embedding layer |
|
is used (i.e., args.standalone_embedding_stage == True). In this case, |
|
zero transformer layers are assigned when pipeline rank == 0. Additionally, |
|
when virtual pipeline rank >= 1, zero total model parameters are created |
|
(virtual rank 0 contains the input embedding). This results in the model's |
|
input and output tensors being the same, which causes an error when |
|
performing certain memory optimiations on the output tensor (e.g., |
|
deallocating it). Thus, this layer disconnects the input from the output |
|
via a clone. Since ranks containing a no-op layer are generally under- |
|
utilized (both compute and memory), there's no worry of any performance |
|
degredation. |
|
""" |
|
|
|
def __init__(self, layer_number): |
|
super().__init__() |
|
self.layer_number = layer_number |
|
|
|
def forward(self, hidden_states, attention_mask, |
|
encoder_output=None, enc_dec_attn_mask=None, |
|
inference_params=None): |
|
return hidden_states.clone() |
|
|
|
|
|
class ParallelTransformer(MegatronModule): |
|
"""Transformer class.""" |
|
|
|
def __init__(self, init_method, output_layer_init_method, |
|
layer_type=LayerType.encoder, |
|
self_attn_mask_type=AttnMaskType.padding, |
|
trans_layer_type="encoder", |
|
post_layer_norm=True, |
|
pre_process=True, post_process=True, |
|
drop_path_rate=0.0): |
|
super(ParallelTransformer, self).__init__() |
|
args = get_args() |
|
|
|
self.layer_type = layer_type |
|
self.model_type = args.model_type |
|
self.bf16 = args.bf16 |
|
self.fp32_residual_connection = args.fp32_residual_connection |
|
self.post_layer_norm = post_layer_norm |
|
self.pre_process = pre_process |
|
self.post_process = post_process |
|
self.input_tensor = None |
|
self.drop_path_rate = drop_path_rate |
|
|
|
|
|
self.recompute_granularity = args.recompute_granularity |
|
self.recompute_method = args.recompute_method |
|
self.recompute_num_layers = args.recompute_num_layers |
|
self.distribute_saved_activations = \ |
|
args.distribute_saved_activations and not args.sequence_parallel |
|
|
|
self.sequence_parallel = args.sequence_parallel |
|
|
|
|
|
if trans_layer_type == "encoder": |
|
self.num_layers = mpu.get_num_layers( |
|
args, args.model_type == ModelType.encoder_and_decoder) |
|
elif trans_layer_type == "decoder": |
|
self.num_layers = mpu.get_num_layers_decoder( |
|
args, args.model_type == ModelType.encoder_and_decoder) |
|
else: |
|
print("No support layer type") |
|
import sys;sys.exit(0) |
|
|
|
self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)] |
|
|
|
|
|
def build_layer(layer_number): |
|
return ParallelTransformerLayer( |
|
init_method, |
|
output_layer_init_method, |
|
layer_number, |
|
layer_type=layer_type, |
|
self_attn_mask_type=self_attn_mask_type, |
|
drop_path_rate=self.drop_path_rates[layer_number - 1]) |
|
if args.virtual_pipeline_model_parallel_size is not None: |
|
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ |
|
'num_layers_per_stage must be divisible by ' \ |
|
'virtual_pipeline_model_parallel_size' |
|
assert args.model_type != ModelType.encoder_and_decoder |
|
|
|
|
|
self.num_layers = self.num_layers // args.virtual_pipeline_model_parallel_size |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
offset = mpu.get_virtual_pipeline_model_parallel_rank() * ( |
|
args.num_layers // args.virtual_pipeline_model_parallel_size) + \ |
|
(mpu.get_pipeline_model_parallel_rank() * self.num_layers) |
|
else: |
|
|
|
if args.model_type == ModelType.encoder_and_decoder and \ |
|
mpu.get_pipeline_model_parallel_world_size() > 1: |
|
pipeline_rank = mpu.get_pipeline_model_parallel_rank() |
|
if layer_type == LayerType.encoder: |
|
offset = pipeline_rank * self.num_layers |
|
else: |
|
num_ranks_in_enc = args.pipeline_model_parallel_split_rank |
|
offset = (pipeline_rank - num_ranks_in_enc) * self.num_layers |
|
else: |
|
offset = mpu.get_pipeline_model_parallel_rank() * self.num_layers |
|
|
|
if self.num_layers == 0: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.num_layers = 1 |
|
self.layers = torch.nn.ModuleList([ NoopTransformerLayer(1) ]) |
|
else: |
|
self.layers = torch.nn.ModuleList( |
|
[build_layer(i + 1 + offset) for i in range(self.num_layers)]) |
|
|
|
if self.post_process and self.post_layer_norm: |
|
|
|
self.final_layernorm = LayerNorm( |
|
args.hidden_size, |
|
eps=args.layernorm_epsilon, |
|
no_persist_layer_norm=args.no_persist_layer_norm, |
|
sequence_parallel=args.sequence_parallel) |
|
|
|
def _get_layer(self, layer_number): |
|
return self.layers[layer_number] |
|
|
|
def _checkpointed_forward(self, hidden_states, attention_mask, |
|
encoder_output, enc_dec_attn_mask): |
|
"""Forward method with activation checkpointing.""" |
|
def custom(start, end): |
|
def custom_forward(*inputs): |
|
x_ = inputs[0] |
|
attention_mask = inputs[1] |
|
encoder_output = inputs[2] |
|
enc_dec_attn_mask = inputs[3] |
|
for index in range(start, end): |
|
layer = self._get_layer(index) |
|
x_ = layer(x_, attention_mask, encoder_output, enc_dec_attn_mask) |
|
return x_ |
|
return custom_forward |
|
|
|
if self.recompute_method == 'uniform': |
|
|
|
|
|
|
|
l = 0 |
|
while l < self.num_layers: |
|
hidden_states = mpu.checkpoint( |
|
custom(l, l + self.recompute_num_layers), |
|
self.distribute_saved_activations, |
|
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) |
|
l += self.recompute_num_layers |
|
|
|
elif self.recompute_method == 'block': |
|
|
|
|
|
|
|
for l in range(self.num_layers): |
|
if l < self.recompute_num_layers: |
|
hidden_states = mpu.checkpoint( |
|
custom(l, l + 1), |
|
self.distribute_saved_activations, |
|
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) |
|
else: |
|
hidden_states = custom(l, l + 1)( |
|
hidden_states, attention_mask, encoder_output, enc_dec_attn_mask) |
|
else: |
|
raise ValueError("Invalid activation recompute method.") |
|
|
|
return hidden_states |
|
|
|
def set_input_tensor(self, input_tensor): |
|
"""Set input tensor to be used instead of forward()'s input. |
|
|
|
When doing pipeline parallelism the input from the previous |
|
stage comes from communication, not from the input, so the |
|
model's forward_step_func won't have it. This function is thus |
|
used by internal code to bypass the input provided by the |
|
forward_step_func""" |
|
self.input_tensor = input_tensor |
|
|
|
def forward(self, hidden_states, attention_mask, |
|
encoder_output=None, enc_dec_attn_mask=None, |
|
inference_params=None): |
|
|
|
|
|
|
|
if inference_params: |
|
assert self.recompute_granularity is None, \ |
|
'inference does not work with activation checkpointing' |
|
|
|
if not self.pre_process: |
|
|
|
hidden_states = self.input_tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hidden_states = mpu.make_viewless_tensor( |
|
hidden_states, |
|
requires_grad=True, |
|
keep_graph=True, |
|
) |
|
|
|
if self.sequence_parallel: |
|
rng_context = mpu.get_cuda_rng_tracker().fork() |
|
else: |
|
rng_context = nullcontext() |
|
|
|
with rng_context: |
|
|
|
if self.recompute_granularity == 'full': |
|
hidden_states = self._checkpointed_forward(hidden_states, |
|
attention_mask, |
|
encoder_output, |
|
enc_dec_attn_mask) |
|
else: |
|
for index in range(self.num_layers): |
|
layer = self._get_layer(index) |
|
hidden_states = layer( |
|
hidden_states, |
|
attention_mask, |
|
encoder_output=encoder_output, |
|
enc_dec_attn_mask=enc_dec_attn_mask, |
|
inference_params=inference_params) |
|
|
|
|
|
if self.post_process and self.post_layer_norm: |
|
hidden_states = self.final_layernorm(hidden_states) |
|
|
|
return hidden_states |
|
|