File size: 4,422 Bytes
8453337
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import argparse
import json
import random
import numpy as np

from time import perf_counter

import sys
sys.path.append('../')
from lyra_baichuan import lyraBaichuan7B, lyraBaichuan13B


def get_args():
    parser = argparse.ArgumentParser(description="Faster Baichuan Demo")

    parser.add_argument('--model-path', type=str, required=True,
                        help='Model Path, include config.ini and tokenizer files')
    # parser.add_argument('--tokenizer-path', type=str, default='/group/30063/users/vanewu/LocalModels/ChatGLM6B-Torch/chatglm-6b')
    parser.add_argument('--tokenizer-path', type=str, default=None)

    parser.add_argument(
        '--data-type', type=str, metavar='TYPE', default='fp16',
        choices=[None, 'fp32', 'fp16', 'bf16', 'int8'],
        help='The data type to inference. If None, the data type follows the '
             'checkpoint data type.')

    parser.add_argument(
        '--memopt_mode', type=int, default=0, choices=[0, 1],
        help='Use MEMOPT mode to increase speed and reduce VRAM usage.'
             ' 0: FP16 mode'
             ' 1: Use MEMOPT mode')

    parser.add_argument("--prompt_filepath", type=str, required=True)
    parser.add_argument("--max-output-length", type=int, default=512)
    parser.add_argument("--warmups", type=int, default=10)
    parser.add_argument("--avgnums", type=int, default=10)
    args = parser.parse_args()

    print('\n=================== Arguments ===================')
    for k, v in vars(args).items():
        print(f' - {k.ljust(25, ".")}: {v}')
    print('=================================================')

    return args


def main():
    args = get_args()

    # model = lyraBaichuan7B(args.model_path, args.tokenizer_path, args.data_type, args.memopt_mode)
    model = lyraBaichuan13B(args.model_path, args.tokenizer_path, args.data_type, args.memopt_mode)

    with open(args.prompt_filepath, "rb") as f:
        input_datas = json.loads(f.read())

    used_input_data = input_datas[0]
    
    # prompt_template = "<reserved_106>{}\n<reserved_107>" # baichuan chat
    prompt_template = "{}" # baichuan
    
    test_batch_size = [1, 2, 4,] # 8, 16, 32, 64
    print("test_batch_size: ", test_batch_size)

    for i, bs in enumerate(test_batch_size):
        all_use_prompts = []
        all_output_texts = []

        # warmup gpu        
        for _ in range(args.warmups):
            prompts = [prompt_template.format( used_input_data['prompts'].format(*x) ) for x in random.choices(used_input_data['contents'], bs)]
            output_texts = model.generate(
                prompts, output_length=args.max_output_length,
                top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.0, do_sample=False)

        all_cost_s = 0.0
        
        for _ in range(args.avgnums):
            prompts = [prompt_template.format( used_input_data['prompts'].format(*x) ) for x in random.choices(used_input_data['contents'], bs)]
            all_use_prompts.extend(prompts)

            start = perf_counter()
            output_texts = model.generate(
                prompts, output_length=args.max_output_length,
                top_k=30, top_p=0.85, temperature=1.0, repetition_penalty=1.0, do_sample=False)
            all_cost_s += perf_counter() - start

            all_output_texts.extend(output_texts)

        cost = all_cost_s / args.avgnums

        input_output_texts = [prompt + ' ' + gtext for prompt,gtext in zip(all_use_prompts, all_output_texts)]

        tokens = 0
        avg_input_tokens = np.mean([len(model.tokenizer.encode(prompt)) for prompt in all_use_prompts])

        words = 0
        for text in input_output_texts:
            tokens += len(model.tokenizer.encode(text))
            words += len(text)
        print(
            f"\nFaster-Dtype: {args.data_type}, Batch Size: {bs}, All tokens: {tokens}. Avg Input tokens: {avg_input_tokens}. Cost: {cost} seconds. Speed: {tokens/cost} tokens/s."
        )
        print(
            f"Faster-Dtype: {args.data_type}, Batch Size: {bs}, All generated words: {words}. Cost: {cost} seconds. Speed: {words/cost} words/s."
        )
        
        if i == 0:
            for k in range(bs):
                print(
                    f"The {k} Sample, \n\t\tInputs: {prompts[k]}. \n\t\tOutputs: {output_texts[k].lstrip()}")
                if k>2:
                    break

if __name__ == "__main__":
    main()