Mubbashir Ahmed commited on
Commit
26fe788
Β·
1 Parent(s): f34278a

evaluting each model together

Browse files
Files changed (1) hide show
  1. app.py +48 -55
app.py CHANGED
@@ -21,71 +21,70 @@ spider_dataset = load_dataset("spider", split="train")
21
  llama_client = InferenceClient(provider="fireworks-ai", api_key=HF_TOKEN)
22
  qwen_client = InferenceClient(provider="featherless-ai", api_key=HF_TOKEN)
23
 
 
 
 
 
 
 
 
 
 
 
 
24
  # ------------------------
25
  # Inference + Evaluation Logic
26
  # ------------------------
27
- def evaluate_model(model_name, user_input, expected_sql, chat_history):
28
- messages = chat_history + [{"role": "user", "content": user_input}]
 
29
 
30
- try:
31
- start_time = time.time()
 
32
 
33
- if model_name == "LLaMA 4":
34
- result = llama_client.chat.completions.create(
35
- model="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
36
- messages=messages
37
- )
38
- model_sql = result.choices[0].message.content
39
 
40
- elif model_name == "Qwen3 14B":
41
- result = qwen_client.chat.completions.create(
42
- model="Qwen/Qwen3-14B",
43
  messages=messages
44
  )
45
  model_sql = result.choices[0].message.content
 
46
 
47
- else:
48
- model_sql = "❌ Invalid model selected."
49
-
50
- end_time = time.time()
51
- latency = int((end_time - start_time) * 1000) # ms
52
-
53
- except Exception as e:
54
- model_sql = f"⚠️ Error: {str(e)}"
55
- latency = -1
56
-
57
- # Evaluation criteria (simulated, can be replaced with real validation)
58
- sql_gen_accuracy = "βœ…" if expected_sql.strip().lower() in model_sql.strip().lower() else "❌"
59
- exec_response_accuracy = "βœ…" if sql_gen_accuracy == "βœ…" else "❌"
60
- intent_clarity = "βœ…" if len(user_input.strip().split()) < 5 and "SELECT" in model_sql.upper() else "❌"
61
- semantic_clarity = "βœ…" if any(word in model_sql.lower() for word in ["from", "join", "group by"]) else "❌"
62
- latency_status = "βœ…" if latency <= 1000 else "❌"
63
-
64
- evaluation_summary = (
65
- f"πŸ“Š **Evaluation Summary**\n"
66
- f"- SQL Generation Match: {sql_gen_accuracy}\n"
67
- f"- Execution Accuracy: {exec_response_accuracy}\n"
68
- f"- Intent Clarification: {intent_clarity}\n"
69
- f"- Semantic Mapping: {semantic_clarity}\n"
70
- f"- Response Latency: {latency} ms ({latency_status})\n"
71
- )
72
 
73
- chat_history.append({"role": "user", "content": user_input})
74
- chat_history.append({"role": "assistant", "content": model_sql})
 
 
 
 
 
 
 
75
 
76
- chat_transcript = "\n".join([
77
- f"πŸ‘€ User: {msg['content']}" if msg["role"] == "user" else f"πŸ€– Assistant: {msg['content']}"
78
- for msg in chat_history
79
- ])
80
 
81
- return chat_transcript, chat_history, evaluation_summary
82
 
83
  # ------------------------
84
  # Load Random Spider Prompt
85
  # ------------------------
86
  def get_random_spider_prompt():
87
  sample = random.choice(spider_dataset)
88
- return sample["question"], sample["query"], sample["query"] # Return expected SQL twice
89
 
90
  # ------------------------
91
  # Gradio UI
@@ -93,17 +92,11 @@ def get_random_spider_prompt():
93
  with gr.Blocks() as demo:
94
  gr.Markdown("## 🧠 Spider Dataset Model Evaluation")
95
 
96
- model_choice = gr.Dropdown(
97
- choices=["LLaMA 4", "Qwen3 14B"],
98
- label="Select Model",
99
- value="LLaMA 4"
100
- )
101
-
102
  prompt_input = gr.Textbox(label="Your Prompt", lines=3, placeholder="Ask your BI question...")
103
  expected_sql_display = gr.Textbox(label="Expected SQL", lines=2, interactive=False)
104
 
105
  load_spider_btn = gr.Button("πŸ”€ Load Random Spider Prompt")
106
- run_button = gr.Button("Send & Evaluate")
107
 
108
  chat_display = gr.Textbox(label="Chat History", lines=20, interactive=False)
109
  evaluation_display = gr.Markdown()
@@ -118,8 +111,8 @@ with gr.Blocks() as demo:
118
  )
119
 
120
  run_button.click(
121
- fn=evaluate_model,
122
- inputs=[model_choice, prompt_input, expected_sql, chat_memory],
123
  outputs=[chat_display, chat_memory, evaluation_display]
124
  )
125
 
 
21
  llama_client = InferenceClient(provider="fireworks-ai", api_key=HF_TOKEN)
22
  qwen_client = InferenceClient(provider="featherless-ai", api_key=HF_TOKEN)
23
 
24
+ model_list = {
25
+ "LLaMA 4": {
26
+ "client": llama_client,
27
+ "model_id": "meta-llama/Llama-4-Maverick-17B-128E-Instruct"
28
+ },
29
+ "Qwen3 14B": {
30
+ "client": qwen_client,
31
+ "model_id": "Qwen/Qwen3-14B"
32
+ }
33
+ }
34
+
35
  # ------------------------
36
  # Inference + Evaluation Logic
37
  # ------------------------
38
+ def evaluate_all_models(user_input, expected_sql, chat_history):
39
+ evaluations = []
40
+ full_chat_transcript = ""
41
 
42
+ for model_name, model_config in model_list.items():
43
+ client = model_config["client"]
44
+ model_id = model_config["model_id"]
45
 
46
+ messages = chat_history + [{"role": "user", "content": user_input}]
47
+ try:
48
+ start_time = time.time()
 
 
 
49
 
50
+ result = client.chat.completions.create(
51
+ model=model_id,
 
52
  messages=messages
53
  )
54
  model_sql = result.choices[0].message.content
55
+ latency = int((time.time() - start_time) * 1000)
56
 
57
+ except Exception as e:
58
+ model_sql = f"⚠️ Error: {str(e)}"
59
+ latency = -1
60
+
61
+ # Evaluation criteria (simulated)
62
+ sql_gen_accuracy = "βœ…" if expected_sql.strip().lower() in model_sql.strip().lower() else "❌"
63
+ exec_response_accuracy = "βœ…" if sql_gen_accuracy == "βœ…" else "❌"
64
+ intent_clarity = "βœ…" if len(user_input.strip().split()) < 5 and "SELECT" in model_sql.upper() else "❌"
65
+ semantic_clarity = "βœ…" if any(word in model_sql.lower() for word in ["from", "join", "group by"]) else "❌"
66
+ latency_status = "βœ…" if latency <= 1000 else "❌"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ summary = (
69
+ f"### πŸ€– {model_name} Evaluation\n"
70
+ f"- SQL Generation Match: {sql_gen_accuracy}\n"
71
+ f"- Execution Accuracy: {exec_response_accuracy}\n"
72
+ f"- Intent Clarification: {intent_clarity}\n"
73
+ f"- Semantic Mapping: {semantic_clarity}\n"
74
+ f"- Response Latency: {latency} ms ({latency_status})\n"
75
+ )
76
+ evaluations.append(summary)
77
 
78
+ full_chat_transcript += f"\nπŸ‘€ User: {user_input}\nπŸ€– {model_name}: {model_sql}\n"
 
 
 
79
 
80
+ return full_chat_transcript.strip(), chat_history, "\n\n".join(evaluations)
81
 
82
  # ------------------------
83
  # Load Random Spider Prompt
84
  # ------------------------
85
  def get_random_spider_prompt():
86
  sample = random.choice(spider_dataset)
87
+ return sample["question"], sample["query"], sample["query"]
88
 
89
  # ------------------------
90
  # Gradio UI
 
92
  with gr.Blocks() as demo:
93
  gr.Markdown("## 🧠 Spider Dataset Model Evaluation")
94
 
 
 
 
 
 
 
95
  prompt_input = gr.Textbox(label="Your Prompt", lines=3, placeholder="Ask your BI question...")
96
  expected_sql_display = gr.Textbox(label="Expected SQL", lines=2, interactive=False)
97
 
98
  load_spider_btn = gr.Button("πŸ”€ Load Random Spider Prompt")
99
+ run_button = gr.Button("Send & Evaluate All Models")
100
 
101
  chat_display = gr.Textbox(label="Chat History", lines=20, interactive=False)
102
  evaluation_display = gr.Markdown()
 
111
  )
112
 
113
  run_button.click(
114
+ fn=evaluate_all_models,
115
+ inputs=[prompt_input, expected_sql, chat_memory],
116
  outputs=[chat_display, chat_memory, evaluation_display]
117
  )
118