ericsorides commited on
Commit
4f07050
·
1 Parent(s): 8456488

Added new inputs and README

Browse files
Files changed (6) hide show
  1. README.md +146 -0
  2. config.json +0 -0
  3. model.onnx +2 -2
  4. special_tokens_map.json +24 -0
  5. tokenizer.json +0 -0
  6. tokenizer.model +3 -0
README.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - text-generation-inference
4
+ - llama2
5
+ base_model:
6
+ - meta-llama/Llama-2-7b-hf
7
+ ---
8
+
9
+
10
+ # Llama 2 7B with Key-Value-Cache enabled in ONNX fp16 format
11
+ - Model creator: [Meta-Llama](https://huggingface.co/meta-llama)
12
+ - Original model: [Meta-Llama Llama 2 7B](https://huggingface.co/meta-llama/Llama-2-7b-hf)
13
+
14
+ <!-- description start -->
15
+ ## Description
16
+
17
+ This repo contains the ONNX files for the ONNX conversion of Llama 2 7B done by Esperanto Technologies.
18
+ The model is in the fp16 format and has the KVC enabled.
19
+
20
+ <!-- description end -->
21
+
22
+ ## How to download ONNX model and weight files
23
+
24
+ The easiest way to obtain the model is to clone this whole repo.
25
+ Alternatively you can download the files is using the `huggingface-hub` Python library.
26
+
27
+ ```shell
28
+ pip3 install huggingface-hub>=0.17.1
29
+ ```
30
+
31
+ Then you can download any individual model file to the current directory, at high speed, with a command like this:
32
+
33
+ ```shell
34
+ huggingface-cli download Esperanto/llama2-7b-kvc-fp16-onnx --local-dir llama2-7b-kvc-fp16-onnx --local-dir-use-symlinks False
35
+ ```
36
+
37
+ For more documentation on downloading with `huggingface-cli`, please see: [HF -> Hub Python Library -> Download files -> Download from the CLI](https://huggingface.co/docs/huggingface_hub/guides/download#download-from-the-cli).
38
+
39
+ ## How to run from Python code using ONNXRuntime
40
+
41
+ This model can easily be ran in a CPU using [ONNXRuntime](https://onnxruntime.ai/).
42
+
43
+ #### First install the packages
44
+
45
+ ```bash
46
+ pip3 install onnx==1.16.1
47
+ pip3 install onnxruntime==1.17.1
48
+ ```
49
+
50
+ #### Example code: generate text with this model
51
+
52
+ We define the loop with greedy decoding:
53
+ ```python
54
+ import numpy as np
55
+ import onnxruntime
56
+ import onnx
57
+ from transformers import AutoTokenizer
58
+
59
+ def generate_text(model_path, prompt, tokenizer, max_gen_tokens, total_sequence, window, context):
60
+ model = onnx.load(model_path)
61
+
62
+ #we create the inputs for the first iteration
63
+ input_tensor = tokenizer(prompt, return_tensors="pt")
64
+ prompt_size = len(input_tensor['input_ids'][0])
65
+ actual_input = input_tensor['input_ids']
66
+ if prompt_size < window:
67
+ actual_input = np.concatenate((tokenizer.bos_token_id*np.ones([1, window - prompt_size], dtype = 'int64'),
68
+ actual_input), axis=1)
69
+ if prompt_size + max_gen_tokens > total_sequence:
70
+ print("ERROR: Longer total sequence is needed!")
71
+ return
72
+ first_attention = np.concatenate((np.zeros([1, total_sequence - window], dtype = 'int64'),
73
+ np.ones((1, window), dtype = 'int64')), axis=1)
74
+ max_gen_tokens += prompt_size #we need to generate on top of parsing the prompt
75
+ inputs_names =[node.name for node in model.graph.input]
76
+ output_names =[node.name for node in model.graph.output]
77
+ n_heads = 32 #gqa-heads of the kvc
78
+ inputs_dict = {}
79
+ inputs_dict['input_ids'] = actual_input[:, :window].reshape(1, window).numpy()
80
+ inputs_dict['attention_mask'] = first_attention
81
+ index_pos = sum(first_attention[0])
82
+ inputs_dict['position_ids'] = np.concatenate((np.zeros([1, total_sequence - index_pos], dtype = 'int64'), np.arange(index_pos, dtype = 'int64').reshape(1, index_pos)), axis=1)
83
+ inputs_dict['tree_attention'] = np.triu(-65504*np.ones(total_sequence), k= 1).astype('float16').reshape(1, 1, total_sequence, total_sequence)
84
+ for name in inputs_names:
85
+ if name == 'input_ids' or name == 'attention_mask' or name == 'position_ids' or name == 'tree_attention': continue
86
+ inputs_dict[name] = np.zeros([1, n_heads, context-window, 128], dtype="float16")
87
+ index = 0
88
+ new_token = np.array([10])
89
+ next_index = window
90
+ old_j = 0
91
+ total_input = actual_input.numpy()
92
+
93
+ rt_session = onnxruntime.InferenceSession(model_path)
94
+ ## We run the inferences
95
+ while next_index < max_gen_tokens:
96
+ if new_token.any() == tokenizer.eos_token_id:
97
+ break
98
+ #inference
99
+ output = rt_session.run(output_names, inputs_dict)
100
+ outs_dictionary = {name: content for (name, content) in zip (output_names, output)}
101
+ #we prepare the inputs for the next inference
102
+ for name in inputs_names:
103
+ if name == 'input_ids':
104
+ old_j = next_index
105
+ if next_index < prompt_size:
106
+ if prompt_size - next_index >= window: next_index += window
107
+ else: next_index = prompt_size
108
+ j = next_index - window
109
+ else:
110
+ next_index +=1
111
+ j = next_index - window
112
+ new_token = outs_dictionary['logits'].argmax(-1).reshape(1, window)
113
+ total_input = np.concatenate((total_input, new_token[: , -1:]), axis = 1)
114
+ inputs_dict['input_ids']= total_input[:, j:next_index].reshape(1, window)
115
+ elif name == 'attention_mask':
116
+ inputs_dict['attention_mask'] = np.concatenate((np.zeros((1, total_sequence-next_index), dtype = 'int64'), np.ones((1, next_index), dtype = 'int64')), axis=1)
117
+ elif name == 'position_ids':
118
+ inputs_dict['position_ids'] = np.concatenate((np.zeros([1, total_sequence - next_index], dtype = 'int64'), np.arange(next_index, dtype = 'int64').reshape(1, next_index)), axis=1)
119
+ elif name == 'tree_attention': continue
120
+ else:
121
+ old_name = name.replace("past_key_values", "present")
122
+ inputs_dict[name] = outs_dictionary[old_name][:, :, next_index-old_j:context-window+(next_index - old_j), :]
123
+
124
+ answer = tokenizer.decode(total_input[0], skip_special_tokens=True, clean_up_tokenization_spaces=False)
125
+ return answer
126
+ ```
127
+ We now run the inferences:
128
+
129
+ ```python
130
+ tokenizer = AutoTokenizer.from_pretrained("Esperanto/llama2-7b-kvc-fp16-onnx")
131
+ model_path = "llama2-7b-kvc-fp16-onnx/model.onnx"
132
+
133
+ max_gen_tokens = 20 #number of tokens we want tog eneral
134
+ total_sequence = 128 #total sequence_length
135
+ context = 1024 #the context to extend the kvc
136
+ window = 16 #number of tokens we want to parse at the time
137
+ messages = [
138
+ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
139
+ {"role": "user", "content": "Who are you?"},
140
+ ]
141
+
142
+ prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
143
+
144
+ generated = generate_text(model_path, prompt, tokenizer, max_gen_tokens, total_sequence, window, context)
145
+ print(generated)
146
+ ```
config.json ADDED
File without changes
model.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:77d77f5eeedd977b50681bc823379bf7819f3e4ba770175c8b645506b840893a
3
- size 34245486
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4eedfacaa72095936f44c4361c7d18f7c1b34fd1446201ef4c68da5844d45ec6
3
+ size 34237881
special_tokens_map.json ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "</s>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "pad_token": "<unk>",
17
+ "unk_token": {
18
+ "content": "<unk>",
19
+ "lstrip": false,
20
+ "normalized": true,
21
+ "rstrip": false,
22
+ "single_word": false
23
+ }
24
+ }
tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
tokenizer.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9e556afd44213b6bd1be2b850ebbbd98f5481437a8021afaf58ee7fb1818d347
3
+ size 499723