|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Forward step utilities.""" |
|
|
|
from collections.abc import Iterable |
|
|
|
import torch |
|
|
|
from megatron import ( |
|
get_args, |
|
mpu) |
|
from .communication import ( |
|
send_to_next_pipeline_rank, |
|
recv_from_prev_pipeline_rank_) |
|
|
|
|
|
|
|
class InferenceParams: |
|
"""Inference parameters that are passed to the main model in order |
|
to efficienly calculate and store the context during inference.""" |
|
|
|
def __init__(self, max_batch_size, max_sequence_len): |
|
"""Note that offsets are set to zero and we always set the |
|
flag to allocate memory. After the first call, make sure to |
|
set this flag to False.""" |
|
self.max_sequence_len = max_sequence_len |
|
self.max_batch_size = max_batch_size |
|
self.sequence_len_offset = 0 |
|
self.batch_size_offset = 0 |
|
self.key_value_memory_dict = {} |
|
|
|
def swap_key_value_dict(self, batch_idx): |
|
"swap between batches" |
|
if len(self.key_value_memory_dict) == 0: |
|
raise ValueError("should not swap when dict in empty") |
|
|
|
for layer_number in self.key_value_memory_dict.keys(): |
|
inference_key_memory, inference_value_memory = self.key_value_memory_dict[layer_number] |
|
assert len(batch_idx) == inference_key_memory.shape[1] |
|
new_inference_key_memory = inference_key_memory[:, batch_idx] |
|
new_inference_value_memory = inference_value_memory[:, batch_idx] |
|
self.key_value_memory_dict[layer_number] = ( |
|
new_inference_key_memory, new_inference_value_memory) |
|
|
|
class ForwardStep: |
|
"""Forward step function with all the communications. |
|
We use a class here to hide the inference parameters |
|
from the outside caller.""" |
|
|
|
def __init__(self, model, max_batch_size, max_sequence_len): |
|
"""Set values so we don't need to do it multiple times.""" |
|
|
|
assert not isinstance(model, Iterable), \ |
|
'interleaving schedule is not supported for inference' |
|
model.eval() |
|
self.model = model |
|
|
|
self.inference_params = InferenceParams(max_batch_size, |
|
max_sequence_len) |
|
|
|
args = get_args() |
|
self.pipeline_size_larger_than_one = ( |
|
args.pipeline_model_parallel_size > 1) |
|
|
|
self.pipelining_batch_x_seqlen = \ |
|
args.inference_batch_times_seqlen_threshold |
|
|
|
|
|
def __call__(self, tokens, position_ids, attention_mask): |
|
"""Invocation of the forward methods. Note that self.inference_params |
|
is being modified by the forward step.""" |
|
|
|
if self.pipeline_size_larger_than_one: |
|
current_batch_x_seqlen = tokens.size(0) * tokens.size(1) |
|
if current_batch_x_seqlen >= self.pipelining_batch_x_seqlen: |
|
micro_batch_size = \ |
|
max(1, self.pipelining_batch_x_seqlen // tokens.size(1)) |
|
return _with_pipelining_forward_step(self.model, |
|
tokens, |
|
position_ids, |
|
attention_mask, |
|
self.inference_params, |
|
micro_batch_size) |
|
|
|
return _no_pipelining_forward_step(self.model, |
|
tokens, |
|
position_ids, |
|
attention_mask, |
|
self.inference_params) |
|
|
|
|
|
|
|
def _get_recv_buffer_dtype(args): |
|
"""Receive happens between the layers.""" |
|
if args.fp32_residual_connection: |
|
return torch.float |
|
return args.params_dtype |
|
|
|
|
|
|
|
def _allocate_recv_buffer(batch_size, sequence_length): |
|
"""Receive happens between the layers with size [s, b, h].""" |
|
if mpu.is_pipeline_first_stage(): |
|
return None |
|
args = get_args() |
|
recv_size = (sequence_length, batch_size, args.hidden_size) |
|
return torch.empty(recv_size, |
|
dtype=_get_recv_buffer_dtype(args), |
|
device=torch.cuda.current_device()) |
|
|
|
|
|
|
|
def _forward_step_helper(model, tokens, position_ids, attention_mask, |
|
inference_params, recv_buffer=None): |
|
"""Single forward step. Update the allocate memory flag so |
|
only the first time the memory is allocated.""" |
|
batch_size = tokens.size(0) |
|
sequence_length = tokens.size(1) |
|
if recv_buffer is None: |
|
recv_buffer = _allocate_recv_buffer(batch_size, sequence_length) |
|
|
|
|
|
recv_from_prev_pipeline_rank_(recv_buffer) |
|
|
|
|
|
model.set_input_tensor(recv_buffer) |
|
output_tensor = model(tokens, position_ids, attention_mask, |
|
inference_params=inference_params) |
|
|
|
|
|
send_to_next_pipeline_rank(output_tensor) |
|
|
|
return output_tensor |
|
|
|
|
|
|
|
def _no_pipelining_forward_step(model, tokens, position_ids, attention_mask, |
|
inference_params, recv_buffer=None): |
|
"""If recv_buffer is none, we will allocate one on the fly.""" |
|
|
|
output_tensor = _forward_step_helper(model, tokens, position_ids, |
|
attention_mask, inference_params, |
|
recv_buffer=recv_buffer) |
|
|
|
inference_params.sequence_len_offset += tokens.size(1) |
|
|
|
logits = None |
|
if mpu.is_pipeline_last_stage(): |
|
logits = output_tensor |
|
|
|
return logits |
|
|
|
|
|
|
|
def _with_pipelining_forward_step(model, tokens, position_ids, attention_mask, |
|
inference_params, micro_batch_size): |
|
"""No interleaving is supported.""" |
|
sequence_length = tokens.size(1) |
|
batch_size = tokens.size(0) |
|
|
|
|
|
num_micro_batches, last_chunk = divmod(batch_size, |
|
micro_batch_size) |
|
if last_chunk > 0: |
|
num_micro_batches += 1 |
|
|
|
|
|
logits = None |
|
if mpu.is_pipeline_last_stage(): |
|
args = get_args() |
|
logits = torch.empty( |
|
(batch_size, sequence_length, args.padded_vocab_size), |
|
dtype=torch.float32, device=torch.cuda.current_device()) |
|
|
|
|
|
recv_buffer = _allocate_recv_buffer(micro_batch_size, sequence_length) |
|
|
|
for micro_batch_index in range(num_micro_batches): |
|
|
|
start = micro_batch_index * micro_batch_size |
|
end = min(start + micro_batch_size, batch_size) |
|
this_micro_batch_size = end - start |
|
tokens2use = tokens[start:end, ...] |
|
position_ids2use = position_ids[start:end, ...] |
|
|
|
|
|
if this_micro_batch_size != micro_batch_size: |
|
recv_buffer = None |
|
output = _forward_step_helper(model, tokens2use, position_ids2use, |
|
attention_mask, inference_params, |
|
recv_buffer=recv_buffer) |
|
|
|
|
|
inference_params.batch_size_offset += this_micro_batch_size |
|
|
|
|
|
if mpu.is_pipeline_last_stage(): |
|
logits[start:end, ...] = output |
|
|
|
|
|
|
|
inference_params.sequence_len_offset += sequence_length |
|
|
|
inference_params.batch_size_offset = 0 |
|
|
|
return logits |
|
|