zamroni111 commited on
Commit
0bf203c
·
verified ·
1 Parent(s): bfe115c

Upload 2 files

Browse files
Files changed (2) hide show
  1. dml-device-specific-optim.py +17 -0
  2. onnxgenairun.py +108 -0
dml-device-specific-optim.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime as rt
2
+
3
+ sess_options = rt.SessionOptions()
4
+ sess_options.graph_optimization_level = rt.GraphOptimizationLevel.ORT_ENABLE_ALL
5
+
6
+ #########################################
7
+ ## Change the Path Accordingly
8
+ sess_options.optimized_model_filepath = "optimized_model.onnx"
9
+
10
+
11
+ #########################################
12
+ ## Change the model.onnx path accordingly
13
+ session = rt.InferenceSession("model.onnx" , sess_options,
14
+ ###providers=['xxxxxxxxxDmlExecutionProvider', 'CPUExecutionProvider'])
15
+ providers=['DmlExecutionProvider'])
16
+
17
+
onnxgenairun.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime_genai as og
2
+ import argparse
3
+ import time
4
+ import re
5
+
6
+
7
+ def main(args):
8
+ if args.verbose: print("Loading model...")
9
+ if args.timings:
10
+ started_timestamp = 0
11
+ first_token_timestamp = 0
12
+
13
+ model = og.Model(f'{args.model}')
14
+ ##########model = og.Model(".\\")
15
+ if args.verbose: print("Model loaded")
16
+ tokenizer = og.Tokenizer(model)
17
+ tokenizer_stream = tokenizer.create_stream()
18
+ if args.verbose: print("Tokenizer created")
19
+ if args.verbose: print()
20
+ search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}
21
+
22
+ # Set the max length to something sensible by default, unless it is specified by the user,
23
+ # since otherwise it will be set to the entire context length
24
+ if 'max_length' not in search_options:
25
+ search_options['max_length'] = 2048
26
+
27
+ chat_template = '<|user|>\n{input} <|end|>\n<|assistant|>'
28
+
29
+ # Keep asking for input prompts in a loop
30
+ while True:
31
+ text = input("Input: ")
32
+ if not text:
33
+ print("Error, input cannot be empty")
34
+ continue
35
+
36
+ if args.timings: started_timestamp = time.time()
37
+
38
+ # If there is a chat template, use it
39
+ prompt = f'{chat_template.format(input=text)}'
40
+
41
+ input_tokens = tokenizer.encode(prompt)
42
+
43
+ params = og.GeneratorParams(model)
44
+ params.set_search_options(**search_options)
45
+ params.input_ids = input_tokens
46
+ generator = og.Generator(model, params)
47
+ if args.verbose: print("Generator created")
48
+
49
+ if args.verbose: print("Running generation loop ...")
50
+ if args.timings:
51
+ first = True
52
+ new_tokens = []
53
+
54
+ print()
55
+ print("Output:\n", end='', flush=True)
56
+
57
+ try:
58
+ vPreviousDecoded = ""
59
+ vNewDecoded = ""
60
+ while not generator.is_done():
61
+ generator.compute_logits()
62
+ generator.generate_next_token()
63
+ if args.timings:
64
+ if first:
65
+ first_token_timestamp = time.time()
66
+ first = False
67
+
68
+ new_token = generator.get_next_tokens()[0]
69
+
70
+ ###print(tokenizer_stream.decode(new_token), end='', flush=True)
71
+
72
+
73
+ vNewDecoded = tokenizer_stream.decode(new_token)
74
+ if re.findall("^[\x2E\x3A\x3B]$", vPreviousDecoded) and vNewDecoded.startswith(" ") and (not vNewDecoded.startswith(" *")) :
75
+ vNewDecoded = "\n" + vNewDecoded.replace(" ", "", 1)
76
+
77
+ print(vNewDecoded, end='', flush=True)
78
+ vPreviousDecoded = vNewDecoded
79
+
80
+ if args.timings: new_tokens.append(new_token)
81
+ except KeyboardInterrupt:
82
+ print(" --control+c pressed, aborting generation--")
83
+ print()
84
+ print()
85
+
86
+ # Delete the generator to free the captured graph for the next generator, if graph capture is enabled
87
+ del generator
88
+
89
+ if args.timings:
90
+ prompt_time = first_token_timestamp - started_timestamp
91
+ run_time = time.time() - first_token_timestamp
92
+ print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps")
93
+
94
+
95
+ if __name__ == "__main__":
96
+ parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai")
97
+ parser.add_argument('-m', '--model', type=str, required=True, help='Onnx model folder path (must contain config.json and model.onnx)')
98
+ parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt')
99
+ parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt')
100
+ parser.add_argument('-ds', '--do_sample', action='store_true', default=False, help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false')
101
+ parser.add_argument('-p', '--top_p', type=float, help='Top p probability to sample with')
102
+ parser.add_argument('-k', '--top_k', type=int, help='Top k tokens to sample from')
103
+ parser.add_argument('-t', '--temperature', type=float, help='Temperature to sample with')
104
+ parser.add_argument('-r', '--repetition_penalty', type=float, help='Repetition penalty to sample with')
105
+ parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false')
106
+ parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false')
107
+ args = parser.parse_args()
108
+ main(args)