File size: 3,792 Bytes
8453337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import argparse
from time import perf_counter

import sys
sys.path.append('../')

from lyra_baichuan import lyraBaichuan7B, lyraBaichuan13B


def print_list(lines):
    # 清空终端输出
    print("\033c", end="")
    
    # 逐行打印字符串列表
    print('\n'.join(lines))
      
def get_args():
    parser = argparse.ArgumentParser(description="Faster Baichuan Demo")

    parser.add_argument('--model-path', type=str, required=True,
                        help='Model Path, include config.ini and tokenizer files')
    parser.add_argument('--tokenizer-path', type=str, default=None)

    parser.add_argument(
        '--data-type', type=str, metavar='TYPE', default='fp16',
        choices=[None, 'fp32', 'fp16', 'bf16', 'int8'],
        help='The data type to inference. If None, the data type follows the '
             'checkpoint data type.')

    parser.add_argument(
        '--memopt_mode', type=int, default=0, choices=[0, 1],
        help='Use MEMOPT mode to increase speed and reduce VRAM usage.'
             ' 0: FP16 mode'
             ' 1: Use MEMOPT mode')

    parser.add_argument("--prompt", type=str, required=False)
    parser.add_argument("--max-output-length", type=int, default=512)
    parser.add_argument("--warmups", type=int, default=10)
    parser.add_argument("--avgnums", type=int, default=10)
    args = parser.parse_args()

    print('\n=================== Arguments ===================')
    for k, v in vars(args).items():
        print(f' - {k.ljust(25, ".")}: {v}')
    print('=================================================')

    return args


def main():
    args = get_args()

    # model = lyraBaichuan7B(args.model_path, args.tokenizer_path, args.data_type, args.memopt_mode)
    model = lyraBaichuan13B(args.model_path, args.tokenizer_path, args.data_type, args.memopt_mode)

    # prompt_template = "<reserved_106>{}\n<reserved_107>" # baichuan chat
    prompt_template = "{}" # baichuan
    
    prompt = prompt_template.format(args.prompt)
    
    test_batch_size = [1, 2, 4] # 8, 16, 32, 64
    print("test_batch_size: ", test_batch_size)

    for i, bs in enumerate(test_batch_size):
        prompts = [prompt, ]*bs

        # warmup gpu        
        for _ in range(args.warmups):
            output_texts = model.generate(
                prompts, output_length=args.max_output_length,
                top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.1, do_sample=False)

        start = perf_counter()
        for _ in range(args.avgnums):
            for finish, output_texts in model.stream_generate(prompts, 
                                                            output_length=args.max_output_length,
                                                            top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.0, do_sample=False):
                print_list(output_texts)
                
                if finish:
                    break
        end = perf_counter()
        cost = (end - start) / args.avgnums

        input_output_texts = [prompt+' ' + gtext for prompt,
                            gtext in zip(prompts, output_texts)]
        tokens = 0
        input_tokens = len(model.tokenizer.encode(prompt))
        words = 0
        for text in input_output_texts:
            tokens += len(model.tokenizer.encode(text))
            words += len(text)
        print(
            f"\nFaster-Dtype: {args.data_type}, Batch Size: {bs}, All tokens: {tokens}. Input tokens: {input_tokens}. Cost: {cost} seconds. Speed: {tokens/cost} tokens/s."
        )
        print(
            f"Faster-Dtype: {args.data_type}, Batch Size: {bs}, All generated words: {words}. Cost: {cost} seconds. Speed: {words/cost} words/s."
        )
                    
if __name__ == "__main__":
    main()