techAInewb's picture
Update app.py
64116c6 verified
raw
history blame
1.92 kB
import gradio as gr
import numpy as np
import onnxruntime as ort
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from huggingface_hub import hf_hub_download, HfFolder
token = HfFolder.get_token() or os.getenv("HF_TOKEN")
HF_MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
HF_ONNX_REPO = "techAInewb/mistral-nemo-2407-fp32"
ONNX_MODEL_FILE = "model.onnx"
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, token=token)
# Load PyTorch model
pt_model = AutoModelForCausalLM.from_pretrained(HF_MODEL_ID, torch_dtype=torch.float32, token=token)
pt_model.eval()
# Load ONNX model
onnx_path = hf_hub_download(repo_id=HF_ONNX_REPO, filename=ONNX_MODEL_FILE)
onnx_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
def compare_outputs(prompt):
inputs = tokenizer(prompt, return_tensors="np", padding=False)
torch_inputs = tokenizer(prompt, return_tensors="pt")
# Run PyTorch
with torch.no_grad():
pt_outputs = pt_model(**torch_inputs).logits
pt_top = torch.topk(pt_outputs[0, -1], 5).indices.tolist()
# Run ONNX
ort_outputs = onnx_session.run(None, {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"]
})
ort_logits = ort_outputs[0]
ort_top = np.argsort(ort_logits[0, -1])[::-1][:5].tolist()
pt_tokens = tokenizer.convert_ids_to_tokens(pt_top)
ort_tokens = tokenizer.convert_ids_to_tokens(ort_top)
return f"PyTorch Top Tokens: {pt_tokens}", f"ONNX Top Tokens: {ort_tokens}"
iface = gr.Interface(fn=compare_outputs,
inputs=gr.Textbox(lines=2, placeholder="Enter a prompt..."),
outputs=["text", "text"],
title="ONNX vs PyTorch Model Comparison",
description="Run both PyTorch and ONNX inference on a prompt and compare top predicted tokens.")
iface.launch()