|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Run MPT model with FT. |
|
|
|
This script is a modified version of |
|
https://github.com/NVIDIA/FasterTransformer/blob/main/examples/pytorch/gpt/multi_gpu_gpt_example.py |
|
""" |
|
|
|
import argparse |
|
import configparser |
|
import os |
|
import sys |
|
import timeit |
|
|
|
import torch |
|
from torch.nn.utils.rnn import pad_sequence |
|
from transformers import AutoTokenizer |
|
|
|
dir_path = os.path.dirname(os.path.realpath(__file__)) |
|
sys.path.append(os.path.join(dir_path, '../../..')) |
|
from examples.pytorch.gpt.utils import comm, gpt_decoder |
|
from examples.pytorch.gpt.utils.parallel_gpt import ParallelGPT |
|
|
|
|
|
@torch.no_grad() |
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--layer_num', |
|
type=int, |
|
default=32, |
|
help='number of layers') |
|
parser.add_argument('--input_len', |
|
type=int, |
|
default=128, |
|
help='input sequence length to generate.') |
|
parser.add_argument('--output_len', |
|
type=int, |
|
default=64, |
|
help='output sequence length to generate.') |
|
parser.add_argument('--head_num', type=int, default=32, help='head number') |
|
parser.add_argument('--size_per_head', |
|
type=int, |
|
default=128, |
|
help='size per head') |
|
parser.add_argument('--vocab_size', |
|
type=int, |
|
default=50432, |
|
help='vocab size') |
|
parser.add_argument( |
|
'--beam_width', |
|
type=int, |
|
default=1, |
|
help='beam width for beam search. Using sampling when beam width is 1.') |
|
parser.add_argument('--top_k', |
|
type=int, |
|
default=1, |
|
help='top k candidate num') |
|
parser.add_argument('--top_p', |
|
type=float, |
|
default=0.95, |
|
help='top p probability threshold') |
|
parser.add_argument('--temperature', |
|
type=float, |
|
default=0.8, |
|
help='temperature') |
|
parser.add_argument('--len_penalty', |
|
type=float, |
|
default=0., |
|
help='len_penalty') |
|
parser.add_argument('--beam_search_diversity_rate', |
|
type=float, |
|
default=0., |
|
help='beam_search_diversity_rate') |
|
parser.add_argument('--tensor_para_size', |
|
type=int, |
|
default=1, |
|
help='tensor parallel size') |
|
parser.add_argument('--pipeline_para_size', |
|
type=int, |
|
default=1, |
|
help='pipeline parallel size') |
|
parser.add_argument('--ckpt_path', |
|
type=str, |
|
default='mpt-ft-7b/1-gpu', |
|
help='path to the FT checkpoint file.') |
|
parser.add_argument( |
|
'--tokenizer_name_or_path', |
|
type=str, |
|
default='EleutherAI/gpt-neox-20b', |
|
help= |
|
'Name of the tokenizer or the directory where the tokenizer file is located.' |
|
) |
|
parser.add_argument( |
|
'--lib_path', |
|
type=str, |
|
help= |
|
'path to the libth_transformer dynamic lib file(.e.g., build/lib/libth_transformer.so.' |
|
) |
|
parser.add_argument('--start_id', |
|
type=int, |
|
default=0, |
|
help='start token id.') |
|
parser.add_argument('--end_id', type=int, default=0, help='end token id.') |
|
parser.add_argument( |
|
'--max_batch_size', |
|
type=int, |
|
default=8, |
|
help= |
|
'Max batch size. If sample_input_file is given, it is truncated to this max_batch_size, otherwise, this value is used as batch size.' |
|
) |
|
parser.add_argument('--repetition_penalty', |
|
type=float, |
|
default=5., |
|
help='repetition penalty') |
|
parser.add_argument( |
|
'--presence_penalty', |
|
type=float, |
|
default=0., |
|
help= |
|
'presence penalty. Similar to repetition, but additive rather than multiplicative.' |
|
) |
|
parser.add_argument('--min_length', |
|
type=int, |
|
default=0, |
|
help='A minimum number of tokens to generate') |
|
parser.add_argument( |
|
'--max_seq_len', |
|
type=int, |
|
default=2048, |
|
help='max sequence length for position embedding table.') |
|
parser.add_argument('--inference_data_type', |
|
'--data_type', |
|
type=str, |
|
choices=['fp32', 'fp16', 'bf16'], |
|
default='bf16') |
|
parser.add_argument('--time', |
|
action='store_true', |
|
help='whether or not to measure time elapsed.') |
|
parser.add_argument( |
|
'--sample_input_file', |
|
type=str, |
|
default=None, |
|
help= |
|
'path to sample input file. If not set, it runs with no context inputs.' |
|
) |
|
parser.add_argument('--sample_output_file', |
|
type=str, |
|
default=None, |
|
help='path to sample output file.') |
|
parser.add_argument( |
|
'--disable_random_seed', |
|
dest='random_seed', |
|
action='store_false', |
|
help='Disable the use of random seed for sentences in a batch.') |
|
parser.add_argument('--skip_end_tokens', |
|
dest='skip_end_tokens', |
|
action='store_false', |
|
help='Whether to remove or not end tokens in outputs.') |
|
parser.add_argument('--no_detokenize', |
|
dest='detokenize', |
|
action='store_false', |
|
help='Skip detokenizing output token ids.') |
|
parser.add_argument( |
|
'--int8_mode', |
|
type=int, |
|
default=0, |
|
choices=[0, 1], |
|
help='The level of quantization to perform.' + |
|
' 0: No quantization. All computation in data_type' + |
|
' 1: Quantize weights to int8, all compute occurs in fp16/bf16. Not supported when data_type is fp32' |
|
) |
|
parser.add_argument( |
|
'--weights_data_type', |
|
type=str, |
|
default='fp32', |
|
choices=['fp32', 'fp16'], |
|
help='Data type of FT checkpoint weights', |
|
) |
|
parser.add_argument( |
|
'--return_cum_log_probs', |
|
type=int, |
|
default=0, |
|
choices=[0, 1, 2], |
|
help='Whether to compute the cumulative log probsbility of sentences.' + |
|
' 0: do not return the cumulative log probs' + |
|
' 1: return the cumulative log probs of generated sequences' + |
|
' 2: return the cumulative log probs of sequences') |
|
parser.add_argument('--shared_contexts_ratio', |
|
type=float, |
|
default=0.0, |
|
help='Triggers the shared context optimization when ' + |
|
'compact_size <= shared_contexts_ratio * batch_size ' + |
|
'A value of 0.0 deactivate the optimization') |
|
parser.add_argument( |
|
'--use_gpt_decoder_ops', |
|
action='store_true', |
|
help='Use separate decoder FT operators instead of end-to-end model op.' |
|
) |
|
parser.add_argument( |
|
'--no-alibi', |
|
dest='alibi', |
|
action='store_false', |
|
help='Do not use ALiBi (aka use_attention_linear_bias).') |
|
parser.add_argument( |
|
'--layernorm_eps', |
|
type=float, |
|
default=1e-5, |
|
help='layernorm eps in PyTorch, by default, is 1e-5 and 1e-6 in FT.') |
|
args = parser.parse_args() |
|
|
|
ckpt_config = configparser.ConfigParser() |
|
ckpt_config_path = os.path.join(args.ckpt_path, 'config.ini') |
|
if os.path.isfile(ckpt_config_path): |
|
ckpt_config.read(ckpt_config_path) |
|
if 'gpt' in ckpt_config.keys(): |
|
for args_key, config_key, func in [ |
|
('layer_num', 'num_layer', ckpt_config.getint), |
|
('max_seq_len', 'max_pos_seq_len', ckpt_config.getint), |
|
('weights_data_type', 'weight_data_type', ckpt_config.get), |
|
('layernorm_eps', 'layernorm_eps', ckpt_config.getfloat), |
|
('alibi', 'use_attention_linear_bias', ckpt_config.getboolean), |
|
]: |
|
if config_key in ckpt_config['gpt'].keys(): |
|
prev_val = args.__dict__[args_key] |
|
args.__dict__[args_key] = func('gpt', config_key) |
|
print( |
|
'Loading {} from config.ini, previous: {}, current: {}' |
|
.format(args_key, prev_val, args.__dict__[args_key])) |
|
else: |
|
print('Not loading {} from config.ini'.format(args_key)) |
|
for key in ['head_num', 'size_per_head', 'tensor_para_size']: |
|
if key in args.__dict__: |
|
prev_val = args.__dict__[key] |
|
args.__dict__[key] = ckpt_config.getint('gpt', key) |
|
print( |
|
'Loading {} from config.ini, previous: {}, current: {}' |
|
.format(key, prev_val, args.__dict__[key])) |
|
else: |
|
print('Not loading {} from config.ini'.format(key)) |
|
|
|
layer_num = args.layer_num |
|
output_len = args.output_len |
|
head_num = args.head_num |
|
size_per_head = args.size_per_head |
|
vocab_size = args.vocab_size |
|
beam_width = args.beam_width |
|
top_k = args.top_k |
|
top_p = args.top_p |
|
temperature = args.temperature |
|
len_penalty = args.len_penalty |
|
beam_search_diversity_rate = args.beam_search_diversity_rate |
|
tensor_para_size = args.tensor_para_size |
|
pipeline_para_size = args.pipeline_para_size |
|
start_id = args.start_id |
|
end_id = args.end_id |
|
max_batch_size = args.max_batch_size |
|
max_seq_len = args.max_seq_len |
|
repetition_penalty = args.repetition_penalty |
|
presence_penalty = args.presence_penalty |
|
min_length = args.min_length |
|
weights_data_type = args.weights_data_type |
|
return_cum_log_probs = args.return_cum_log_probs |
|
return_output_length = return_cum_log_probs > 0 |
|
shared_contexts_ratio = args.shared_contexts_ratio |
|
layernorm_eps = args.layernorm_eps |
|
use_attention_linear_bias = args.alibi |
|
has_positional_encoding = not args.alibi |
|
|
|
print('\n=================== Arguments ===================') |
|
for k, v in vars(args).items(): |
|
print(f'{k.ljust(30, ".")}: {v}') |
|
print('=================================================\n') |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path) |
|
torch.manual_seed(0) |
|
|
|
comm.initialize_model_parallel(args.tensor_para_size, |
|
args.pipeline_para_size) |
|
rank = comm.get_rank() |
|
device = comm.get_device() |
|
|
|
|
|
contexts = [] |
|
if args.sample_input_file: |
|
with open(args.sample_input_file, 'r') as f: |
|
contexts = f.read().splitlines() |
|
batch_size = min(len(contexts), max_batch_size) |
|
contexts = contexts[:batch_size] |
|
start_ids = [ |
|
torch.tensor(tokenizer.encode(c), dtype=torch.int32, device=device) |
|
for c in contexts |
|
] |
|
else: |
|
batch_size = max_batch_size |
|
contexts = ['<|endoftext|>'] * batch_size |
|
start_ids = [torch.IntTensor([end_id for _ in range(args.input_len)]) |
|
] * batch_size |
|
|
|
start_lengths = [len(ids) for ids in start_ids] |
|
|
|
start_ids = pad_sequence(start_ids, batch_first=True, padding_value=end_id) |
|
start_lengths = torch.IntTensor(start_lengths) |
|
|
|
|
|
if not args.use_gpt_decoder_ops: |
|
gpt = ParallelGPT(head_num, |
|
size_per_head, |
|
vocab_size, |
|
start_id, |
|
end_id, |
|
layer_num, |
|
max_seq_len, |
|
tensor_para_size, |
|
pipeline_para_size, |
|
lib_path=args.lib_path, |
|
inference_data_type=args.inference_data_type, |
|
int8_mode=args.int8_mode, |
|
weights_data_type=weights_data_type, |
|
layernorm_eps=layernorm_eps, |
|
use_attention_linear_bias=use_attention_linear_bias, |
|
has_positional_encoding=has_positional_encoding, |
|
shared_contexts_ratio=shared_contexts_ratio) |
|
if not gpt.load(ckpt_path=args.ckpt_path): |
|
print( |
|
'[WARNING] Checkpoint file not found. Model loading is skipped.' |
|
) |
|
else: |
|
gpt = gpt_decoder.Gpt(num_heads=head_num, |
|
size_per_head=size_per_head, |
|
num_layers=layer_num, |
|
vocab_size=vocab_size, |
|
start_id=start_id, |
|
end_id=end_id, |
|
tensor_para_size=tensor_para_size, |
|
pipeline_para_size=pipeline_para_size, |
|
lib_path=args.lib_path, |
|
max_seq_len=max_seq_len, |
|
int8_mode=args.int8_mode, |
|
weights_data_type=args.weights_data_type) |
|
gpt.load(args.ckpt_path, args.inference_data_type) |
|
|
|
if args.random_seed: |
|
random_seed_tensor = torch.randint(0, |
|
10000, |
|
size=[batch_size], |
|
dtype=torch.int64) |
|
else: |
|
random_seed_tensor = torch.zeros([batch_size], dtype=torch.int64) |
|
|
|
repetition_penalty_vec = None if repetition_penalty == 1. else repetition_penalty * torch.ones( |
|
batch_size, dtype=torch.float32) |
|
presence_penalty_vec = None if presence_penalty == 0. else presence_penalty * torch.ones( |
|
batch_size, dtype=torch.float32) |
|
|
|
infer_decode_args = { |
|
'beam_width': |
|
beam_width, |
|
'top_k': |
|
top_k * torch.ones(batch_size, dtype=torch.int32), |
|
'top_p': |
|
top_p * torch.ones(batch_size, dtype=torch.float32), |
|
'temperature': |
|
temperature * torch.ones(batch_size, dtype=torch.float32), |
|
'repetition_penalty': |
|
repetition_penalty_vec, |
|
'presence_penalty': |
|
presence_penalty_vec, |
|
'beam_search_diversity_rate': |
|
beam_search_diversity_rate * |
|
torch.ones(batch_size, dtype=torch.float32), |
|
'len_penalty': |
|
len_penalty * torch.ones(size=[batch_size], dtype=torch.float32), |
|
'bad_words_list': |
|
None, |
|
'min_length': |
|
min_length * torch.ones(size=[batch_size], dtype=torch.int32), |
|
'random_seed': |
|
random_seed_tensor |
|
} |
|
|
|
if not args.use_gpt_decoder_ops: |
|
|
|
def gpt_generate_fn(): |
|
tokens_batch = gpt(start_ids, |
|
start_lengths, |
|
output_len, |
|
return_output_length=return_output_length, |
|
return_cum_log_probs=return_cum_log_probs, |
|
**infer_decode_args) |
|
return tokens_batch |
|
else: |
|
|
|
def gpt_generate_fn(): |
|
output_dict = gpt.generate( |
|
input_token_ids=start_ids, |
|
input_lengths=start_lengths, |
|
gen_length=output_len, |
|
eos_token_id=end_id, |
|
return_output_length=return_output_length, |
|
return_log_probs=return_cum_log_probs, |
|
**infer_decode_args) |
|
return output_dict |
|
|
|
|
|
gen_outputs = gpt_generate_fn() |
|
|
|
if rank == 0: |
|
if not args.use_gpt_decoder_ops: |
|
if return_cum_log_probs > 0: |
|
tokens_batch, _, cum_log_probs = gen_outputs |
|
else: |
|
tokens_batch, cum_log_probs = gen_outputs, None |
|
else: |
|
tokens_batch = gen_outputs['output_token_ids'] |
|
cum_log_probs = gen_outputs[ |
|
'cum_log_probs'] if return_cum_log_probs > 0 else None |
|
if cum_log_probs is not None: |
|
print('[INFO] Log probs of sentences:', cum_log_probs) |
|
|
|
outputs = [] |
|
tokens_batch = tokens_batch.cpu().numpy() |
|
for i, (context, tokens) in enumerate(zip(contexts, tokens_batch)): |
|
for beam_id in range(beam_width): |
|
token = tokens[beam_id][ |
|
start_lengths[i]:] |
|
if args.skip_end_tokens: |
|
token = token[token != end_id] |
|
output = tokenizer.decode( |
|
token) if args.detokenize else ' '.join( |
|
str(t) for t in token.tolist()) |
|
outputs.append(output) |
|
print( |
|
f'[INFO] batch {i}, beam {beam_id}:\n[Context]\n{context}\n\n[Output]\n{output}\n' |
|
) |
|
|
|
if args.sample_output_file: |
|
with open(args.sample_output_file, 'w+') as f: |
|
outputs = [o.replace('\n', '\\n') for o in outputs] |
|
f.writelines('\n'.join(outputs)) |
|
|
|
|
|
if args.time: |
|
warmup_iterations = 10 |
|
for _ in range(warmup_iterations): |
|
gpt_generate_fn() |
|
torch.cuda.synchronize() |
|
measurement_iterations = 10 |
|
time = timeit.default_timer() |
|
for _ in range(measurement_iterations): |
|
gpt_generate_fn() |
|
torch.cuda.synchronize() |
|
time_elapsed = timeit.default_timer() - time |
|
if rank == 0: |
|
print(f'[INFO] MPT time costs:') |
|
print( |
|
'model_name, gpu_type, gpu_count, batch_size, input_tokens, output_tokens, latency_ms' |
|
) |
|
print( |
|
f'{ckpt_config.get("gpt", "model_name")}, {torch.cuda.get_device_name().replace(" ", "-")}, {torch.cuda.device_count()}, {batch_size}, {args.input_len}, {args.output_len}, {time_elapsed * 1000 / measurement_iterations:.2f}' |
|
) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|