techAInewb commited on
Commit
8ee35af
·
verified ·
1 Parent(s): 455148e

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +51 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import onnxruntime as ort
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from huggingface_hub import hf_hub_download
6
+ import torch
7
+
8
+ HF_MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
9
+ HF_ONNX_REPO = "techAInewb/mistral-nemo-2407-fp32"
10
+ ONNX_MODEL_FILE = "model.onnx"
11
+
12
+ # Load tokenizer
13
+ tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID)
14
+
15
+ # Load PyTorch model
16
+ pt_model = AutoModelForCausalLM.from_pretrained(HF_MODEL_ID, torch_dtype=torch.float32)
17
+ pt_model.eval()
18
+
19
+ # Load ONNX model
20
+ onnx_path = hf_hub_download(repo_id=HF_ONNX_REPO, filename=ONNX_MODEL_FILE)
21
+ onnx_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
22
+
23
+ def compare_outputs(prompt):
24
+ inputs = tokenizer(prompt, return_tensors="np", padding=False)
25
+ torch_inputs = tokenizer(prompt, return_tensors="pt")
26
+
27
+ # Run PyTorch
28
+ with torch.no_grad():
29
+ pt_outputs = pt_model(**torch_inputs).logits
30
+ pt_top = torch.topk(pt_outputs[0, -1], 5).indices.tolist()
31
+
32
+ # Run ONNX
33
+ ort_outputs = onnx_session.run(None, {
34
+ "input_ids": inputs["input_ids"],
35
+ "attention_mask": inputs["attention_mask"]
36
+ })
37
+ ort_logits = ort_outputs[0]
38
+ ort_top = np.argsort(ort_logits[0, -1])[::-1][:5].tolist()
39
+
40
+ pt_tokens = tokenizer.convert_ids_to_tokens(pt_top)
41
+ ort_tokens = tokenizer.convert_ids_to_tokens(ort_top)
42
+
43
+ return f"PyTorch Top Tokens: {pt_tokens}", f"ONNX Top Tokens: {ort_tokens}"
44
+
45
+ iface = gr.Interface(fn=compare_outputs,
46
+ inputs=gr.Textbox(lines=2, placeholder="Enter a prompt..."),
47
+ outputs=["text", "text"],
48
+ title="ONNX vs PyTorch Model Comparison",
49
+ description="Run both PyTorch and ONNX inference on a prompt and compare top predicted tokens.")
50
+
51
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio
2
+ transformers
3
+ torch
4
+ onnxruntime
5
+ huggingface_hub