techAInewb commited on
Commit
ee89b44
·
verified ·
1 Parent(s): 649b695

Delete app[1].py

Browse files
Files changed (1) hide show
  1. app[1].py +0 -51
app[1].py DELETED
@@ -1,51 +0,0 @@
1
- import gradio as gr
2
- import numpy as np
3
- import onnxruntime as ort
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
- from huggingface_hub import hf_hub_download
6
- import torch
7
-
8
- HF_MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
9
- HF_ONNX_REPO = "techAInewb/mistral-nemo-2407-fp32"
10
- ONNX_MODEL_FILE = "model.onnx"
11
-
12
- # Load tokenizer
13
- tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID)
14
-
15
- # Load PyTorch model
16
- pt_model = AutoModelForCausalLM.from_pretrained(HF_MODEL_ID, torch_dtype=torch.float32)
17
- pt_model.eval()
18
-
19
- # Load ONNX model
20
- onnx_path = hf_hub_download(repo_id=HF_ONNX_REPO, filename=ONNX_MODEL_FILE)
21
- onnx_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
22
-
23
- def compare_outputs(prompt):
24
- inputs = tokenizer(prompt, return_tensors="np", padding=False)
25
- torch_inputs = tokenizer(prompt, return_tensors="pt")
26
-
27
- # Run PyTorch
28
- with torch.no_grad():
29
- pt_outputs = pt_model(**torch_inputs).logits
30
- pt_top = torch.topk(pt_outputs[0, -1], 5).indices.tolist()
31
-
32
- # Run ONNX
33
- ort_outputs = onnx_session.run(None, {
34
- "input_ids": inputs["input_ids"],
35
- "attention_mask": inputs["attention_mask"]
36
- })
37
- ort_logits = ort_outputs[0]
38
- ort_top = np.argsort(ort_logits[0, -1])[::-1][:5].tolist()
39
-
40
- pt_tokens = tokenizer.convert_ids_to_tokens(pt_top)
41
- ort_tokens = tokenizer.convert_ids_to_tokens(ort_top)
42
-
43
- return f"PyTorch Top Tokens: {pt_tokens}", f"ONNX Top Tokens: {ort_tokens}"
44
-
45
- iface = gr.Interface(fn=compare_outputs,
46
- inputs=gr.Textbox(lines=2, placeholder="Enter a prompt..."),
47
- outputs=["text", "text"],
48
- title="ONNX vs PyTorch Model Comparison",
49
- description="Run both PyTorch and ONNX inference on a prompt and compare top predicted tokens.")
50
-
51
- iface.launch()