techAInewb's picture
Update app.py
52389d5 verified
raw
history blame
4.65 kB
import gradio as gr
import numpy as np
import onnxruntime as ort
import torch
import gc
import os
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import hf_hub_download, HfFolder
token = HfFolder.get_token() or os.getenv("HF_TOKEN")
HF_MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
HF_ONNX_REPO = "techAInewb/mistral-nemo-2407-fp32"
ONNX_MODEL_FILE = "model.onnx"
# Shared tokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, token=token)
def compare_outputs(prompt, show_tokens):
summary_log = []
pt_output_text = ""
ort_output_text = ""
pt_tokens = []
ort_tokens = []
try:
import psutil
ram_used = f"{psutil.virtual_memory().used / 1e9:.2f} GB"
except:
ram_used = "Unavailable"
# πŸ”Ή PyTorch Generate
pt_start = time.time()
try:
torch_inputs = tokenizer(prompt, return_tensors="pt")
pt_model = AutoModelForCausalLM.from_pretrained(HF_MODEL_ID, torch_dtype=torch.float32, token=token)
pt_model.eval()
with torch.no_grad():
pt_outputs = pt_model.generate(**torch_inputs, max_new_tokens=50)
pt_output_ids = pt_outputs[0].tolist()
pt_output_text = tokenizer.decode(pt_output_ids, skip_special_tokens=True)
pt_tokens = tokenizer.convert_ids_to_tokens(pt_output_ids)
pt_time = time.time() - pt_start
finally:
del pt_model
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# πŸ”Ή ONNX Generate (Greedy)
ort_start = time.time()
ort_inputs = tokenizer(prompt, return_tensors="np")
onnx_path = hf_hub_download(repo_id=HF_ONNX_REPO, filename=ONNX_MODEL_FILE)
ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
ort_output_ids = []
generated = ort_inputs["input_ids"]
attention_mask = ort_inputs["attention_mask"]
for _ in range(50):
ort_outputs = ort_session.run(None, {
"input_ids": generated,
"attention_mask": attention_mask
})
next_token_logits = ort_outputs[0][:, -1, :]
next_token = np.argmax(next_token_logits, axis=-1).reshape(-1, 1)
ort_output_ids.append(next_token[0][0])
generated = np.concatenate((generated, next_token), axis=1)
attention_mask = np.concatenate((attention_mask, np.ones((1, 1), dtype=np.int64)), axis=1)
if next_token[0][0] == tokenizer.eos_token_id:
break
ort_time = time.time() - ort_start
ort_tokens = tokenizer.convert_ids_to_tokens(ort_inputs["input_ids"][0].tolist() + ort_output_ids)
ort_output_text = tokenizer.decode(ort_inputs["input_ids"][0].tolist() + ort_output_ids, skip_special_tokens=True)
# πŸ“Š Summary
summary_log.append("| Model | Tokens | Time (s) | Time/Token |")
summary_log.append("|---------|--------|----------|------------|")
summary_log.append(f"| PyTorch | {len(pt_tokens)} | {pt_time:.2f} | {pt_time / max(1, len(pt_tokens)):.4f} |")
summary_log.append(f"| ONNX | {len(ort_tokens)} | {ort_time:.2f} | {ort_time / max(1, len(ort_tokens)):.4f} |")
summary_log.append(f"\nπŸ“¦ RAM Used: {ram_used}")
summary_log.append(f"πŸ“š Tokenizer: {tokenizer.name_or_path} | Vocab size: {tokenizer.vocab_size}")
summary_log.append("πŸ› οΈ Note: This ONNX export is FP32. INT8 + Vitis AI variants coming soon.")
outputs = [pt_output_text, ort_output_text, "\n".join(summary_log)]
if show_tokens:
outputs += [
", ".join(pt_tokens),
", ".join(ort_tokens)
]
else:
outputs += ["", ""]
return outputs
example_prompts = [
"Who was the first president of the United States?",
"If you have 3 apples and eat 1, how many are left?",
"Write a short poem about memory and time.",
"Explain the laws of motion in simple terms.",
"What happens when you mix baking soda and vinegar?"
]
iface = gr.Interface(
fn=compare_outputs,
inputs=[
gr.Textbox(lines=2, placeholder="Enter a prompt..."),
gr.Checkbox(label="Show Token IDs")
],
outputs=[
gr.Textbox(label="PyTorch Output"),
gr.Textbox(label="ONNX Output"),
gr.Textbox(label="Evaluation Summary"),
gr.Textbox(label="PyTorch Tokens"),
gr.Textbox(label="ONNX Tokens")
],
title="ONNX vs PyTorch (Full Output + Token Trace)",
description="Run both models on your prompt and compare output text, timing, and token traces. Sequential model loading avoids OOM.",
examples=[[p, False] for p in example_prompts]
)
iface.launch()