suvadityamuk commited on
Commit
66e5432
·
1 Parent(s): c711155

chore: add onnx

Browse files

Signed-off-by: Suvaditya Mukherjee <[email protected]>

Files changed (4) hide show
  1. app.py +3 -4
  2. model.onnx +3 -0
  3. requirements.txt +2 -1
  4. utils.py +53 -0
app.py CHANGED
@@ -9,7 +9,7 @@ import psutil
9
  import pymupdf
10
  import gradio as gr
11
  from qdrant_client import QdrantClient
12
- from utils import download_pdf_from_gdrive, merge_strings_with_prefix, scrape_website
13
  from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
14
 
15
  def rag_query(query: str):
@@ -120,8 +120,6 @@ if __name__ == "__main__":
120
  vectors_config=client.get_fastembed_vector_params(),
121
  )
122
 
123
- print("fulltext", fulltext)
124
-
125
  _ = client.add(
126
  collection_name="resume",
127
  documents=fulltext,
@@ -170,7 +168,8 @@ if __name__ == "__main__":
170
 
171
  # start_time = time.time()
172
  # Generate LLM answer
173
- generated_text = generate_answer(chat_history)
 
174
 
175
  # Detect if tool call is requested by LLM. If yes, then
176
  # execute tool and use else return None
 
9
  import pymupdf
10
  import gradio as gr
11
  from qdrant_client import QdrantClient
12
+ from utils import download_pdf_from_gdrive, merge_strings_with_prefix, onnx_inference
13
  from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
14
 
15
  def rag_query(query: str):
 
120
  vectors_config=client.get_fastembed_vector_params(),
121
  )
122
 
 
 
123
  _ = client.add(
124
  collection_name="resume",
125
  documents=fulltext,
 
168
 
169
  # start_time = time.time()
170
  # Generate LLM answer
171
+ # generated_text = generate_answer(chat_history)
172
+ generated_text = onnx_inference(chat_history, rag_query, tokenizer)
173
 
174
  # Detect if tool call is requested by LLM. If yes, then
175
  # execute tool and use else return None
model.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f7da5d0bb5e5b6eba0ba0e9c006fcbc8b670134405ef5e02aaaf738361e2074
3
+ size 1057843
requirements.txt CHANGED
@@ -14,4 +14,5 @@ psutil
14
  optimum-quanto
15
  pynvml
16
  beautifulsoup4
17
- requests
 
 
14
  optimum-quanto
15
  pynvml
16
  beautifulsoup4
17
+ requests
18
+ onnxruntime
utils.py CHANGED
@@ -1,5 +1,8 @@
1
  import gdown
2
  import os
 
 
 
3
  from urllib.parse import urlparse, parse_qs, urljoin
4
  import requests
5
  from bs4 import BeautifulSoup
@@ -175,3 +178,53 @@ def scrape_website(start_url, delay=1):
175
  # Combine all content into a single string
176
  combined_content = "\n\n".join(all_content)
177
  return combined_content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gdown
2
  import os
3
+ import numpy as np
4
+ import torch
5
+ import onnxruntime
6
  from urllib.parse import urlparse, parse_qs, urljoin
7
  import requests
8
  from bs4 import BeautifulSoup
 
178
  # Combine all content into a single string
179
  combined_content = "\n\n".join(all_content)
180
  return combined_content
181
+
182
+ def onnx_inference(chat_history, rag_query, tokenizer):
183
+ # Create ONNX Runtime session
184
+ session = onnxruntime.InferenceSession("model.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
185
+
186
+ # Tokenize input text
187
+ inputs = tokenizer.apply_chat_template(
188
+ chat_history,
189
+ tools=[rag_query],
190
+ return_tensors="np",
191
+ return_dict=True,
192
+ add_generation_prompt=True,
193
+ # padding=True
194
+ )
195
+
196
+ # Run inference
197
+ ort_inputs = {
198
+ "input_ids": inputs["input_ids"],
199
+ "attention_mask": inputs["attention_mask"]
200
+ }
201
+
202
+ input_length = inputs["input_ids"].shape[1]
203
+ max_new_tokens = 512
204
+
205
+ # Run generation
206
+ for _ in range(max_new_tokens):
207
+ ort_outputs = session.run(None, ort_inputs)
208
+ next_token_logits = ort_outputs[0][:, -1, :]
209
+
210
+ # Apply sampling
211
+ next_token_logits = torch.tensor(next_token_logits)
212
+ probs = torch.nn.functional.softmax(next_token_logits / 1.0, dim=-1)
213
+ next_token = torch.multinomial(probs, num_samples=1).numpy()
214
+
215
+ # Append to input
216
+ ort_inputs["input_ids"] = np.concatenate([ort_inputs["input_ids"], next_token], axis=1)
217
+ ort_inputs["attention_mask"] = np.concatenate([
218
+ ort_inputs["attention_mask"],
219
+ np.ones_like(next_token)
220
+ ], axis=1)
221
+
222
+ # Check for EOS token
223
+ if next_token[0, 0] == tokenizer.eos_token_id:
224
+ break
225
+
226
+ # Decode only the new tokens
227
+ generated_ids = ort_inputs["input_ids"][0, input_length:]
228
+ generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
229
+
230
+ return generated_text