Mubbashir Ahmed commited on
Commit
ec635d6
Β·
1 Parent(s): a1b742b
Files changed (1) hide show
  1. app.py +74 -75
app.py CHANGED
@@ -3,6 +3,8 @@ import random
3
  import time
4
  import json
5
  import gradio as gr
 
 
6
  from huggingface_hub import InferenceClient
7
 
8
  # ------------------------
@@ -36,6 +38,7 @@ model_list = {
36
  # ------------------------
37
  # Prompt Template for SQL Generation
38
  # ------------------------
 
39
  def build_prompt(user_question):
40
  return f"""You are an expert SQL assistant. Convert the given question into a valid SQL query.
41
 
@@ -57,89 +60,85 @@ Q: {user_question}
57
  A:"""
58
 
59
  # ------------------------
60
- # Inference + Evaluation Logic
61
  # ------------------------
62
- def evaluate_all_models(user_input, expected_sql, chat_history):
63
- evaluations = []
64
- full_chat_transcript = ""
65
- prompt = build_prompt(user_input)
66
-
67
- for model_name, model_config in model_list.items():
68
- client = model_config["client"]
69
- model_id = model_config["model_id"]
70
-
71
- messages = chat_history + [{"role": "user", "content": prompt}]
72
- try:
73
- start_time = time.time()
74
-
75
- result = client.chat.completions.create(
76
- model=model_id,
77
- messages=messages
78
- )
79
- model_sql = result.choices[0].message.content
80
- latency = int((time.time() - start_time) * 1000)
81
-
82
- except Exception as e:
83
- model_sql = f"⚠️ Error: {str(e)}"
84
- latency = -1
85
-
86
- # Evaluation criteria (simulated)
87
- sql_gen_accuracy = "βœ…" if expected_sql.strip().lower() in model_sql.strip().lower() else "❌"
88
- exec_response_accuracy = "βœ…" if sql_gen_accuracy == "βœ…" else "❌"
89
- intent_clarity = "βœ…" if len(user_input.strip().split()) < 5 and "SELECT" in model_sql.upper() else "❌"
90
- semantic_clarity = "βœ…" if any(word in model_sql.lower() for word in ["from", "join", "group by"]) else "❌"
91
- latency_status = "βœ…" if latency <= 1000 else "❌"
92
-
93
- summary = (
94
- f"### πŸ€– {model_name} Evaluation\n"
95
- f"- SQL Generation Match: {sql_gen_accuracy}\n"
96
- f"- Execution Accuracy: {exec_response_accuracy}\n"
97
- f"- Intent Clarification: {intent_clarity}\n"
98
- f"- Semantic Mapping: {semantic_clarity}\n"
99
- f"- Response Latency: {latency} ms ({latency_status})\n"
100
- )
101
- evaluations.append(summary)
102
-
103
- full_chat_transcript += f"\nπŸ‘€ User: {user_input}\nπŸ€– {model_name}: {model_sql}\n"
104
-
105
- return full_chat_transcript.strip(), chat_history, "\n\n".join(evaluations)
106
 
107
- # ------------------------
108
- # Load Random Spider Prompt
109
- # ------------------------
110
- def get_random_spider_prompt():
111
- sample = random.choice(spider_dataset)
112
- return sample["question"], sample["query"], sample["query"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
 
114
  # ------------------------
115
- # Gradio UI
116
  # ------------------------
117
  with gr.Blocks() as demo:
118
- gr.Markdown("## 🧠 Spider Dataset Model Evaluation")
119
-
120
- prompt_input = gr.Textbox(label="Your Prompt", lines=3, placeholder="Ask your BI question...")
121
- expected_sql_display = gr.Textbox(label="Expected SQL", lines=2, interactive=False)
122
-
123
- load_spider_btn = gr.Button("πŸ”€ Load Random Spider Prompt")
124
- run_button = gr.Button("Send & Evaluate All Models")
125
-
126
- chat_display = gr.Textbox(label="Chat History", lines=20, interactive=False)
127
- evaluation_display = gr.Markdown()
128
 
129
- chat_memory = gr.State([])
130
- expected_sql = gr.State("")
 
131
 
132
- load_spider_btn.click(
133
- fn=get_random_spider_prompt,
134
- inputs=[],
135
- outputs=[prompt_input, expected_sql, expected_sql_display]
136
- )
137
 
138
- run_button.click(
139
- fn=evaluate_all_models,
140
- inputs=[prompt_input, expected_sql, chat_memory],
141
- outputs=[chat_display, chat_memory, evaluation_display]
142
- )
143
 
144
  # Launch
145
- demo.launch()
 
 
3
  import time
4
  import json
5
  import gradio as gr
6
+ import csv
7
+ from datetime import datetime
8
  from huggingface_hub import InferenceClient
9
 
10
  # ------------------------
 
38
  # ------------------------
39
  # Prompt Template for SQL Generation
40
  # ------------------------
41
+
42
  def build_prompt(user_question):
43
  return f"""You are an expert SQL assistant. Convert the given question into a valid SQL query.
44
 
 
60
  A:"""
61
 
62
  # ------------------------
63
+ # Evaluation + Batch Logic
64
  # ------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ def evaluate_batch(n=50):
67
+ timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
68
+ output_path = f"evaluation_results_{timestamp}.csv"
69
+
70
+ results = []
71
+ selected_samples = random.sample(spider_dataset, n)
72
+
73
+ for idx, sample in enumerate(selected_samples):
74
+ user_question = sample["question"]
75
+ expected_sql = sample["query"]
76
+ prompt = build_prompt(user_question)
77
+
78
+ row = {
79
+ "question": user_question,
80
+ "gold_sql": expected_sql
81
+ }
82
+
83
+ for model_name, model_config in model_list.items():
84
+ client = model_config["client"]
85
+ model_id = model_config["model_id"]
86
+ try:
87
+ start_time = time.time()
88
+ result = client.chat.completions.create(
89
+ model=model_id,
90
+ messages=[{"role": "user", "content": prompt}]
91
+ )
92
+ model_sql = result.choices[0].message.content
93
+ latency = int((time.time() - start_time) * 1000)
94
+ except Exception as e:
95
+ model_sql = f"ERROR: {str(e)}"
96
+ latency = -1
97
+
98
+ sql_gen_accuracy = "βœ…" if expected_sql.strip().lower() in model_sql.strip().lower() else "❌"
99
+ exec_response_accuracy = "βœ…" if sql_gen_accuracy == "βœ…" else "❌"
100
+ intent_clarity = "βœ…" if len(user_question.strip().split()) < 5 and "SELECT" in model_sql.upper() else "❌"
101
+ semantic_clarity = "βœ…" if any(word in model_sql.lower() for word in ["from", "join", "group by"]) else "❌"
102
+ latency_status = "βœ…" if latency <= 1000 else "❌"
103
+
104
+ row.update({
105
+ f"{model_name}_sql": model_sql,
106
+ f"{model_name}_sql_match": sql_gen_accuracy,
107
+ f"{model_name}_exec_match": exec_response_accuracy,
108
+ f"{model_name}_intent_clarity": intent_clarity,
109
+ f"{model_name}_semantic_clarity": semantic_clarity,
110
+ f"{model_name}_latency_ms": latency,
111
+ f"{model_name}_latency_status": latency_status
112
+ })
113
+
114
+ results.append(row)
115
+ print(f"[{idx+1}/{n}] Done: {user_question}")
116
+
117
+ # Save to CSV
118
+ fieldnames = results[0].keys()
119
+ with open(output_path, "w", newline="") as f:
120
+ writer = csv.DictWriter(f, fieldnames=fieldnames)
121
+ writer.writeheader()
122
+ writer.writerows(results)
123
+
124
+ print(f"\nβœ… Evaluation completed and saved to: {output_path}")
125
+ return output_path
126
 
127
  # ------------------------
128
+ # Gradio UI for batch evaluation
129
  # ------------------------
130
  with gr.Blocks() as demo:
131
+ gr.Markdown("## 🧠 Run Batch Evaluation on Spider Dataset")
 
 
 
 
 
 
 
 
 
132
 
133
+ num_samples = gr.Slider(10, 100, value=50, step=10, label="Number of Random Samples")
134
+ run_button = gr.Button("πŸš€ Run Evaluation")
135
+ download_output = gr.File(label="Download Evaluation CSV")
136
 
137
+ def run_eval(n):
138
+ return evaluate_batch(n)
 
 
 
139
 
140
+ run_button.click(fn=run_eval, inputs=[num_samples], outputs=[download_output])
 
 
 
 
141
 
142
  # Launch
143
+ if __name__ == "__main__":
144
+ demo.launch()