Spaces:
Sleeping
Sleeping
Commit
·
66e5432
1
Parent(s):
c711155
chore: add onnx
Browse filesSigned-off-by: Suvaditya Mukherjee <[email protected]>
- app.py +3 -4
- model.onnx +3 -0
- requirements.txt +2 -1
- utils.py +53 -0
app.py
CHANGED
@@ -9,7 +9,7 @@ import psutil
|
|
9 |
import pymupdf
|
10 |
import gradio as gr
|
11 |
from qdrant_client import QdrantClient
|
12 |
-
from utils import download_pdf_from_gdrive, merge_strings_with_prefix,
|
13 |
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
|
14 |
|
15 |
def rag_query(query: str):
|
@@ -120,8 +120,6 @@ if __name__ == "__main__":
|
|
120 |
vectors_config=client.get_fastembed_vector_params(),
|
121 |
)
|
122 |
|
123 |
-
print("fulltext", fulltext)
|
124 |
-
|
125 |
_ = client.add(
|
126 |
collection_name="resume",
|
127 |
documents=fulltext,
|
@@ -170,7 +168,8 @@ if __name__ == "__main__":
|
|
170 |
|
171 |
# start_time = time.time()
|
172 |
# Generate LLM answer
|
173 |
-
generated_text = generate_answer(chat_history)
|
|
|
174 |
|
175 |
# Detect if tool call is requested by LLM. If yes, then
|
176 |
# execute tool and use else return None
|
|
|
9 |
import pymupdf
|
10 |
import gradio as gr
|
11 |
from qdrant_client import QdrantClient
|
12 |
+
from utils import download_pdf_from_gdrive, merge_strings_with_prefix, onnx_inference
|
13 |
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
|
14 |
|
15 |
def rag_query(query: str):
|
|
|
120 |
vectors_config=client.get_fastembed_vector_params(),
|
121 |
)
|
122 |
|
|
|
|
|
123 |
_ = client.add(
|
124 |
collection_name="resume",
|
125 |
documents=fulltext,
|
|
|
168 |
|
169 |
# start_time = time.time()
|
170 |
# Generate LLM answer
|
171 |
+
# generated_text = generate_answer(chat_history)
|
172 |
+
generated_text = onnx_inference(chat_history, rag_query, tokenizer)
|
173 |
|
174 |
# Detect if tool call is requested by LLM. If yes, then
|
175 |
# execute tool and use else return None
|
model.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:8f7da5d0bb5e5b6eba0ba0e9c006fcbc8b670134405ef5e02aaaf738361e2074
|
3 |
+
size 1057843
|
requirements.txt
CHANGED
@@ -14,4 +14,5 @@ psutil
|
|
14 |
optimum-quanto
|
15 |
pynvml
|
16 |
beautifulsoup4
|
17 |
-
requests
|
|
|
|
14 |
optimum-quanto
|
15 |
pynvml
|
16 |
beautifulsoup4
|
17 |
+
requests
|
18 |
+
onnxruntime
|
utils.py
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
import gdown
|
2 |
import os
|
|
|
|
|
|
|
3 |
from urllib.parse import urlparse, parse_qs, urljoin
|
4 |
import requests
|
5 |
from bs4 import BeautifulSoup
|
@@ -175,3 +178,53 @@ def scrape_website(start_url, delay=1):
|
|
175 |
# Combine all content into a single string
|
176 |
combined_content = "\n\n".join(all_content)
|
177 |
return combined_content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gdown
|
2 |
import os
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import onnxruntime
|
6 |
from urllib.parse import urlparse, parse_qs, urljoin
|
7 |
import requests
|
8 |
from bs4 import BeautifulSoup
|
|
|
178 |
# Combine all content into a single string
|
179 |
combined_content = "\n\n".join(all_content)
|
180 |
return combined_content
|
181 |
+
|
182 |
+
def onnx_inference(chat_history, rag_query, tokenizer):
|
183 |
+
# Create ONNX Runtime session
|
184 |
+
session = onnxruntime.InferenceSession("model.onnx", providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
|
185 |
+
|
186 |
+
# Tokenize input text
|
187 |
+
inputs = tokenizer.apply_chat_template(
|
188 |
+
chat_history,
|
189 |
+
tools=[rag_query],
|
190 |
+
return_tensors="np",
|
191 |
+
return_dict=True,
|
192 |
+
add_generation_prompt=True,
|
193 |
+
# padding=True
|
194 |
+
)
|
195 |
+
|
196 |
+
# Run inference
|
197 |
+
ort_inputs = {
|
198 |
+
"input_ids": inputs["input_ids"],
|
199 |
+
"attention_mask": inputs["attention_mask"]
|
200 |
+
}
|
201 |
+
|
202 |
+
input_length = inputs["input_ids"].shape[1]
|
203 |
+
max_new_tokens = 512
|
204 |
+
|
205 |
+
# Run generation
|
206 |
+
for _ in range(max_new_tokens):
|
207 |
+
ort_outputs = session.run(None, ort_inputs)
|
208 |
+
next_token_logits = ort_outputs[0][:, -1, :]
|
209 |
+
|
210 |
+
# Apply sampling
|
211 |
+
next_token_logits = torch.tensor(next_token_logits)
|
212 |
+
probs = torch.nn.functional.softmax(next_token_logits / 1.0, dim=-1)
|
213 |
+
next_token = torch.multinomial(probs, num_samples=1).numpy()
|
214 |
+
|
215 |
+
# Append to input
|
216 |
+
ort_inputs["input_ids"] = np.concatenate([ort_inputs["input_ids"], next_token], axis=1)
|
217 |
+
ort_inputs["attention_mask"] = np.concatenate([
|
218 |
+
ort_inputs["attention_mask"],
|
219 |
+
np.ones_like(next_token)
|
220 |
+
], axis=1)
|
221 |
+
|
222 |
+
# Check for EOS token
|
223 |
+
if next_token[0, 0] == tokenizer.eos_token_id:
|
224 |
+
break
|
225 |
+
|
226 |
+
# Decode only the new tokens
|
227 |
+
generated_ids = ort_inputs["input_ids"][0, input_length:]
|
228 |
+
generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
229 |
+
|
230 |
+
return generated_text
|