Mubbashir Ahmed commited on
Commit
3f087c6
Β·
1 Parent(s): dccb5da
Files changed (2) hide show
  1. app.py +18 -45
  2. spider +1 -0
app.py CHANGED
@@ -4,7 +4,6 @@ import time
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
  from datasets import load_dataset
7
- import json
8
 
9
  # ------------------------
10
  # Auth
@@ -16,22 +15,6 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
16
  # ------------------------
17
  spider_dataset = load_dataset("spider", split="train")
18
 
19
- # Load table schemas from Spider
20
- with open("spider/tables.json", "r") as f:
21
- tables_json = json.load(f)
22
-
23
- # Build db_id β†’ schema_string mapping
24
- def extract_schema(db_id):
25
- for db in tables_json:
26
- if db["db_id"] == db_id:
27
- tables = []
28
- for table_name, columns in zip(db["table_names_original"], db["column_names_original"]):
29
- col_list = [col[1] for col in db["column_names_original"] if col[0] == db["table_names_original"].index(table_name)]
30
- table_def = f"{table_name}({', '.join(col for col in col_list if col != '*')})"
31
- tables.append(table_def)
32
- return "\n".join(tables)
33
- return "Schema not found."
34
-
35
  # ------------------------
36
  # Inference Clients
37
  # ------------------------
@@ -50,42 +33,34 @@ model_list = {
50
  }
51
 
52
  # ------------------------
53
- # Few-shot examples
54
  # ------------------------
55
- few_shot_examples = """Q: Show all department names.
56
  A: SELECT name FROM department;
57
 
58
  Q: Count number of students.
59
- A: SELECT COUNT(*) FROM student;"""
 
60
 
61
- # ------------------------
62
- # Prompt Constructor
63
- # ------------------------
64
- def build_sql_prompt(user_question, db_id):
65
- schema = extract_schema(db_id)
66
- prompt = f"""You are an expert SQL assistant. Convert the given question into a valid SQL query using the database schema provided below.
67
 
68
  Instructions:
69
- - Respond with only the SQL query.
70
- - Do not include markdown, explanations, or additional formatting.
71
- - Use correct table and column names from the schema.
72
- - Follow SQL best practices and Spider dataset formatting.
73
-
74
- Schema (db_id: {db_id}):
75
- {schema}
76
 
77
  Examples:
78
- {few_shot_examples}
79
 
80
  Now answer this:
81
  Q: {user_question}
82
  A:"""
83
- return prompt
84
 
85
  # ------------------------
86
- # Evaluate Models with Engineered Prompt
87
  # ------------------------
88
- def evaluate_all_models(user_input, expected_sql, db_id, chat_history):
89
  evaluations = []
90
  full_chat_transcript = ""
91
 
@@ -93,7 +68,7 @@ def evaluate_all_models(user_input, expected_sql, db_id, chat_history):
93
  client = model_config["client"]
94
  model_id = model_config["model_id"]
95
 
96
- prompt = build_sql_prompt(user_input, db_id)
97
  messages = [{"role": "user", "content": prompt}]
98
 
99
  try:
@@ -109,7 +84,7 @@ def evaluate_all_models(user_input, expected_sql, db_id, chat_history):
109
  model_sql = f"⚠️ Error: {str(e)}"
110
  latency = -1
111
 
112
- # Evaluation criteria
113
  sql_gen_accuracy = "βœ…" if expected_sql.strip().lower() in model_sql.strip().lower() else "❌"
114
  exec_response_accuracy = "βœ…" if sql_gen_accuracy == "βœ…" else "❌"
115
  intent_clarity = "βœ…" if len(user_input.strip().split()) < 5 and "SELECT" in model_sql.upper() else "❌"
@@ -135,17 +110,16 @@ def evaluate_all_models(user_input, expected_sql, db_id, chat_history):
135
  # ------------------------
136
  def get_random_spider_prompt():
137
  sample = random.choice(spider_dataset)
138
- return sample["question"], sample["query"], sample["query"], sample["db_id"]
139
 
140
  # ------------------------
141
  # Gradio UI
142
  # ------------------------
143
  with gr.Blocks() as demo:
144
- gr.Markdown("## 🧠 Advanced SQL Generation Evaluation (Spider Dataset + Prompt Engineering)")
145
 
146
  prompt_input = gr.Textbox(label="Your Prompt", lines=3, placeholder="Ask your BI question...")
147
  expected_sql_display = gr.Textbox(label="Expected SQL", lines=2, interactive=False)
148
- dbid_display = gr.Textbox(label="DB ID", lines=1, interactive=False)
149
 
150
  load_spider_btn = gr.Button("πŸ”€ Load Random Spider Prompt")
151
  run_button = gr.Button("Send & Evaluate All Models")
@@ -155,17 +129,16 @@ with gr.Blocks() as demo:
155
 
156
  chat_memory = gr.State([])
157
  expected_sql = gr.State("")
158
- db_id = gr.State("")
159
 
160
  load_spider_btn.click(
161
  fn=get_random_spider_prompt,
162
  inputs=[],
163
- outputs=[prompt_input, expected_sql, expected_sql_display, db_id, dbid_display]
164
  )
165
 
166
  run_button.click(
167
  fn=evaluate_all_models,
168
- inputs=[prompt_input, expected_sql, db_id, chat_memory],
169
  outputs=[chat_display, chat_memory, evaluation_display]
170
  )
171
 
 
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
  from datasets import load_dataset
 
7
 
8
  # ------------------------
9
  # Auth
 
15
  # ------------------------
16
  spider_dataset = load_dataset("spider", split="train")
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # ------------------------
19
  # Inference Clients
20
  # ------------------------
 
33
  }
34
 
35
  # ------------------------
36
+ # Prompt Construction
37
  # ------------------------
38
+ FEW_SHOT = """Q: Show all department names.
39
  A: SELECT name FROM department;
40
 
41
  Q: Count number of students.
42
+ A: SELECT COUNT(*) FROM student;
43
+ """
44
 
45
+ def build_prompt(user_question):
46
+ return f"""You are an expert SQL assistant. Convert the given question into a valid SQL query.
 
 
 
 
47
 
48
  Instructions:
49
+ - Return only the SQL query.
50
+ - Do not include markdown, explanations, or formatting.
51
+ - Follow Spider dataset SQL syntax.
 
 
 
 
52
 
53
  Examples:
54
+ {FEW_SHOT}
55
 
56
  Now answer this:
57
  Q: {user_question}
58
  A:"""
 
59
 
60
  # ------------------------
61
+ # Inference + Evaluation Logic
62
  # ------------------------
63
+ def evaluate_all_models(user_input, expected_sql, chat_history):
64
  evaluations = []
65
  full_chat_transcript = ""
66
 
 
68
  client = model_config["client"]
69
  model_id = model_config["model_id"]
70
 
71
+ prompt = build_prompt(user_input)
72
  messages = [{"role": "user", "content": prompt}]
73
 
74
  try:
 
84
  model_sql = f"⚠️ Error: {str(e)}"
85
  latency = -1
86
 
87
+ # Evaluation criteria (simulated)
88
  sql_gen_accuracy = "βœ…" if expected_sql.strip().lower() in model_sql.strip().lower() else "❌"
89
  exec_response_accuracy = "βœ…" if sql_gen_accuracy == "βœ…" else "❌"
90
  intent_clarity = "βœ…" if len(user_input.strip().split()) < 5 and "SELECT" in model_sql.upper() else "❌"
 
110
  # ------------------------
111
  def get_random_spider_prompt():
112
  sample = random.choice(spider_dataset)
113
+ return sample["question"], sample["query"], sample["query"]
114
 
115
  # ------------------------
116
  # Gradio UI
117
  # ------------------------
118
  with gr.Blocks() as demo:
119
+ gr.Markdown("## 🧠 Spider Dataset Model Evaluation with Prompt Engineering")
120
 
121
  prompt_input = gr.Textbox(label="Your Prompt", lines=3, placeholder="Ask your BI question...")
122
  expected_sql_display = gr.Textbox(label="Expected SQL", lines=2, interactive=False)
 
123
 
124
  load_spider_btn = gr.Button("πŸ”€ Load Random Spider Prompt")
125
  run_button = gr.Button("Send & Evaluate All Models")
 
129
 
130
  chat_memory = gr.State([])
131
  expected_sql = gr.State("")
 
132
 
133
  load_spider_btn.click(
134
  fn=get_random_spider_prompt,
135
  inputs=[],
136
+ outputs=[prompt_input, expected_sql, expected_sql_display]
137
  )
138
 
139
  run_button.click(
140
  fn=evaluate_all_models,
141
+ inputs=[prompt_input, expected_sql, chat_memory],
142
  outputs=[chat_display, chat_memory, evaluation_display]
143
  )
144
 
spider ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit b7b5b8c890cd30e35427348bb9eb8c6d1350ca7c