Mubbashir Ahmed commited on
Commit
6725b24
Β·
1 Parent(s): 1aa45ea
Files changed (1) hide show
  1. app.py +63 -73
app.py CHANGED
@@ -1,137 +1,127 @@
1
  import os
2
  import random
 
3
  import gradio as gr
4
  from huggingface_hub import InferenceClient
5
  from datasets import load_dataset
6
- # from transformers import AutoTokenizer, AutoModelForCausalLM
7
- # import torch
8
-
9
- HF_TOKEN = os.environ.get("HF_TOKEN")
10
 
11
  # ------------------------
12
- # Load Spider Dataset (Hugging Face Datasets)
13
  # ------------------------
14
- spider_dataset = load_dataset("spider", split="train")
15
 
16
  # ------------------------
17
- # API Clients
18
  # ------------------------
19
- llama_client = InferenceClient(
20
- provider="fireworks-ai",
21
- api_key=HF_TOKEN,
22
- )
23
-
24
- qwen_client = InferenceClient(
25
- provider="featherless-ai",
26
- api_key=HF_TOKEN,
27
- )
28
 
29
  # ------------------------
30
- # Mixtral Local Setup (DISABLED)
31
  # ------------------------
32
- # mixtral_model_id = "mistralai/Mixtral-8x7B-Instruct-v0.1"
33
- # mixtral_tokenizer = AutoTokenizer.from_pretrained(mixtral_model_id)
34
- # mixtral_model = AutoModelForCausalLM.from_pretrained(
35
- # mixtral_model_id, torch_dtype=torch.float16
36
- # ).to("cuda")
37
 
38
  # ------------------------
39
- # Unified Inference Function with Chat History
40
  # ------------------------
41
- def run_model_with_history(model_name, user_input, chat_history):
42
  messages = chat_history + [{"role": "user", "content": user_input}]
43
 
44
  try:
 
 
45
  if model_name == "LLaMA 4":
46
  result = llama_client.chat.completions.create(
47
  model="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
48
  messages=messages
49
  )
50
- reply = result.choices[0].message.content
51
 
52
  elif model_name == "Qwen3 14B":
53
  result = qwen_client.chat.completions.create(
54
  model="Qwen/Qwen3-14B",
55
  messages=messages
56
  )
57
- reply = result.choices[0].message.content
58
-
59
- # Mixtral section disabled due to space constraints
60
- # elif model_name == "Mixtral 8x7B":
61
- # full_prompt = ""
62
- # for msg in messages:
63
- # prefix = "User: " if msg["role"] == "user" else "Assistant: "
64
- # full_prompt += f"{prefix}{msg['content']}\n"
65
- # inputs = mixtral_tokenizer(full_prompt, return_tensors="pt").to("cuda")
66
- # outputs = mixtral_model.generate(
67
- # **inputs,
68
- # max_new_tokens=512,
69
- # do_sample=True,
70
- # temperature=0.7,
71
- # top_k=50,
72
- # top_p=0.95
73
- # )
74
- # reply = mixtral_tokenizer.decode(outputs[0], skip_special_tokens=True)
75
 
76
  else:
77
- reply = "❌ Invalid model selection."
 
 
 
78
 
79
  except Exception as e:
80
- reply = f"⚠️ Error: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
 
82
- # Update chat history
83
  chat_history.append({"role": "user", "content": user_input})
84
- chat_history.append({"role": "assistant", "content": reply})
85
 
86
- # Format display
87
  chat_transcript = "\n".join([
88
  f"πŸ‘€ User: {msg['content']}" if msg["role"] == "user" else f"πŸ€– Assistant: {msg['content']}"
89
  for msg in chat_history
90
  ])
91
 
92
- return chat_transcript, chat_history
93
 
94
  # ------------------------
95
- # Get Random Spider Question
96
  # ------------------------
97
- def get_random_spider_question():
98
  sample = random.choice(spider_dataset)
99
- return sample["question"]
100
 
101
  # ------------------------
102
  # Gradio UI
103
  # ------------------------
104
  with gr.Blocks() as demo:
105
- gr.Markdown("## 🧠 Generative AI Model Evaluation with Context")
106
 
107
- with gr.Row():
108
- model_choice = gr.Dropdown(
109
- choices=["LLaMA 4", "Qwen3 14B"],
110
- label="Select Model",
111
- value="LLaMA 4"
112
- )
113
- load_spider_btn = gr.Button("πŸ”€ Load Random Spider Prompt")
114
 
115
- chat_display = gr.Textbox(label="Chat History", lines=20, interactive=False)
116
  prompt_input = gr.Textbox(label="Your Prompt", lines=3, placeholder="Ask your BI question...")
 
117
 
118
- run_button = gr.Button("Send")
 
 
 
 
119
 
120
- # Hidden chat history state
121
  chat_memory = gr.State([])
 
122
 
123
- run_button.click(
124
- fn=run_model_with_history,
125
- inputs=[model_choice, prompt_input, chat_memory],
126
- outputs=[chat_display, chat_memory]
127
- )
128
-
129
  load_spider_btn.click(
130
- fn=get_random_spider_question,
131
  inputs=[],
132
- outputs=prompt_input
 
 
 
 
 
 
133
  )
134
-
135
 
136
- # Launch app
137
  demo.launch()
 
1
  import os
2
  import random
3
+ import time
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
  from datasets import load_dataset
 
 
 
 
7
 
8
  # ------------------------
9
+ # Auth
10
  # ------------------------
11
+ HF_TOKEN = os.environ.get("HF_TOKEN")
12
 
13
  # ------------------------
14
+ # Load Spider Dataset
15
  # ------------------------
16
+ spider_dataset = load_dataset("spider", split="train")
 
 
 
 
 
 
 
 
17
 
18
  # ------------------------
19
+ # Inference Clients
20
  # ------------------------
21
+ llama_client = InferenceClient(provider="fireworks-ai", api_key=HF_TOKEN)
22
+ qwen_client = InferenceClient(provider="featherless-ai", api_key=HF_TOKEN)
 
 
 
23
 
24
  # ------------------------
25
+ # Inference + Evaluation Logic
26
  # ------------------------
27
+ def evaluate_model(model_name, user_input, expected_sql, chat_history):
28
  messages = chat_history + [{"role": "user", "content": user_input}]
29
 
30
  try:
31
+ start_time = time.time()
32
+
33
  if model_name == "LLaMA 4":
34
  result = llama_client.chat.completions.create(
35
  model="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
36
  messages=messages
37
  )
38
+ model_sql = result.choices[0].message.content
39
 
40
  elif model_name == "Qwen3 14B":
41
  result = qwen_client.chat.completions.create(
42
  model="Qwen/Qwen3-14B",
43
  messages=messages
44
  )
45
+ model_sql = result.choices[0].message.content
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  else:
48
+ model_sql = "❌ Invalid model selected."
49
+
50
+ end_time = time.time()
51
+ latency = int((end_time - start_time) * 1000) # ms
52
 
53
  except Exception as e:
54
+ model_sql = f"⚠️ Error: {str(e)}"
55
+ latency = -1
56
+
57
+ # Evaluation criteria (simulated, can be replaced with real validation)
58
+ sql_gen_accuracy = "βœ…" if expected_sql.strip().lower() in model_sql.strip().lower() else "❌"
59
+ exec_response_accuracy = "βœ…" if sql_gen_accuracy == "βœ…" else "❌"
60
+ intent_clarity = "βœ…" if len(user_input.strip().split()) < 5 and "SELECT" in model_sql.upper() else "❌"
61
+ semantic_clarity = "βœ…" if any(word in model_sql.lower() for word in ["from", "join", "group by"]) else "❌"
62
+ latency_status = "βœ…" if latency <= 1000 else "❌"
63
+
64
+ evaluation_summary = (
65
+ f"πŸ“Š **Evaluation Summary**\n"
66
+ f"- SQL Generation Match: {sql_gen_accuracy}\n"
67
+ f"- Execution Accuracy: {exec_response_accuracy}\n"
68
+ f"- Intent Clarification: {intent_clarity}\n"
69
+ f"- Semantic Mapping: {semantic_clarity}\n"
70
+ f"- Response Latency: {latency} ms ({latency_status})\n"
71
+ )
72
 
 
73
  chat_history.append({"role": "user", "content": user_input})
74
+ chat_history.append({"role": "assistant", "content": model_sql})
75
 
 
76
  chat_transcript = "\n".join([
77
  f"πŸ‘€ User: {msg['content']}" if msg["role"] == "user" else f"πŸ€– Assistant: {msg['content']}"
78
  for msg in chat_history
79
  ])
80
 
81
+ return chat_transcript, chat_history, evaluation_summary
82
 
83
  # ------------------------
84
+ # Load Random Spider Prompt
85
  # ------------------------
86
+ def get_random_spider_prompt():
87
  sample = random.choice(spider_dataset)
88
+ return sample["question"], sample["query"]
89
 
90
  # ------------------------
91
  # Gradio UI
92
  # ------------------------
93
  with gr.Blocks() as demo:
94
+ gr.Markdown("## 🧠 Spider Dataset Model Evaluation")
95
 
96
+ model_choice = gr.Dropdown(
97
+ choices=["LLaMA 4", "Qwen3 14B"],
98
+ label="Select Model",
99
+ value="LLaMA 4"
100
+ )
 
 
101
 
 
102
  prompt_input = gr.Textbox(label="Your Prompt", lines=3, placeholder="Ask your BI question...")
103
+ expected_sql_display = gr.Textbox(label="Expected SQL", lines=2, interactive=False)
104
 
105
+ load_spider_btn = gr.Button("πŸ”€ Load Random Spider Prompt")
106
+ run_button = gr.Button("Send & Evaluate")
107
+
108
+ chat_display = gr.Textbox(label="Chat History", lines=20, interactive=False)
109
+ evaluation_display = gr.Markdown()
110
 
 
111
  chat_memory = gr.State([])
112
+ expected_sql = gr.State("")
113
 
 
 
 
 
 
 
114
  load_spider_btn.click(
115
+ fn=get_random_spider_prompt,
116
  inputs=[],
117
+ outputs=[prompt_input, expected_sql, expected_sql_display]
118
+ )
119
+
120
+ run_button.click(
121
+ fn=evaluate_model,
122
+ inputs=[model_choice, prompt_input, expected_sql, chat_memory],
123
+ outputs=[chat_display, chat_memory, evaluation_display]
124
  )
 
125
 
126
+ # Launch
127
  demo.launch()