suvadityamuk commited on
Commit
01e09d8
·
1 Parent(s): 0c9c63c

chore: download onnx-data on spaces

Browse files

Signed-off-by: Suvaditya Mukherjee <[email protected]>

Files changed (2) hide show
  1. app.py +21 -21
  2. utils.py +1 -132
app.py CHANGED
@@ -11,7 +11,7 @@ import psutil
11
  import pymupdf
12
  import gradio as gr
13
  from qdrant_client import QdrantClient
14
- from utils import download_pdf_from_gdrive, merge_strings_with_prefix, onnx_inference
15
  from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
16
 
17
  def rag_query(query: str):
@@ -101,25 +101,25 @@ if __name__ == "__main__":
101
  RESUME_PATH = os.path.join(os.getcwd(), "Resume.pdf")
102
  RESUME_URL = "https://drive.google.com/file/d/1YMF9NNTG5gubwJ7ipI5JfxAJKhlD9h2v/"
103
 
104
- ONNX_MODEL_PATH = "https://huggingface.co/onnx-community/Qwen2.5-1.5B-Instruct/resolve/main/onnx/model.onnx_data"
105
- SAVE_PATH = "./model.onnx_data"
106
 
107
- print("Downloading ONNX model...")
108
- response = requests.get(ONNX_MODEL_PATH, stream=True)
109
- response.raise_for_status()
110
 
111
- total_size = int(response.headers.get('content-length', 0))
112
 
113
- with open(SAVE_PATH, 'wb') as file, tqdm(
114
- desc=os.path.basename(SAVE_PATH),
115
- total=total_size,
116
- unit='iB',
117
- unit_scale=True
118
- ) as pbar:
119
- for data in response.iter_content(chunk_size=8192):
120
- size = file.write(data)
121
- pbar.update(size)
122
- print("Downloaded ONNX model!")
123
 
124
  # Download file
125
  download_pdf_from_gdrive(RESUME_URL, RESUME_PATH)
@@ -190,8 +190,8 @@ if __name__ == "__main__":
190
 
191
  # start_time = time.time()
192
  # Generate LLM answer
193
- # generated_text = generate_answer(chat_history)
194
- generated_text = onnx_inference(chat_history, rag_query, tokenizer)
195
 
196
  # Detect if tool call is requested by LLM. If yes, then
197
  # execute tool and use else return None
@@ -204,8 +204,8 @@ if __name__ == "__main__":
204
  chat_history, tool_query, query_results
205
  )
206
  # Generate result from the
207
- # generated_text = generate_answer(chat_history)
208
- generated_text = onnx_inference(chat_history, rag_query, tokenizer)
209
 
210
  # metrics = {
211
  # "conversation": {
 
11
  import pymupdf
12
  import gradio as gr
13
  from qdrant_client import QdrantClient
14
+ from utils import download_pdf_from_gdrive, merge_strings_with_prefix
15
  from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
16
 
17
  def rag_query(query: str):
 
101
  RESUME_PATH = os.path.join(os.getcwd(), "Resume.pdf")
102
  RESUME_URL = "https://drive.google.com/file/d/1YMF9NNTG5gubwJ7ipI5JfxAJKhlD9h2v/"
103
 
104
+ # ONNX_MODEL_PATH = "https://huggingface.co/onnx-community/Qwen2.5-1.5B-Instruct/resolve/main/onnx/model.onnx_data"
105
+ # SAVE_PATH = "./model.onnx_data"
106
 
107
+ # print("Downloading ONNX model...")
108
+ # response = requests.get(ONNX_MODEL_PATH, stream=True)
109
+ # response.raise_for_status()
110
 
111
+ # total_size = int(response.headers.get('content-length', 0))
112
 
113
+ # with open(SAVE_PATH, 'wb') as file, tqdm(
114
+ # desc=os.path.basename(SAVE_PATH),
115
+ # total=total_size,
116
+ # unit='iB',
117
+ # unit_scale=True
118
+ # ) as pbar:
119
+ # for data in response.iter_content(chunk_size=8192):
120
+ # size = file.write(data)
121
+ # pbar.update(size)
122
+ # print("Downloaded ONNX model!")
123
 
124
  # Download file
125
  download_pdf_from_gdrive(RESUME_URL, RESUME_PATH)
 
190
 
191
  # start_time = time.time()
192
  # Generate LLM answer
193
+ generated_text = generate_answer(chat_history)
194
+ # generated_text = onnx_inference(chat_history, rag_query, tokenizer)
195
 
196
  # Detect if tool call is requested by LLM. If yes, then
197
  # execute tool and use else return None
 
204
  chat_history, tool_query, query_results
205
  )
206
  # Generate result from the
207
+ generated_text = generate_answer(chat_history)
208
+ # generated_text = onnx_inference(chat_history, rag_query, tokenizer)
209
 
210
  # metrics = {
211
  # "conversation": {
utils.py CHANGED
@@ -177,135 +177,4 @@ def scrape_website(start_url, delay=1):
177
 
178
  # Combine all content into a single string
179
  combined_content = "\n\n".join(all_content)
180
- return combined_content
181
-
182
- def onnx_inference(chat_history: list, rag_query: str, tokenizer) -> str:
183
- """
184
- Performs ONNX inference with dynamic input handling, optimized for conciseness.
185
- """
186
- session = onnxruntime.InferenceSession("model.onnx")
187
- model_inputs = session.get_inputs()
188
- model_outputs = session.get_outputs()
189
-
190
- # --- Corrected Chat History and Tool Call Format ---
191
- # The tool call needs to be *part* of the chat history.
192
- chat_history_with_tool = chat_history + [
193
- {"role": "user", "content": rag_query, "tools": [{"type": "retrieval"}]},
194
- ]
195
-
196
-
197
- # Use HF tokenizer for input preparation
198
- inputs = tokenizer.apply_chat_template(
199
- chat_history_with_tool,
200
- return_tensors="np",
201
- add_generation_prompt=True
202
- )
203
- input_ids = inputs["input_ids"] #Corrected: Access input_ids correctly
204
- attention_mask = inputs["attention_mask"]
205
-
206
- # Determine required inputs
207
- has_position_ids = "position_ids" in (inp.name for inp in model_inputs)
208
- has_past_key_values = any("past_key_values" in inp.name for inp in model_inputs)
209
-
210
- # Prepare initial inputs, including past_key_values if needed
211
- ort_inputs = {"input_ids": input_ids, "attention_mask": attention_mask}
212
-
213
- if has_position_ids:
214
- ort_inputs["position_ids"] = np.arange(input_ids.shape[1], dtype=np.int64).reshape(1, -1)
215
-
216
- if has_past_key_values:
217
- # Dummy run to get past_key_values shape
218
- dummy_inputs = {
219
- "input_ids": np.array([[0]]),
220
- "attention_mask": np.array([[1]]),
221
- "position_ids": np.array([[0]]) if has_position_ids else None # Add if needed
222
- }
223
- dummy_inputs = {k: v for k, v in dummy_inputs.items() if v is not None} # Remove None values
224
- sample_outputs = session.run(None, dummy_inputs)
225
- pkv_shape = list(sample_outputs[1].shape)
226
- pkv_shape[2] = 0 # Set sequence length to 0 for initial state
227
- num_pkv = len([inp for inp in model_inputs if "past_key_values" in inp.name])
228
- past_key_values = tuple([np.zeros(pkv_shape, dtype=np.float32) for _ in range(num_pkv)])
229
-
230
- # Add initial past_key_values to ort_inputs
231
- for i in range(len(past_key_values) // 2):
232
- ort_inputs[f"past_key_values.{i}.key"] = past_key_values[i * 2]
233
- ort_inputs[f"past_key_values.{i}.value"] = past_key_values[i * 2 + 1]
234
-
235
- generated_ids = []
236
- input_length = input_ids.shape[1]
237
-
238
- # Generation loop with dynamic input updates
239
- for _ in range(512): # Max new tokens
240
- ort_outputs = session.run(None, ort_inputs)
241
- next_token_logits = torch.tensor(ort_outputs[0][:, -1, :])
242
- next_token = torch.multinomial(torch.softmax(next_token_logits / 1.0, dim=-1), num_samples=1).numpy()
243
-
244
- generated_ids.append(next_token[0, 0])
245
- if next_token[0, 0] == tokenizer.eos_token_id:
246
- break
247
-
248
- # Update inputs for next iteration
249
- ort_inputs["input_ids"] = next_token
250
- ort_inputs["attention_mask"] = np.ones_like(next_token)
251
-
252
- if has_position_ids:
253
- ort_inputs["position_ids"] = np.array([[input_length]], dtype=np.int64)
254
- input_length += 1
255
-
256
- if has_past_key_values:
257
- for i in range(len(ort_outputs) -1): # Iterate over model outputs, excluding logits
258
- ort_inputs[model_inputs[i+2].name] = ort_outputs[i+1] # Use names for robustness
259
-
260
-
261
- return tokenizer.decode(generated_ids, skip_special_tokens=True)
262
-
263
- # def onnx_inference(chat_history, rag_query, tokenizer):
264
- # # Create ONNX Runtime session
265
- # session = onnxruntime.InferenceSession("model.onnx")
266
-
267
- # # Tokenize input text
268
- # inputs = tokenizer.apply_chat_template(
269
- # chat_history,
270
- # tools=[rag_query],
271
- # return_tensors="np",
272
- # return_dict=True,
273
- # add_generation_prompt=True,
274
- # # padding=True
275
- # )
276
-
277
- # # Run inference
278
- # ort_inputs = {
279
- # "input_ids": inputs["input_ids"],
280
- # "attention_mask": inputs["attention_mask"]
281
- # }
282
-
283
- # input_length = inputs["input_ids"].shape[1]
284
- # max_new_tokens = 512
285
-
286
- # # Run generation
287
- # for _ in range(max_new_tokens):
288
- # ort_outputs = session.run(None, ort_inputs)
289
- # next_token_logits = ort_outputs[0][:, -1, :]
290
-
291
- # # Apply sampling
292
- # next_token_logits = torch.tensor(next_token_logits)
293
- # probs = torch.nn.functional.softmax(next_token_logits / 1.0, dim=-1)
294
- # next_token = torch.multinomial(probs, num_samples=1).numpy()
295
-
296
- # # Append to input
297
- # ort_inputs["input_ids"] = np.concatenate([ort_inputs["input_ids"], next_token], axis=1)
298
- # ort_inputs["attention_mask"] = np.concatenate([
299
- # ort_inputs["attention_mask"],
300
- # np.ones_like(next_token)
301
- # ], axis=1)
302
-
303
- # # Check for EOS token
304
- # if next_token[0, 0] == tokenizer.eos_token_id:
305
- # break
306
-
307
- # # Decode only the new tokens
308
- # generated_ids = ort_inputs["input_ids"][0, input_length:]
309
- # generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)
310
-
311
- # return generated_text
 
177
 
178
  # Combine all content into a single string
179
  combined_content = "\n\n".join(all_content)
180
+ return combined_content