Mubbashir Ahmed commited on
Commit
dccb5da
Β·
1 Parent(s): 26fe788

added prompting

Browse files
Files changed (1) hide show
  1. app.py +63 -10
app.py CHANGED
@@ -4,6 +4,7 @@ import time
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
  from datasets import load_dataset
 
7
 
8
  # ------------------------
9
  # Auth
@@ -15,6 +16,22 @@ HF_TOKEN = os.environ.get("HF_TOKEN")
15
  # ------------------------
16
  spider_dataset = load_dataset("spider", split="train")
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  # ------------------------
19
  # Inference Clients
20
  # ------------------------
@@ -33,9 +50,42 @@ model_list = {
33
  }
34
 
35
  # ------------------------
36
- # Inference + Evaluation Logic
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  # ------------------------
38
- def evaluate_all_models(user_input, expected_sql, chat_history):
39
  evaluations = []
40
  full_chat_transcript = ""
41
 
@@ -43,22 +93,23 @@ def evaluate_all_models(user_input, expected_sql, chat_history):
43
  client = model_config["client"]
44
  model_id = model_config["model_id"]
45
 
46
- messages = chat_history + [{"role": "user", "content": user_input}]
 
 
47
  try:
48
  start_time = time.time()
49
-
50
  result = client.chat.completions.create(
51
  model=model_id,
52
  messages=messages
53
  )
54
- model_sql = result.choices[0].message.content
55
  latency = int((time.time() - start_time) * 1000)
56
 
57
  except Exception as e:
58
  model_sql = f"⚠️ Error: {str(e)}"
59
  latency = -1
60
 
61
- # Evaluation criteria (simulated)
62
  sql_gen_accuracy = "βœ…" if expected_sql.strip().lower() in model_sql.strip().lower() else "❌"
63
  exec_response_accuracy = "βœ…" if sql_gen_accuracy == "βœ…" else "❌"
64
  intent_clarity = "βœ…" if len(user_input.strip().split()) < 5 and "SELECT" in model_sql.upper() else "❌"
@@ -84,16 +135,17 @@ def evaluate_all_models(user_input, expected_sql, chat_history):
84
  # ------------------------
85
  def get_random_spider_prompt():
86
  sample = random.choice(spider_dataset)
87
- return sample["question"], sample["query"], sample["query"]
88
 
89
  # ------------------------
90
  # Gradio UI
91
  # ------------------------
92
  with gr.Blocks() as demo:
93
- gr.Markdown("## 🧠 Spider Dataset Model Evaluation")
94
 
95
  prompt_input = gr.Textbox(label="Your Prompt", lines=3, placeholder="Ask your BI question...")
96
  expected_sql_display = gr.Textbox(label="Expected SQL", lines=2, interactive=False)
 
97
 
98
  load_spider_btn = gr.Button("πŸ”€ Load Random Spider Prompt")
99
  run_button = gr.Button("Send & Evaluate All Models")
@@ -103,16 +155,17 @@ with gr.Blocks() as demo:
103
 
104
  chat_memory = gr.State([])
105
  expected_sql = gr.State("")
 
106
 
107
  load_spider_btn.click(
108
  fn=get_random_spider_prompt,
109
  inputs=[],
110
- outputs=[prompt_input, expected_sql, expected_sql_display]
111
  )
112
 
113
  run_button.click(
114
  fn=evaluate_all_models,
115
- inputs=[prompt_input, expected_sql, chat_memory],
116
  outputs=[chat_display, chat_memory, evaluation_display]
117
  )
118
 
 
4
  import gradio as gr
5
  from huggingface_hub import InferenceClient
6
  from datasets import load_dataset
7
+ import json
8
 
9
  # ------------------------
10
  # Auth
 
16
  # ------------------------
17
  spider_dataset = load_dataset("spider", split="train")
18
 
19
+ # Load table schemas from Spider
20
+ with open("spider/tables.json", "r") as f:
21
+ tables_json = json.load(f)
22
+
23
+ # Build db_id β†’ schema_string mapping
24
+ def extract_schema(db_id):
25
+ for db in tables_json:
26
+ if db["db_id"] == db_id:
27
+ tables = []
28
+ for table_name, columns in zip(db["table_names_original"], db["column_names_original"]):
29
+ col_list = [col[1] for col in db["column_names_original"] if col[0] == db["table_names_original"].index(table_name)]
30
+ table_def = f"{table_name}({', '.join(col for col in col_list if col != '*')})"
31
+ tables.append(table_def)
32
+ return "\n".join(tables)
33
+ return "Schema not found."
34
+
35
  # ------------------------
36
  # Inference Clients
37
  # ------------------------
 
50
  }
51
 
52
  # ------------------------
53
+ # Few-shot examples
54
+ # ------------------------
55
+ few_shot_examples = """Q: Show all department names.
56
+ A: SELECT name FROM department;
57
+
58
+ Q: Count number of students.
59
+ A: SELECT COUNT(*) FROM student;"""
60
+
61
+ # ------------------------
62
+ # Prompt Constructor
63
+ # ------------------------
64
+ def build_sql_prompt(user_question, db_id):
65
+ schema = extract_schema(db_id)
66
+ prompt = f"""You are an expert SQL assistant. Convert the given question into a valid SQL query using the database schema provided below.
67
+
68
+ Instructions:
69
+ - Respond with only the SQL query.
70
+ - Do not include markdown, explanations, or additional formatting.
71
+ - Use correct table and column names from the schema.
72
+ - Follow SQL best practices and Spider dataset formatting.
73
+
74
+ Schema (db_id: {db_id}):
75
+ {schema}
76
+
77
+ Examples:
78
+ {few_shot_examples}
79
+
80
+ Now answer this:
81
+ Q: {user_question}
82
+ A:"""
83
+ return prompt
84
+
85
+ # ------------------------
86
+ # Evaluate Models with Engineered Prompt
87
  # ------------------------
88
+ def evaluate_all_models(user_input, expected_sql, db_id, chat_history):
89
  evaluations = []
90
  full_chat_transcript = ""
91
 
 
93
  client = model_config["client"]
94
  model_id = model_config["model_id"]
95
 
96
+ prompt = build_sql_prompt(user_input, db_id)
97
+ messages = [{"role": "user", "content": prompt}]
98
+
99
  try:
100
  start_time = time.time()
 
101
  result = client.chat.completions.create(
102
  model=model_id,
103
  messages=messages
104
  )
105
+ model_sql = result.choices[0].message.content.strip()
106
  latency = int((time.time() - start_time) * 1000)
107
 
108
  except Exception as e:
109
  model_sql = f"⚠️ Error: {str(e)}"
110
  latency = -1
111
 
112
+ # Evaluation criteria
113
  sql_gen_accuracy = "βœ…" if expected_sql.strip().lower() in model_sql.strip().lower() else "❌"
114
  exec_response_accuracy = "βœ…" if sql_gen_accuracy == "βœ…" else "❌"
115
  intent_clarity = "βœ…" if len(user_input.strip().split()) < 5 and "SELECT" in model_sql.upper() else "❌"
 
135
  # ------------------------
136
  def get_random_spider_prompt():
137
  sample = random.choice(spider_dataset)
138
+ return sample["question"], sample["query"], sample["query"], sample["db_id"]
139
 
140
  # ------------------------
141
  # Gradio UI
142
  # ------------------------
143
  with gr.Blocks() as demo:
144
+ gr.Markdown("## 🧠 Advanced SQL Generation Evaluation (Spider Dataset + Prompt Engineering)")
145
 
146
  prompt_input = gr.Textbox(label="Your Prompt", lines=3, placeholder="Ask your BI question...")
147
  expected_sql_display = gr.Textbox(label="Expected SQL", lines=2, interactive=False)
148
+ dbid_display = gr.Textbox(label="DB ID", lines=1, interactive=False)
149
 
150
  load_spider_btn = gr.Button("πŸ”€ Load Random Spider Prompt")
151
  run_button = gr.Button("Send & Evaluate All Models")
 
155
 
156
  chat_memory = gr.State([])
157
  expected_sql = gr.State("")
158
+ db_id = gr.State("")
159
 
160
  load_spider_btn.click(
161
  fn=get_random_spider_prompt,
162
  inputs=[],
163
+ outputs=[prompt_input, expected_sql, expected_sql_display, db_id, dbid_display]
164
  )
165
 
166
  run_button.click(
167
  fn=evaluate_all_models,
168
+ inputs=[prompt_input, expected_sql, db_id, chat_memory],
169
  outputs=[chat_display, chat_memory, evaluation_display]
170
  )
171