carsonhxsu
# This is a combination of 22 commits.
8453337
import argparse
from time import perf_counter
import sys
sys.path.append('../')
from lyra_baichuan import lyraBaichuan7B, lyraBaichuan13B
def get_args():
parser = argparse.ArgumentParser(description="Faster Baichuan Demo")
parser.add_argument('--model-path', type=str, required=True,
help='Model Path, include config.ini and tokenizer files')
# parser.add_argument('--tokenizer-path', type=str, default='/group/30063/users/vanewu/LocalModels/ChatGLM6B-Torch/chatglm-6b')
parser.add_argument('--tokenizer-path', type=str, default=None)
parser.add_argument(
'--data-type', type=str, metavar='TYPE', default='fp16',
choices=[None, 'fp32', 'fp16', 'bf16', 'int8'],
help='The data type to inference. If None, the data type follows the '
'checkpoint data type.')
parser.add_argument(
'--memopt_mode', type=int, default=0, choices=[0, 1],
help='Use MEMOPT mode to increase speed and reduce VRAM usage.'
' 0: FP16 mode'
' 1: Use MEMOPT mode')
parser.add_argument(
'--quant-type', type=str, metavar='TYPE', default='int8',
choices=['int4', 'int8'],
help='The data type of quantization. Only used in MEMOPT.')
parser.add_argument("--prompt", type=str, required=False)
parser.add_argument("--max-output-length", type=int, default=512)
parser.add_argument("--warmups", type=int, default=10)
parser.add_argument("--avgnums", type=int, default=10)
args = parser.parse_args()
print('\n=================== Arguments ===================')
for k, v in vars(args).items():
print(f' - {k.ljust(25, ".")}: {v}')
print('=================================================')
return args
def main():
args = get_args()
# model = lyraBaichuan7B(args.model_path, args.tokenizer_path, args.data_type, args.memopt_mode, args.quant_type)
model = lyraBaichuan13B(args.model_path, args.tokenizer_path, args.data_type, args.memopt_mode, args.quant_type)
# prompt_template = "<reserved_106>{}\n<reserved_107>" # baichuan chat
prompt_template = "{}" # baichuan
prompt = prompt_template.format(args.prompt)
test_batch_size = [1, 2, 4] # 8, 16, 32, 64
print("test_batch_size: ", test_batch_size)
for i, bs in enumerate(test_batch_size):
prompts = [prompt, ]*bs
# warmup gpu
for _ in range(args.warmups):
output_texts = model.generate(
prompts, output_length=args.max_output_length,
top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.1, do_sample=False)
start = perf_counter()
for _ in range(args.avgnums):
output_texts = model.generate(
prompts, output_length=args.max_output_length,
top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.0, do_sample=False)
end = perf_counter()
cost = (end - start) / args.avgnums
input_output_texts = [prompt+' ' + gtext for prompt,
gtext in zip(prompts, output_texts)]
tokens = 0
input_tokens = len(model.tokenizer.encode(prompt))
words = 0
for text in input_output_texts:
tokens += len(model.tokenizer.encode(text))
words += len(text)
print(
f"\nFaster-Dtype: {args.data_type}, Batch Size: {bs}, All tokens: {tokens}. Input tokens: {input_tokens}. Cost: {cost} seconds. Speed: {tokens/cost} tokens/s."
)
print(
f"Faster-Dtype: {args.data_type}, Batch Size: {bs}, All generated words: {words}. Cost: {cost} seconds. Speed: {words/cost} words/s."
)
if i == 0:
for k in range(bs):
print(
f"The {k} Sample, \n\t\tInputs: {prompts[k]}. \n\t\tOutputs: {output_texts[k].lstrip()}")
if k>2:
break
if __name__ == "__main__":
main()