Mubbashir Ahmed commited on
Commit
609f0d2
·
1 Parent(s): e960914

removed spider folder and updated code

Browse files
Files changed (3) hide show
  1. app.py +11 -17
  2. spider +0 -1
  3. train_spider.json +0 -3
app.py CHANGED
@@ -3,7 +3,6 @@ import random
3
  import time
4
  import json
5
  import gradio as gr
6
- from datasets import Dataset
7
  from huggingface_hub import InferenceClient
8
 
9
  # ------------------------
@@ -12,17 +11,10 @@ from huggingface_hub import InferenceClient
12
  HF_TOKEN = os.environ.get("HF_TOKEN")
13
 
14
  # ------------------------
15
- # Clone Spider Dataset if Not Exists
16
  # ------------------------
17
- if not os.path.exists("spider/train_spider.json"):
18
- os.system("git clone https://github.com/taoyds/spider.git")
19
-
20
- # ------------------------
21
- # Load Spider Dataset
22
- # ------------------------
23
- with open("spider/train_spider.json", "r") as f:
24
- spider_raw = json.load(f)
25
- spider_dataset = Dataset.from_list(spider_raw)
26
 
27
  # ------------------------
28
  # Inference Clients
@@ -42,7 +34,7 @@ model_list = {
42
  }
43
 
44
  # ------------------------
45
- # Prompt Engineering Template
46
  # ------------------------
47
  def build_prompt(user_question):
48
  return f"""You are an expert SQL assistant. Convert the given question into a valid SQL query.
@@ -50,7 +42,7 @@ def build_prompt(user_question):
50
  Instructions:
51
  - Respond with only the SQL query.
52
  - Do not include markdown, explanations, or additional formatting.
53
- - Use correct table and column names from the schema.
54
  - Follow SQL best practices and Spider dataset formatting.
55
 
56
  Examples:
@@ -70,15 +62,16 @@ A:"""
70
  def evaluate_all_models(user_input, expected_sql, chat_history):
71
  evaluations = []
72
  full_chat_transcript = ""
73
- engineered_prompt = build_prompt(user_input)
74
 
75
  for model_name, model_config in model_list.items():
76
  client = model_config["client"]
77
  model_id = model_config["model_id"]
78
 
79
- messages = chat_history + [{"role": "user", "content": engineered_prompt}]
80
  try:
81
  start_time = time.time()
 
82
  result = client.chat.completions.create(
83
  model=model_id,
84
  messages=messages
@@ -106,6 +99,7 @@ def evaluate_all_models(user_input, expected_sql, chat_history):
106
  f"- Response Latency: {latency} ms ({latency_status})\n"
107
  )
108
  evaluations.append(summary)
 
109
  full_chat_transcript += f"\n👤 User: {user_input}\n🤖 {model_name}: {model_sql}\n"
110
 
111
  return full_chat_transcript.strip(), chat_history, "\n\n".join(evaluations)
@@ -121,7 +115,7 @@ def get_random_spider_prompt():
121
  # Gradio UI
122
  # ------------------------
123
  with gr.Blocks() as demo:
124
- gr.Markdown("## 🧠 Spider Dataset Model Evaluation with Prompt Engineering")
125
 
126
  prompt_input = gr.Textbox(label="Your Prompt", lines=3, placeholder="Ask your BI question...")
127
  expected_sql_display = gr.Textbox(label="Expected SQL", lines=2, interactive=False)
@@ -147,5 +141,5 @@ with gr.Blocks() as demo:
147
  outputs=[chat_display, chat_memory, evaluation_display]
148
  )
149
 
150
- # Launch the app
151
  demo.launch()
 
3
  import time
4
  import json
5
  import gradio as gr
 
6
  from huggingface_hub import InferenceClient
7
 
8
  # ------------------------
 
11
  HF_TOKEN = os.environ.get("HF_TOKEN")
12
 
13
  # ------------------------
14
+ # Load Spider Dataset (local JSON)
15
  # ------------------------
16
+ with open("train_spider.json", "r") as f:
17
+ spider_dataset = json.load(f)
 
 
 
 
 
 
 
18
 
19
  # ------------------------
20
  # Inference Clients
 
34
  }
35
 
36
  # ------------------------
37
+ # Prompt Template for SQL Generation
38
  # ------------------------
39
  def build_prompt(user_question):
40
  return f"""You are an expert SQL assistant. Convert the given question into a valid SQL query.
 
42
  Instructions:
43
  - Respond with only the SQL query.
44
  - Do not include markdown, explanations, or additional formatting.
45
+ - Use correct table and column names.
46
  - Follow SQL best practices and Spider dataset formatting.
47
 
48
  Examples:
 
62
  def evaluate_all_models(user_input, expected_sql, chat_history):
63
  evaluations = []
64
  full_chat_transcript = ""
65
+ prompt = build_prompt(user_input)
66
 
67
  for model_name, model_config in model_list.items():
68
  client = model_config["client"]
69
  model_id = model_config["model_id"]
70
 
71
+ messages = chat_history + [{"role": "user", "content": prompt}]
72
  try:
73
  start_time = time.time()
74
+
75
  result = client.chat.completions.create(
76
  model=model_id,
77
  messages=messages
 
99
  f"- Response Latency: {latency} ms ({latency_status})\n"
100
  )
101
  evaluations.append(summary)
102
+
103
  full_chat_transcript += f"\n👤 User: {user_input}\n🤖 {model_name}: {model_sql}\n"
104
 
105
  return full_chat_transcript.strip(), chat_history, "\n\n".join(evaluations)
 
115
  # Gradio UI
116
  # ------------------------
117
  with gr.Blocks() as demo:
118
+ gr.Markdown("## 🧠 Spider Dataset Model Evaluation")
119
 
120
  prompt_input = gr.Textbox(label="Your Prompt", lines=3, placeholder="Ask your BI question...")
121
  expected_sql_display = gr.Textbox(label="Expected SQL", lines=2, interactive=False)
 
141
  outputs=[chat_display, chat_memory, evaluation_display]
142
  )
143
 
144
+ # Launch
145
  demo.launch()
spider DELETED
@@ -1 +0,0 @@
1
- Subproject commit b7b5b8c890cd30e35427348bb9eb8c6d1350ca7c
 
 
train_spider.json DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c43d0d72e59e1a9e1a60837da9bf70d5a6277226bdb7f634d544f380646f527a
3
- size 24928884