crystal-technologies's picture
Upload 303 files
de4ade4
# Copyright 2022 MosaicML LLM Foundry authors
# SPDX-License-Identifier: Apache-2.0
# Copyright (c) 2021-2023, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021, NAVER Corp. Authored by CLOVA.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Run MPT model with FT.
This script is a modified version of
https://github.com/NVIDIA/FasterTransformer/blob/main/examples/pytorch/gpt/multi_gpu_gpt_example.py
"""
import argparse
import configparser
import os
import sys
import timeit
import torch
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer
dir_path = os.path.dirname(os.path.realpath(__file__))
sys.path.append(os.path.join(dir_path, '../../..'))
from examples.pytorch.gpt.utils import comm, gpt_decoder
from examples.pytorch.gpt.utils.parallel_gpt import ParallelGPT
@torch.no_grad()
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--layer_num',
type=int,
default=32,
help='number of layers')
parser.add_argument('--input_len',
type=int,
default=128,
help='input sequence length to generate.')
parser.add_argument('--output_len',
type=int,
default=64,
help='output sequence length to generate.')
parser.add_argument('--head_num', type=int, default=32, help='head number')
parser.add_argument('--size_per_head',
type=int,
default=128,
help='size per head')
parser.add_argument('--vocab_size',
type=int,
default=50432,
help='vocab size')
parser.add_argument(
'--beam_width',
type=int,
default=1,
help='beam width for beam search. Using sampling when beam width is 1.')
parser.add_argument('--top_k',
type=int,
default=1,
help='top k candidate num')
parser.add_argument('--top_p',
type=float,
default=0.95,
help='top p probability threshold')
parser.add_argument('--temperature',
type=float,
default=0.8,
help='temperature')
parser.add_argument('--len_penalty',
type=float,
default=0.,
help='len_penalty')
parser.add_argument('--beam_search_diversity_rate',
type=float,
default=0.,
help='beam_search_diversity_rate')
parser.add_argument('--tensor_para_size',
type=int,
default=1,
help='tensor parallel size')
parser.add_argument('--pipeline_para_size',
type=int,
default=1,
help='pipeline parallel size')
parser.add_argument('--ckpt_path',
type=str,
default='mpt-ft-7b/1-gpu',
help='path to the FT checkpoint file.')
parser.add_argument(
'--tokenizer_name_or_path',
type=str,
default='EleutherAI/gpt-neox-20b',
help=
'Name of the tokenizer or the directory where the tokenizer file is located.'
)
parser.add_argument(
'--lib_path',
type=str,
help=
'path to the libth_transformer dynamic lib file(.e.g., build/lib/libth_transformer.so.'
)
parser.add_argument('--start_id',
type=int,
default=0,
help='start token id.')
parser.add_argument('--end_id', type=int, default=0, help='end token id.')
parser.add_argument(
'--max_batch_size',
type=int,
default=8,
help=
'Max batch size. If sample_input_file is given, it is truncated to this max_batch_size, otherwise, this value is used as batch size.'
)
parser.add_argument('--repetition_penalty',
type=float,
default=5.,
help='repetition penalty')
parser.add_argument(
'--presence_penalty',
type=float,
default=0.,
help=
'presence penalty. Similar to repetition, but additive rather than multiplicative.'
)
parser.add_argument('--min_length',
type=int,
default=0,
help='A minimum number of tokens to generate')
parser.add_argument(
'--max_seq_len',
type=int,
default=2048,
help='max sequence length for position embedding table.')
parser.add_argument('--inference_data_type',
'--data_type',
type=str,
choices=['fp32', 'fp16', 'bf16'],
default='bf16')
parser.add_argument('--time',
action='store_true',
help='whether or not to measure time elapsed.')
parser.add_argument(
'--sample_input_file',
type=str,
default=None,
help=
'path to sample input file. If not set, it runs with no context inputs.'
)
parser.add_argument('--sample_output_file',
type=str,
default=None,
help='path to sample output file.')
parser.add_argument(
'--disable_random_seed',
dest='random_seed',
action='store_false',
help='Disable the use of random seed for sentences in a batch.')
parser.add_argument('--skip_end_tokens',
dest='skip_end_tokens',
action='store_false',
help='Whether to remove or not end tokens in outputs.')
parser.add_argument('--no_detokenize',
dest='detokenize',
action='store_false',
help='Skip detokenizing output token ids.')
parser.add_argument(
'--int8_mode',
type=int,
default=0,
choices=[0, 1],
help='The level of quantization to perform.' +
' 0: No quantization. All computation in data_type' +
' 1: Quantize weights to int8, all compute occurs in fp16/bf16. Not supported when data_type is fp32'
)
parser.add_argument(
'--weights_data_type',
type=str,
default='fp32',
choices=['fp32', 'fp16'],
help='Data type of FT checkpoint weights',
)
parser.add_argument(
'--return_cum_log_probs',
type=int,
default=0,
choices=[0, 1, 2],
help='Whether to compute the cumulative log probsbility of sentences.' +
' 0: do not return the cumulative log probs' +
' 1: return the cumulative log probs of generated sequences' +
' 2: return the cumulative log probs of sequences')
parser.add_argument('--shared_contexts_ratio',
type=float,
default=0.0,
help='Triggers the shared context optimization when ' +
'compact_size <= shared_contexts_ratio * batch_size ' +
'A value of 0.0 deactivate the optimization')
parser.add_argument(
'--use_gpt_decoder_ops',
action='store_true',
help='Use separate decoder FT operators instead of end-to-end model op.'
)
parser.add_argument(
'--no-alibi',
dest='alibi',
action='store_false',
help='Do not use ALiBi (aka use_attention_linear_bias).')
parser.add_argument(
'--layernorm_eps',
type=float,
default=1e-5,
help='layernorm eps in PyTorch, by default, is 1e-5 and 1e-6 in FT.')
args = parser.parse_args()
ckpt_config = configparser.ConfigParser()
ckpt_config_path = os.path.join(args.ckpt_path, 'config.ini')
if os.path.isfile(ckpt_config_path):
ckpt_config.read(ckpt_config_path)
if 'gpt' in ckpt_config.keys():
for args_key, config_key, func in [
('layer_num', 'num_layer', ckpt_config.getint),
('max_seq_len', 'max_pos_seq_len', ckpt_config.getint),
('weights_data_type', 'weight_data_type', ckpt_config.get),
('layernorm_eps', 'layernorm_eps', ckpt_config.getfloat),
('alibi', 'use_attention_linear_bias', ckpt_config.getboolean),
]:
if config_key in ckpt_config['gpt'].keys():
prev_val = args.__dict__[args_key]
args.__dict__[args_key] = func('gpt', config_key)
print(
'Loading {} from config.ini, previous: {}, current: {}'
.format(args_key, prev_val, args.__dict__[args_key]))
else:
print('Not loading {} from config.ini'.format(args_key))
for key in ['head_num', 'size_per_head', 'tensor_para_size']:
if key in args.__dict__:
prev_val = args.__dict__[key]
args.__dict__[key] = ckpt_config.getint('gpt', key)
print(
'Loading {} from config.ini, previous: {}, current: {}'
.format(key, prev_val, args.__dict__[key]))
else:
print('Not loading {} from config.ini'.format(key))
layer_num = args.layer_num
output_len = args.output_len
head_num = args.head_num
size_per_head = args.size_per_head
vocab_size = args.vocab_size
beam_width = args.beam_width
top_k = args.top_k
top_p = args.top_p
temperature = args.temperature
len_penalty = args.len_penalty
beam_search_diversity_rate = args.beam_search_diversity_rate
tensor_para_size = args.tensor_para_size
pipeline_para_size = args.pipeline_para_size
start_id = args.start_id
end_id = args.end_id
max_batch_size = args.max_batch_size
max_seq_len = args.max_seq_len
repetition_penalty = args.repetition_penalty
presence_penalty = args.presence_penalty
min_length = args.min_length
weights_data_type = args.weights_data_type
return_cum_log_probs = args.return_cum_log_probs
return_output_length = return_cum_log_probs > 0
shared_contexts_ratio = args.shared_contexts_ratio
layernorm_eps = args.layernorm_eps
use_attention_linear_bias = args.alibi
has_positional_encoding = not args.alibi
print('\n=================== Arguments ===================')
for k, v in vars(args).items():
print(f'{k.ljust(30, ".")}: {v}')
print('=================================================\n')
tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name_or_path)
torch.manual_seed(0)
comm.initialize_model_parallel(args.tensor_para_size,
args.pipeline_para_size)
rank = comm.get_rank()
device = comm.get_device()
# Inputs
contexts = []
if args.sample_input_file:
with open(args.sample_input_file, 'r') as f:
contexts = f.read().splitlines()
batch_size = min(len(contexts), max_batch_size)
contexts = contexts[:batch_size]
start_ids = [
torch.tensor(tokenizer.encode(c), dtype=torch.int32, device=device)
for c in contexts
]
else:
batch_size = max_batch_size
contexts = ['<|endoftext|>'] * batch_size
start_ids = [torch.IntTensor([end_id for _ in range(args.input_len)])
] * batch_size
start_lengths = [len(ids) for ids in start_ids]
start_ids = pad_sequence(start_ids, batch_first=True, padding_value=end_id)
start_lengths = torch.IntTensor(start_lengths)
# Prepare model.
if not args.use_gpt_decoder_ops:
gpt = ParallelGPT(head_num,
size_per_head,
vocab_size,
start_id,
end_id,
layer_num,
max_seq_len,
tensor_para_size,
pipeline_para_size,
lib_path=args.lib_path,
inference_data_type=args.inference_data_type,
int8_mode=args.int8_mode,
weights_data_type=weights_data_type,
layernorm_eps=layernorm_eps,
use_attention_linear_bias=use_attention_linear_bias,
has_positional_encoding=has_positional_encoding,
shared_contexts_ratio=shared_contexts_ratio)
if not gpt.load(ckpt_path=args.ckpt_path):
print(
'[WARNING] Checkpoint file not found. Model loading is skipped.'
)
else:
gpt = gpt_decoder.Gpt(num_heads=head_num,
size_per_head=size_per_head,
num_layers=layer_num,
vocab_size=vocab_size,
start_id=start_id,
end_id=end_id,
tensor_para_size=tensor_para_size,
pipeline_para_size=pipeline_para_size,
lib_path=args.lib_path,
max_seq_len=max_seq_len,
int8_mode=args.int8_mode,
weights_data_type=args.weights_data_type)
gpt.load(args.ckpt_path, args.inference_data_type)
if args.random_seed:
random_seed_tensor = torch.randint(0,
10000,
size=[batch_size],
dtype=torch.int64)
else:
random_seed_tensor = torch.zeros([batch_size], dtype=torch.int64)
repetition_penalty_vec = None if repetition_penalty == 1. else repetition_penalty * torch.ones(
batch_size, dtype=torch.float32)
presence_penalty_vec = None if presence_penalty == 0. else presence_penalty * torch.ones(
batch_size, dtype=torch.float32)
infer_decode_args = {
'beam_width':
beam_width,
'top_k':
top_k * torch.ones(batch_size, dtype=torch.int32),
'top_p':
top_p * torch.ones(batch_size, dtype=torch.float32),
'temperature':
temperature * torch.ones(batch_size, dtype=torch.float32),
'repetition_penalty':
repetition_penalty_vec,
'presence_penalty':
presence_penalty_vec,
'beam_search_diversity_rate':
beam_search_diversity_rate *
torch.ones(batch_size, dtype=torch.float32),
'len_penalty':
len_penalty * torch.ones(size=[batch_size], dtype=torch.float32),
'bad_words_list':
None,
'min_length':
min_length * torch.ones(size=[batch_size], dtype=torch.int32),
'random_seed':
random_seed_tensor
}
if not args.use_gpt_decoder_ops:
def gpt_generate_fn():
tokens_batch = gpt(start_ids,
start_lengths,
output_len,
return_output_length=return_output_length,
return_cum_log_probs=return_cum_log_probs,
**infer_decode_args)
return tokens_batch
else:
def gpt_generate_fn():
output_dict = gpt.generate(
input_token_ids=start_ids,
input_lengths=start_lengths,
gen_length=output_len,
eos_token_id=end_id,
return_output_length=return_output_length,
return_log_probs=return_cum_log_probs,
**infer_decode_args)
return output_dict
# Generate tokens.
gen_outputs = gpt_generate_fn()
if rank == 0:
if not args.use_gpt_decoder_ops:
if return_cum_log_probs > 0:
tokens_batch, _, cum_log_probs = gen_outputs
else:
tokens_batch, cum_log_probs = gen_outputs, None
else:
tokens_batch = gen_outputs['output_token_ids']
cum_log_probs = gen_outputs[
'cum_log_probs'] if return_cum_log_probs > 0 else None
if cum_log_probs is not None:
print('[INFO] Log probs of sentences:', cum_log_probs)
outputs = []
tokens_batch = tokens_batch.cpu().numpy()
for i, (context, tokens) in enumerate(zip(contexts, tokens_batch)):
for beam_id in range(beam_width):
token = tokens[beam_id][
start_lengths[i]:] # exclude context input from the output
if args.skip_end_tokens:
token = token[token != end_id]
output = tokenizer.decode(
token) if args.detokenize else ' '.join(
str(t) for t in token.tolist())
outputs.append(output)
print(
f'[INFO] batch {i}, beam {beam_id}:\n[Context]\n{context}\n\n[Output]\n{output}\n'
)
if args.sample_output_file:
with open(args.sample_output_file, 'w+') as f:
outputs = [o.replace('\n', '\\n') for o in outputs]
f.writelines('\n'.join(outputs))
# Measure inference time.
if args.time:
warmup_iterations = 10
for _ in range(warmup_iterations):
gpt_generate_fn()
torch.cuda.synchronize()
measurement_iterations = 10
time = timeit.default_timer()
for _ in range(measurement_iterations):
gpt_generate_fn()
torch.cuda.synchronize()
time_elapsed = timeit.default_timer() - time
if rank == 0:
print(f'[INFO] MPT time costs:')
print(
'model_name, gpu_type, gpu_count, batch_size, input_tokens, output_tokens, latency_ms'
)
print(
f'{ckpt_config.get("gpt", "model_name")}, {torch.cuda.get_device_name().replace(" ", "-")}, {torch.cuda.device_count()}, {batch_size}, {args.input_len}, {args.output_len}, {time_elapsed * 1000 / measurement_iterations:.2f}'
)
if __name__ == '__main__':
main()