File size: 4,654 Bytes
8ee35af
 
 
 
339541e
 
 
 
 
64116c6
 
 
8ee35af
 
 
 
 
339541e
64116c6
8ee35af
52389d5
339541e
52389d5
 
 
 
 
 
 
 
 
 
339541e
 
 
 
 
 
 
 
 
52389d5
 
 
339541e
 
 
 
 
 
 
 
 
 
 
 
52389d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
339541e
52389d5
 
339541e
52389d5
 
 
 
 
 
 
 
339541e
52389d5
 
 
 
 
 
 
 
 
 
 
339541e
 
 
 
 
 
 
 
 
 
 
52389d5
 
 
 
339541e
 
 
52389d5
 
 
339541e
52389d5
 
 
339541e
8ee35af
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import numpy as np
import onnxruntime as ort
import torch
import gc
import os
import time

from transformers import AutoTokenizer, AutoModelForCausalLM
from huggingface_hub import hf_hub_download, HfFolder

token = HfFolder.get_token() or os.getenv("HF_TOKEN")

HF_MODEL_ID = "mistralai/Mistral-Nemo-Instruct-2407"
HF_ONNX_REPO = "techAInewb/mistral-nemo-2407-fp32"
ONNX_MODEL_FILE = "model.onnx"

# Shared tokenizer
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_ID, token=token)

def compare_outputs(prompt, show_tokens):
    summary_log = []
    pt_output_text = ""
    ort_output_text = ""
    pt_tokens = []
    ort_tokens = []

    try:
        import psutil
        ram_used = f"{psutil.virtual_memory().used / 1e9:.2f} GB"
    except:
        ram_used = "Unavailable"

    # πŸ”Ή PyTorch Generate
    pt_start = time.time()
    try:
        torch_inputs = tokenizer(prompt, return_tensors="pt")
        pt_model = AutoModelForCausalLM.from_pretrained(HF_MODEL_ID, torch_dtype=torch.float32, token=token)
        pt_model.eval()
        with torch.no_grad():
            pt_outputs = pt_model.generate(**torch_inputs, max_new_tokens=50)
        pt_output_ids = pt_outputs[0].tolist()
        pt_output_text = tokenizer.decode(pt_output_ids, skip_special_tokens=True)
        pt_tokens = tokenizer.convert_ids_to_tokens(pt_output_ids)
        pt_time = time.time() - pt_start
    finally:
        del pt_model
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    # πŸ”Ή ONNX Generate (Greedy)
    ort_start = time.time()
    ort_inputs = tokenizer(prompt, return_tensors="np")
    onnx_path = hf_hub_download(repo_id=HF_ONNX_REPO, filename=ONNX_MODEL_FILE)
    ort_session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
    ort_output_ids = []
    generated = ort_inputs["input_ids"]
    attention_mask = ort_inputs["attention_mask"]
    for _ in range(50):
        ort_outputs = ort_session.run(None, {
            "input_ids": generated,
            "attention_mask": attention_mask
        })
        next_token_logits = ort_outputs[0][:, -1, :]
        next_token = np.argmax(next_token_logits, axis=-1).reshape(-1, 1)
        ort_output_ids.append(next_token[0][0])
        generated = np.concatenate((generated, next_token), axis=1)
        attention_mask = np.concatenate((attention_mask, np.ones((1, 1), dtype=np.int64)), axis=1)
        if next_token[0][0] == tokenizer.eos_token_id:
            break
    ort_time = time.time() - ort_start
    ort_tokens = tokenizer.convert_ids_to_tokens(ort_inputs["input_ids"][0].tolist() + ort_output_ids)
    ort_output_text = tokenizer.decode(ort_inputs["input_ids"][0].tolist() + ort_output_ids, skip_special_tokens=True)

    # πŸ“Š Summary
    summary_log.append("| Model   | Tokens | Time (s) | Time/Token |")
    summary_log.append("|---------|--------|----------|------------|")
    summary_log.append(f"| PyTorch | {len(pt_tokens)} | {pt_time:.2f} | {pt_time / max(1, len(pt_tokens)):.4f} |")
    summary_log.append(f"| ONNX    | {len(ort_tokens)} | {ort_time:.2f} | {ort_time / max(1, len(ort_tokens)):.4f} |")
    summary_log.append(f"\nπŸ“¦ RAM Used: {ram_used}")
    summary_log.append(f"πŸ“š Tokenizer: {tokenizer.name_or_path} | Vocab size: {tokenizer.vocab_size}")
    summary_log.append("πŸ› οΈ Note: This ONNX export is FP32. INT8 + Vitis AI variants coming soon.")

    outputs = [pt_output_text, ort_output_text, "\n".join(summary_log)]

    if show_tokens:
        outputs += [
            ", ".join(pt_tokens),
            ", ".join(ort_tokens)
        ]
    else:
        outputs += ["", ""]

    return outputs

example_prompts = [
    "Who was the first president of the United States?",
    "If you have 3 apples and eat 1, how many are left?",
    "Write a short poem about memory and time.",
    "Explain the laws of motion in simple terms.",
    "What happens when you mix baking soda and vinegar?"
]

iface = gr.Interface(
    fn=compare_outputs,
    inputs=[
        gr.Textbox(lines=2, placeholder="Enter a prompt..."),
        gr.Checkbox(label="Show Token IDs")
    ],
    outputs=[
        gr.Textbox(label="PyTorch Output"),
        gr.Textbox(label="ONNX Output"),
        gr.Textbox(label="Evaluation Summary"),
        gr.Textbox(label="PyTorch Tokens"),
        gr.Textbox(label="ONNX Tokens")
    ],
    title="ONNX vs PyTorch (Full Output + Token Trace)",
    description="Run both models on your prompt and compare output text, timing, and token traces. Sequential model loading avoids OOM.",
    examples=[[p, False] for p in example_prompts]
)

iface.launch()