--- tags: - text-generation-inference - gemma --- # Gemma 7B instruct with Key-Value-Cache enabled in ONNX fp16 format - Model creator: [Google](https://huggingface.co/google) - Original model: [Gemma 7B instruct](https://huggingface.co/google/gemma-7b-it) ## Description This repo contains the ONNX files of the ONNX conversion of Gemma 7B instruct done by Esperanto Technologies. The model is in the fp16 format and has the KVC enabled. ## How to download ONNX model and weight files The easiest way to obtain the model is to clone this whole repo. Alternatively you can download the files is using the `huggingface-hub` Python library. ```shell pip3 install huggingface-hub>=0.17.1 ``` Then you can download any individual model file to the current directory, at high speed, with a command like this: ```shell huggingface-cli download Esperanto/gemma-7b-it-kvc-fp16-onnx --local-dir gemma-7b-it-kvc-fp16-onnx --local-dir-use-symlinks False ``` For more documentation on downloading with `huggingface-cli`, please see: [HF -> Hub Python Library -> Download files -> Download from the CLI](https://huggingface.co/docs/huggingface_hub/guides/download#download-from-the-cli). ## How to run from Python code using ONNXRuntime This model can easily be ran in a CPU using [ONNXRuntime](https://onnxruntime.ai/). #### First install the packages ```bash pip3 install onnx==1.16.1 pip3 install onnxruntime==1.17.1 ``` #### Example code: generate text with this model We define the loop with greedy decoding: ```python import numpy as np import onnxruntime import onnx from transformers import AutoTokenizer def generate_text(model_path, prompt, tokenizer, max_gen_tokens, total_sequence, window, context): model = onnx.load(model_path) #we create the inputs for the first iteration input_tensor = tokenizer(prompt, return_tensors="pt") prompt_size = len(input_tensor['input_ids'][0]) actual_input = input_tensor['input_ids'] if prompt_size < window: actual_input = np.concatenate((tokenizer.bos_token_id*np.ones([1, window - prompt_size], dtype = 'int64'), actual_input), axis=1) if prompt_size + max_gen_tokens > total_sequence: print("ERROR: Longer total sequence is needed!") return first_attention = np.concatenate((np.zeros([1, total_sequence - window], dtype = 'int64'), np.ones((1, window), dtype = 'int64')), axis=1) max_gen_tokens += prompt_size #we need to generate on top of parsing the prompt inputs_names =[node.name for node in model.graph.input] output_names =[node.name for node in model.graph.output] n_heads = 16 #gqa-heads of the kvc inputs_dict = {} inputs_dict['input_ids'] = actual_input[:, :window].reshape(1, window).numpy() inputs_dict['attention_mask'] = first_attention for name in inputs_names: if name == 'input_ids' or name == 'attention_mask': continue inputs_dict[name] = np.zeros([1, n_heads, context-window, 256], dtype="float16") index = 0 new_token = np.array([10]) next_index = window old_j = 0 total_input = actual_input.numpy() rt_session = onnxruntime.InferenceSession(model_path) ## We run the inferences while next_index < max_gen_tokens: if new_token.any() == tokenizer.eos_token_id: break #inference output = rt_session.run(output_names, inputs_dict) outs_dictionary = {name: content for (name, content) in zip (output_names, output)} #we prepare the inputs for the next inference for name in inputs_names: if name == 'input_ids': old_j = next_index if next_index < prompt_size: if prompt_size - next_index >= window: next_index += window else: next_index = prompt_size j = next_index - window else: next_index +=1 j = next_index - window new_token = outs_dictionary['logits'].argmax(-1).reshape(1, window) total_input = np.concatenate((total_input, new_token[: , -1:]), axis = 1) inputs_dict['input_ids']= total_input[:, j:next_index].reshape(1, window) elif name == 'attention_mask': inputs_dict['attention_mask'] = np.concatenate((np.zeros((1, total_sequence-next_index), dtype = 'int64'), np.ones((1, next_index), dtype = 'int64')), axis=1) else: old_name = name.replace("past_key_values", "present") inputs_dict[name] = outs_dictionary[old_name][:, :, next_index-old_j:context-window+(next_index - old_j), :] answer = tokenizer.decode(total_input[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) return answer ``` We now run the inferences: ```python tokenizer = AutoTokenizer.from_pretrained("Esperanto/gemma-7b-it-kvc-fp16-onnx") model_path = "gemma-7b-it-kvc-fp16-onnx/model.onnx" max_gen_tokens = 20 #number of tokens we want tog eneral total_sequence = 128 #total sequence_length context = 1024 #the context to extend the kvc window = 16 #number of tokens we want to parse at the time messages = [ {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"}, {"role": "user", "content": "Who are you?"}, ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) generated = generate_text(model_path, prompt, tokenizer, max_gen_tokens, total_sequence, window, context) print(generated) ```