File size: 4,026 Bytes
8453337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import argparse
from time import perf_counter

import sys
sys.path.append('../')

from lyra_baichuan import lyraBaichuan7B, lyraBaichuan13B


def get_args():
    parser = argparse.ArgumentParser(description="Faster Baichuan Demo")

    parser.add_argument('--model-path', type=str, required=True,
                        help='Model Path, include config.ini and tokenizer files')
    # parser.add_argument('--tokenizer-path', type=str, default='/group/30063/users/vanewu/LocalModels/ChatGLM6B-Torch/chatglm-6b')
    parser.add_argument('--tokenizer-path', type=str, default=None)

    parser.add_argument(
        '--data-type', type=str, metavar='TYPE', default='fp16',
        choices=[None, 'fp32', 'fp16', 'bf16', 'int8'],
        help='The data type to inference. If None, the data type follows the '
             'checkpoint data type.')

    parser.add_argument(
        '--memopt_mode', type=int, default=0, choices=[0, 1],
        help='Use MEMOPT mode to increase speed and reduce VRAM usage.'
             ' 0: FP16 mode'
             ' 1: Use MEMOPT mode')
    
    parser.add_argument(
        '--quant-type', type=str, metavar='TYPE', default='int8',
        choices=['int4', 'int8'],
        help='The data type of quantization. Only used in MEMOPT.')

    parser.add_argument("--prompt", type=str, required=False)
    parser.add_argument("--max-output-length", type=int, default=512)
    parser.add_argument("--warmups", type=int, default=10)
    parser.add_argument("--avgnums", type=int, default=10)
    args = parser.parse_args()

    print('\n=================== Arguments ===================')
    for k, v in vars(args).items():
        print(f' - {k.ljust(25, ".")}: {v}')
    print('=================================================')

    return args


def main():
    args = get_args()

    # model = lyraBaichuan7B(args.model_path, args.tokenizer_path, args.data_type, args.memopt_mode, args.quant_type)
    model = lyraBaichuan13B(args.model_path, args.tokenizer_path, args.data_type, args.memopt_mode, args.quant_type)

    # prompt_template = "<reserved_106>{}\n<reserved_107>" # baichuan chat
    prompt_template = "{}" # baichuan
    
    prompt = prompt_template.format(args.prompt)

    test_batch_size = [1, 2, 4] # 8, 16, 32, 64
    print("test_batch_size: ", test_batch_size)

    for i, bs in enumerate(test_batch_size):
        prompts = [prompt, ]*bs

        # warmup gpu        
        for _ in range(args.warmups):
            output_texts = model.generate(
                prompts, output_length=args.max_output_length,
                top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.1, do_sample=False)

        start = perf_counter()
        for _ in range(args.avgnums):
            output_texts = model.generate(
                prompts, output_length=args.max_output_length,
                top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.0, do_sample=False)
        end = perf_counter()
        cost = (end - start) / args.avgnums

        input_output_texts = [prompt+' ' + gtext for prompt,
                            gtext in zip(prompts, output_texts)]
        tokens = 0
        input_tokens = len(model.tokenizer.encode(prompt))
        words = 0
        for text in input_output_texts:
            tokens += len(model.tokenizer.encode(text))
            words += len(text)
        print(
            f"\nFaster-Dtype: {args.data_type}, Batch Size: {bs}, All tokens: {tokens}. Input tokens: {input_tokens}. Cost: {cost} seconds. Speed: {tokens/cost} tokens/s."
        )
        print(
            f"Faster-Dtype: {args.data_type}, Batch Size: {bs}, All generated words: {words}. Cost: {cost} seconds. Speed: {words/cost} words/s."
        )
        
        if i == 0:
            for k in range(bs):
                print(
                    f"The {k} Sample, \n\t\tInputs: {prompts[k]}. \n\t\tOutputs: {output_texts[k].lstrip()}")
                if k>2:
                    break
                    
if __name__ == "__main__":
    main()