Spaces:
Sleeping
Sleeping
Commit
·
01e09d8
1
Parent(s):
0c9c63c
chore: download onnx-data on spaces
Browse filesSigned-off-by: Suvaditya Mukherjee <[email protected]>
app.py
CHANGED
@@ -11,7 +11,7 @@ import psutil
|
|
11 |
import pymupdf
|
12 |
import gradio as gr
|
13 |
from qdrant_client import QdrantClient
|
14 |
-
from utils import download_pdf_from_gdrive, merge_strings_with_prefix
|
15 |
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
|
16 |
|
17 |
def rag_query(query: str):
|
@@ -101,25 +101,25 @@ if __name__ == "__main__":
|
|
101 |
RESUME_PATH = os.path.join(os.getcwd(), "Resume.pdf")
|
102 |
RESUME_URL = "https://drive.google.com/file/d/1YMF9NNTG5gubwJ7ipI5JfxAJKhlD9h2v/"
|
103 |
|
104 |
-
ONNX_MODEL_PATH = "https://huggingface.co/onnx-community/Qwen2.5-1.5B-Instruct/resolve/main/onnx/model.onnx_data"
|
105 |
-
SAVE_PATH = "./model.onnx_data"
|
106 |
|
107 |
-
print("Downloading ONNX model...")
|
108 |
-
response = requests.get(ONNX_MODEL_PATH, stream=True)
|
109 |
-
response.raise_for_status()
|
110 |
|
111 |
-
total_size = int(response.headers.get('content-length', 0))
|
112 |
|
113 |
-
with open(SAVE_PATH, 'wb') as file, tqdm(
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
) as pbar:
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
print("Downloaded ONNX model!")
|
123 |
|
124 |
# Download file
|
125 |
download_pdf_from_gdrive(RESUME_URL, RESUME_PATH)
|
@@ -190,8 +190,8 @@ if __name__ == "__main__":
|
|
190 |
|
191 |
# start_time = time.time()
|
192 |
# Generate LLM answer
|
193 |
-
|
194 |
-
generated_text = onnx_inference(chat_history, rag_query, tokenizer)
|
195 |
|
196 |
# Detect if tool call is requested by LLM. If yes, then
|
197 |
# execute tool and use else return None
|
@@ -204,8 +204,8 @@ if __name__ == "__main__":
|
|
204 |
chat_history, tool_query, query_results
|
205 |
)
|
206 |
# Generate result from the
|
207 |
-
|
208 |
-
generated_text = onnx_inference(chat_history, rag_query, tokenizer)
|
209 |
|
210 |
# metrics = {
|
211 |
# "conversation": {
|
|
|
11 |
import pymupdf
|
12 |
import gradio as gr
|
13 |
from qdrant_client import QdrantClient
|
14 |
+
from utils import download_pdf_from_gdrive, merge_strings_with_prefix
|
15 |
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
|
16 |
|
17 |
def rag_query(query: str):
|
|
|
101 |
RESUME_PATH = os.path.join(os.getcwd(), "Resume.pdf")
|
102 |
RESUME_URL = "https://drive.google.com/file/d/1YMF9NNTG5gubwJ7ipI5JfxAJKhlD9h2v/"
|
103 |
|
104 |
+
# ONNX_MODEL_PATH = "https://huggingface.co/onnx-community/Qwen2.5-1.5B-Instruct/resolve/main/onnx/model.onnx_data"
|
105 |
+
# SAVE_PATH = "./model.onnx_data"
|
106 |
|
107 |
+
# print("Downloading ONNX model...")
|
108 |
+
# response = requests.get(ONNX_MODEL_PATH, stream=True)
|
109 |
+
# response.raise_for_status()
|
110 |
|
111 |
+
# total_size = int(response.headers.get('content-length', 0))
|
112 |
|
113 |
+
# with open(SAVE_PATH, 'wb') as file, tqdm(
|
114 |
+
# desc=os.path.basename(SAVE_PATH),
|
115 |
+
# total=total_size,
|
116 |
+
# unit='iB',
|
117 |
+
# unit_scale=True
|
118 |
+
# ) as pbar:
|
119 |
+
# for data in response.iter_content(chunk_size=8192):
|
120 |
+
# size = file.write(data)
|
121 |
+
# pbar.update(size)
|
122 |
+
# print("Downloaded ONNX model!")
|
123 |
|
124 |
# Download file
|
125 |
download_pdf_from_gdrive(RESUME_URL, RESUME_PATH)
|
|
|
190 |
|
191 |
# start_time = time.time()
|
192 |
# Generate LLM answer
|
193 |
+
generated_text = generate_answer(chat_history)
|
194 |
+
# generated_text = onnx_inference(chat_history, rag_query, tokenizer)
|
195 |
|
196 |
# Detect if tool call is requested by LLM. If yes, then
|
197 |
# execute tool and use else return None
|
|
|
204 |
chat_history, tool_query, query_results
|
205 |
)
|
206 |
# Generate result from the
|
207 |
+
generated_text = generate_answer(chat_history)
|
208 |
+
# generated_text = onnx_inference(chat_history, rag_query, tokenizer)
|
209 |
|
210 |
# metrics = {
|
211 |
# "conversation": {
|
utils.py
CHANGED
@@ -177,135 +177,4 @@ def scrape_website(start_url, delay=1):
|
|
177 |
|
178 |
# Combine all content into a single string
|
179 |
combined_content = "\n\n".join(all_content)
|
180 |
-
return combined_content
|
181 |
-
|
182 |
-
def onnx_inference(chat_history: list, rag_query: str, tokenizer) -> str:
|
183 |
-
"""
|
184 |
-
Performs ONNX inference with dynamic input handling, optimized for conciseness.
|
185 |
-
"""
|
186 |
-
session = onnxruntime.InferenceSession("model.onnx")
|
187 |
-
model_inputs = session.get_inputs()
|
188 |
-
model_outputs = session.get_outputs()
|
189 |
-
|
190 |
-
# --- Corrected Chat History and Tool Call Format ---
|
191 |
-
# The tool call needs to be *part* of the chat history.
|
192 |
-
chat_history_with_tool = chat_history + [
|
193 |
-
{"role": "user", "content": rag_query, "tools": [{"type": "retrieval"}]},
|
194 |
-
]
|
195 |
-
|
196 |
-
|
197 |
-
# Use HF tokenizer for input preparation
|
198 |
-
inputs = tokenizer.apply_chat_template(
|
199 |
-
chat_history_with_tool,
|
200 |
-
return_tensors="np",
|
201 |
-
add_generation_prompt=True
|
202 |
-
)
|
203 |
-
input_ids = inputs["input_ids"] #Corrected: Access input_ids correctly
|
204 |
-
attention_mask = inputs["attention_mask"]
|
205 |
-
|
206 |
-
# Determine required inputs
|
207 |
-
has_position_ids = "position_ids" in (inp.name for inp in model_inputs)
|
208 |
-
has_past_key_values = any("past_key_values" in inp.name for inp in model_inputs)
|
209 |
-
|
210 |
-
# Prepare initial inputs, including past_key_values if needed
|
211 |
-
ort_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
|
212 |
-
|
213 |
-
if has_position_ids:
|
214 |
-
ort_inputs["position_ids"] = np.arange(input_ids.shape[1], dtype=np.int64).reshape(1, -1)
|
215 |
-
|
216 |
-
if has_past_key_values:
|
217 |
-
# Dummy run to get past_key_values shape
|
218 |
-
dummy_inputs = {
|
219 |
-
"input_ids": np.array([[0]]),
|
220 |
-
"attention_mask": np.array([[1]]),
|
221 |
-
"position_ids": np.array([[0]]) if has_position_ids else None # Add if needed
|
222 |
-
}
|
223 |
-
dummy_inputs = {k: v for k, v in dummy_inputs.items() if v is not None} # Remove None values
|
224 |
-
sample_outputs = session.run(None, dummy_inputs)
|
225 |
-
pkv_shape = list(sample_outputs[1].shape)
|
226 |
-
pkv_shape[2] = 0 # Set sequence length to 0 for initial state
|
227 |
-
num_pkv = len([inp for inp in model_inputs if "past_key_values" in inp.name])
|
228 |
-
past_key_values = tuple([np.zeros(pkv_shape, dtype=np.float32) for _ in range(num_pkv)])
|
229 |
-
|
230 |
-
# Add initial past_key_values to ort_inputs
|
231 |
-
for i in range(len(past_key_values) // 2):
|
232 |
-
ort_inputs[f"past_key_values.{i}.key"] = past_key_values[i * 2]
|
233 |
-
ort_inputs[f"past_key_values.{i}.value"] = past_key_values[i * 2 + 1]
|
234 |
-
|
235 |
-
generated_ids = []
|
236 |
-
input_length = input_ids.shape[1]
|
237 |
-
|
238 |
-
# Generation loop with dynamic input updates
|
239 |
-
for _ in range(512): # Max new tokens
|
240 |
-
ort_outputs = session.run(None, ort_inputs)
|
241 |
-
next_token_logits = torch.tensor(ort_outputs[0][:, -1, :])
|
242 |
-
next_token = torch.multinomial(torch.softmax(next_token_logits / 1.0, dim=-1), num_samples=1).numpy()
|
243 |
-
|
244 |
-
generated_ids.append(next_token[0, 0])
|
245 |
-
if next_token[0, 0] == tokenizer.eos_token_id:
|
246 |
-
break
|
247 |
-
|
248 |
-
# Update inputs for next iteration
|
249 |
-
ort_inputs["input_ids"] = next_token
|
250 |
-
ort_inputs["attention_mask"] = np.ones_like(next_token)
|
251 |
-
|
252 |
-
if has_position_ids:
|
253 |
-
ort_inputs["position_ids"] = np.array([[input_length]], dtype=np.int64)
|
254 |
-
input_length += 1
|
255 |
-
|
256 |
-
if has_past_key_values:
|
257 |
-
for i in range(len(ort_outputs) -1): # Iterate over model outputs, excluding logits
|
258 |
-
ort_inputs[model_inputs[i+2].name] = ort_outputs[i+1] # Use names for robustness
|
259 |
-
|
260 |
-
|
261 |
-
return tokenizer.decode(generated_ids, skip_special_tokens=True)
|
262 |
-
|
263 |
-
# def onnx_inference(chat_history, rag_query, tokenizer):
|
264 |
-
# # Create ONNX Runtime session
|
265 |
-
# session = onnxruntime.InferenceSession("model.onnx")
|
266 |
-
|
267 |
-
# # Tokenize input text
|
268 |
-
# inputs = tokenizer.apply_chat_template(
|
269 |
-
# chat_history,
|
270 |
-
# tools=[rag_query],
|
271 |
-
# return_tensors="np",
|
272 |
-
# return_dict=True,
|
273 |
-
# add_generation_prompt=True,
|
274 |
-
# # padding=True
|
275 |
-
# )
|
276 |
-
|
277 |
-
# # Run inference
|
278 |
-
# ort_inputs = {
|
279 |
-
# "input_ids": inputs["input_ids"],
|
280 |
-
# "attention_mask": inputs["attention_mask"]
|
281 |
-
# }
|
282 |
-
|
283 |
-
# input_length = inputs["input_ids"].shape[1]
|
284 |
-
# max_new_tokens = 512
|
285 |
-
|
286 |
-
# # Run generation
|
287 |
-
# for _ in range(max_new_tokens):
|
288 |
-
# ort_outputs = session.run(None, ort_inputs)
|
289 |
-
# next_token_logits = ort_outputs[0][:, -1, :]
|
290 |
-
|
291 |
-
# # Apply sampling
|
292 |
-
# next_token_logits = torch.tensor(next_token_logits)
|
293 |
-
# probs = torch.nn.functional.softmax(next_token_logits / 1.0, dim=-1)
|
294 |
-
# next_token = torch.multinomial(probs, num_samples=1).numpy()
|
295 |
-
|
296 |
-
# # Append to input
|
297 |
-
# ort_inputs["input_ids"] = np.concatenate([ort_inputs["input_ids"], next_token], axis=1)
|
298 |
-
# ort_inputs["attention_mask"] = np.concatenate([
|
299 |
-
# ort_inputs["attention_mask"],
|
300 |
-
# np.ones_like(next_token)
|
301 |
-
# ], axis=1)
|
302 |
-
|
303 |
-
# # Check for EOS token
|
304 |
-
# if next_token[0, 0] == tokenizer.eos_token_id:
|
305 |
-
# break
|
306 |
-
|
307 |
-
# # Decode only the new tokens
|
308 |
-
# generated_ids = ort_inputs["input_ids"][0, input_length:]
|
309 |
-
# generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
310 |
-
|
311 |
-
# return generated_text
|
|
|
177 |
|
178 |
# Combine all content into a single string
|
179 |
combined_content = "\n\n".join(all_content)
|
180 |
+
return combined_content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|