LeonceNsh commited on
Commit
88c83f6
·
verified ·
1 Parent(s): 00c05fa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -115
app.py CHANGED
@@ -4,11 +4,11 @@ import duckdb
4
  from functools import lru_cache
5
  import pandas as pd
6
  import plotly.express as px
7
- import openai
8
  import os
 
9
 
10
  # Set OpenAI API key
11
- openai.api_key = os.getenv("OPENAI_API_KEY")
12
 
13
  # =========================
14
  # Configuration and Setup
@@ -21,35 +21,13 @@ dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file pat
21
  schema = [
22
  {"column_name": "department_ind_agency", "column_type": "VARCHAR"},
23
  {"column_name": "cgac", "column_type": "BIGINT"},
24
- {"column_name": "sub_tier", "column_type": "VARCHAR"},
25
- {"column_name": "fpds_code", "column_type": "VARCHAR"},
26
- {"column_name": "office", "column_type": "VARCHAR"},
27
- {"column_name": "aac_code", "column_type": "VARCHAR"},
28
- {"column_name": "posteddate", "column_type": "VARCHAR"},
29
- {"column_name": "type", "column_type": "VARCHAR"},
30
- {"column_name": "basetype", "column_type": "VARCHAR"},
31
- {"column_name": "popstreetaddress", "column_type": "VARCHAR"},
32
- {"column_name": "popcity", "column_type": "VARCHAR"},
33
- {"column_name": "popstate", "column_type": "VARCHAR"},
34
- {"column_name": "popzip", "column_type": "VARCHAR"},
35
- {"column_name": "popcountry", "column_type": "VARCHAR"},
36
- {"column_name": "active", "column_type": "VARCHAR"},
37
- {"column_name": "awardnumber", "column_type": "VARCHAR"},
38
- {"column_name": "awarddate", "column_type": "VARCHAR"},
39
- {"column_name": "award", "column_type": "DOUBLE"},
40
- {"column_name": "awardee", "column_type": "VARCHAR"},
41
- {"column_name": "state", "column_type": "VARCHAR"},
42
- {"column_name": "city", "column_type": "VARCHAR"},
43
- {"column_name": "zipcode", "column_type": "VARCHAR"},
44
- {"column_name": "countrycode", "column_type": "VARCHAR"}
45
  ]
46
 
47
- # Cache the schema loading
48
  @lru_cache(maxsize=1)
49
  def get_schema():
50
  return schema
51
 
52
- # Map column names to their types
53
  COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
54
 
55
  # =========================
@@ -62,7 +40,6 @@ def load_dataset_schema():
62
  """
63
  con = duckdb.connect()
64
  try:
65
- # Drop the view if it exists to avoid errors
66
  con.execute("DROP VIEW IF EXISTS contract_data")
67
  con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
68
  return True
@@ -78,28 +55,21 @@ def load_dataset_schema():
78
 
79
  async def parse_query(nl_query):
80
  """
81
- Converts a natural language query into a SQL query using OpenAI's GPT-3.5-turbo model.
82
  """
83
-
84
  messages = [
85
- {"role": "system",
86
- "content": (
87
- "You are an assistant that converts natural language queries into SQL queries "
88
- "for a DuckDB database named 'contract_data'. Use the provided schema to form accurate SQL queries.")
89
- },
90
- {"role": "user",
91
- "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nNatural Language Query:\n\"{nl_query}\"\n\nSQL Query:"
92
- }
93
  ]
94
 
95
  try:
96
- response = await openai.ChatCompletion.acreate(
97
  model="gpt-3.5-turbo",
98
  messages=messages,
99
- temperature=0, # Set to 0 for deterministic output
100
  max_tokens=150,
101
  )
102
- sql_query = response.choices[0].message['content'].strip()
103
  return sql_query
104
  except Exception as e:
105
  return f"Error generating SQL query: {e}"
@@ -110,27 +80,19 @@ async def parse_query(nl_query):
110
 
111
  def detect_plot_intent(nl_query):
112
  """
113
- Detects if the user's query involves plotting based on the presence of specific keywords.
114
  """
115
- plot_keywords = [
116
- 'plot', 'graph', 'chart', 'distribution', 'visualize', 'histogram',
117
- 'bar chart', 'line chart', 'scatter plot', 'pie chart'
118
- ]
119
- for keyword in plot_keywords:
120
- if keyword in nl_query.lower():
121
- return True
122
- return False
123
 
124
  async def generate_sql_and_plot_code(query):
125
  """
126
- Generates SQL query and plotting code based on the natural language input.
127
  """
128
  is_plot = detect_plot_intent(query)
129
  sql_query = await parse_query(query)
130
  plot_code = ""
131
  if is_plot and not sql_query.startswith("Error"):
132
- # Generate plot code based on the query
133
- # For simplicity, we'll generate a basic plot code
134
  plot_code = """
135
  import plotly.express as px
136
  fig = px.bar(result_df, x='x_column', y='y_column', title='Generated Plot')
@@ -140,20 +102,18 @@ fig.update_layout(title_x=0.5)
140
 
141
  def execute_query(sql_query):
142
  """
143
- Executes the SQL query and returns results or an error message.
144
  """
145
  if sql_query.startswith("Error"):
146
- return None, sql_query # Pass the error message forward
147
 
148
  try:
149
  con = duckdb.connect()
150
- # Ensure the view is created
151
  con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
152
  result_df = con.execute(sql_query).fetchdf()
153
  con.close()
154
  return result_df, ""
155
  except Exception as e:
156
- # In case of error, return None and error message
157
  return None, f"Error executing query: {e}"
158
 
159
  def generate_plot(plot_code, result_df):
@@ -163,41 +123,18 @@ def generate_plot(plot_code, result_df):
163
  if not plot_code.strip():
164
  return None, "No plot code provided."
165
  try:
166
- # Replace placeholders in plot_code with actual column names
167
- if result_df.empty:
168
- return None, "Result DataFrame is empty."
169
  columns = result_df.columns.tolist()
170
  if len(columns) < 2:
171
  return None, "Not enough columns to plot."
172
  plot_code = plot_code.replace('x_column', columns[0])
173
  plot_code = plot_code.replace('y_column', columns[1])
174
-
175
- # Execute the plot code
176
  local_vars = {'result_df': result_df, 'px': px}
177
  exec(plot_code, {}, local_vars)
178
  fig = local_vars.get('fig', None)
179
- if fig:
180
- return fig, ""
181
- else:
182
- return None, "Plot could not be generated."
183
  except Exception as e:
184
  return None, f"Error generating plot: {e}"
185
 
186
- # =========================
187
- # Schema Display
188
- # =========================
189
-
190
- @lru_cache(maxsize=1)
191
- def get_schema_json():
192
- return json.dumps(get_schema(), indent=2)
193
-
194
- # =========================
195
- # Initialize Dataset Schema
196
- # =========================
197
-
198
- if not load_dataset_schema():
199
- raise Exception("Failed to load dataset schema. Please check the dataset path and format.")
200
-
201
  # =========================
202
  # Gradio Application UI
203
  # =========================
@@ -210,36 +147,17 @@ with gr.Blocks() as demo:
210
 
211
  ## Instructions
212
 
213
- 1. **Describe the data you want to retrieve or plot**: For example:
214
- - `Show all awards greater than 1,000,000 in California`
215
- - `Plot the distribution of awards by state`
216
- - `Show a bar chart of total awards per department`
217
- - `List awardees who received multiple awards along with award amounts`
218
- - `Number of awards issued by each department division`
219
-
220
- 2. **Generate SQL**: Click "Generate SQL" to see the SQL query that will be executed.
221
- 3. **Execute Query**: Click "Execute Query" to run the query and view the results.
222
- 4. **View Plot**: If your query involves plotting, the plot will be displayed.
223
- 5. **View Dataset Schema**: Check the "Dataset Schema" tab to understand available columns and their types.
224
-
225
- ## Example Queries
226
-
227
- - `Plot the total award amount by state`
228
- - `Show a histogram of awards over time`
229
- - `award greater than 1000000 and state equal to "CA"`
230
- - `List awards where department_ind_agency contains "Defense"`
231
  """)
232
 
233
  with gr.Tabs():
234
- # Query Tab
235
  with gr.TabItem("Query Data"):
236
  with gr.Row():
237
  with gr.Column(scale=1):
238
- query = gr.Textbox(
239
- label="Natural Language Query",
240
- placeholder='e.g., "Show all awards greater than 1,000,000 in California"',
241
- lines=4
242
- )
243
  btn_generate = gr.Button("Generate SQL")
244
  sql_out = gr.Code(label="Generated SQL Query", language="sql")
245
  plot_code_out = gr.Code(label="Generated Plot Code", language="python")
@@ -249,10 +167,9 @@ with gr.Blocks() as demo:
249
  results_out = gr.Dataframe(label="Query Results", interactive=False)
250
  plot_out = gr.Plot(label="Plot")
251
 
252
- # Schema Tab
253
  with gr.TabItem("Dataset Schema"):
254
  gr.Markdown("### Dataset Schema")
255
- schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json()))
256
 
257
  # =========================
258
  # Click Event Handlers
@@ -281,16 +198,8 @@ with gr.Blocks() as demo:
281
  else:
282
  return result_df, None, ""
283
 
284
- btn_generate.click(
285
- fn=on_generate_click,
286
- inputs=query,
287
- outputs=[sql_out, plot_code_out],
288
- )
289
- btn_execute.click(
290
- fn=on_execute_click,
291
- inputs=[sql_out, plot_code_out],
292
- outputs=[results_out, plot_out, error_out],
293
- )
294
 
295
  # =========================
296
  # Launch the Gradio App
 
4
  from functools import lru_cache
5
  import pandas as pd
6
  import plotly.express as px
 
7
  import os
8
+ from openai import OpenAI
9
 
10
  # Set OpenAI API key
11
+ client = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
12
 
13
  # =========================
14
  # Configuration and Setup
 
21
  schema = [
22
  {"column_name": "department_ind_agency", "column_type": "VARCHAR"},
23
  {"column_name": "cgac", "column_type": "BIGINT"},
24
+ # Additional columns go here...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  ]
26
 
 
27
  @lru_cache(maxsize=1)
28
  def get_schema():
29
  return schema
30
 
 
31
  COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
32
 
33
  # =========================
 
40
  """
41
  con = duckdb.connect()
42
  try:
 
43
  con.execute("DROP VIEW IF EXISTS contract_data")
44
  con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
45
  return True
 
55
 
56
  async def parse_query(nl_query):
57
  """
58
+ Converts a natural language query into a SQL query using OpenAI's API.
59
  """
 
60
  messages = [
61
+ {"role": "system", "content": "Convert natural language queries to SQL queries for 'contract_data'."},
62
+ {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
 
 
 
 
 
 
63
  ]
64
 
65
  try:
66
+ response = await client.chat.completions.create(
67
  model="gpt-3.5-turbo",
68
  messages=messages,
69
+ temperature=0,
70
  max_tokens=150,
71
  )
72
+ sql_query = response.choices[0].message.content.strip()
73
  return sql_query
74
  except Exception as e:
75
  return f"Error generating SQL query: {e}"
 
80
 
81
  def detect_plot_intent(nl_query):
82
  """
83
+ Detects if the user's query involves plotting.
84
  """
85
+ plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize']
86
+ return any(keyword in nl_query.lower() for keyword in plot_keywords)
 
 
 
 
 
 
87
 
88
  async def generate_sql_and_plot_code(query):
89
  """
90
+ Generates SQL query and optional plotting code.
91
  """
92
  is_plot = detect_plot_intent(query)
93
  sql_query = await parse_query(query)
94
  plot_code = ""
95
  if is_plot and not sql_query.startswith("Error"):
 
 
96
  plot_code = """
97
  import plotly.express as px
98
  fig = px.bar(result_df, x='x_column', y='y_column', title='Generated Plot')
 
102
 
103
  def execute_query(sql_query):
104
  """
105
+ Executes the SQL query and returns the results.
106
  """
107
  if sql_query.startswith("Error"):
108
+ return None, sql_query
109
 
110
  try:
111
  con = duckdb.connect()
 
112
  con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
113
  result_df = con.execute(sql_query).fetchdf()
114
  con.close()
115
  return result_df, ""
116
  except Exception as e:
 
117
  return None, f"Error executing query: {e}"
118
 
119
  def generate_plot(plot_code, result_df):
 
123
  if not plot_code.strip():
124
  return None, "No plot code provided."
125
  try:
 
 
 
126
  columns = result_df.columns.tolist()
127
  if len(columns) < 2:
128
  return None, "Not enough columns to plot."
129
  plot_code = plot_code.replace('x_column', columns[0])
130
  plot_code = plot_code.replace('y_column', columns[1])
 
 
131
  local_vars = {'result_df': result_df, 'px': px}
132
  exec(plot_code, {}, local_vars)
133
  fig = local_vars.get('fig', None)
134
+ return fig, "" if fig else "Plot could not be generated."
 
 
 
135
  except Exception as e:
136
  return None, f"Error generating plot: {e}"
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  # =========================
139
  # Gradio Application UI
140
  # =========================
 
147
 
148
  ## Instructions
149
 
150
+ 1. **Describe the data you want**: e.g., `Show awards over 1M in CA`
151
+ 2. **Generate SQL**: Click "Generate SQL" to see the SQL query.
152
+ 3. **Execute Query**: Run the query to view results and plots.
153
+ 4. **Dataset Schema**: See available columns and types in the "Schema" tab.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  """)
155
 
156
  with gr.Tabs():
 
157
  with gr.TabItem("Query Data"):
158
  with gr.Row():
159
  with gr.Column(scale=1):
160
+ query = gr.Textbox(label="Natural Language Query", placeholder='e.g., "Awards > 1M in CA"')
 
 
 
 
161
  btn_generate = gr.Button("Generate SQL")
162
  sql_out = gr.Code(label="Generated SQL Query", language="sql")
163
  plot_code_out = gr.Code(label="Generated Plot Code", language="python")
 
167
  results_out = gr.Dataframe(label="Query Results", interactive=False)
168
  plot_out = gr.Plot(label="Plot")
169
 
 
170
  with gr.TabItem("Dataset Schema"):
171
  gr.Markdown("### Dataset Schema")
172
+ schema_display = gr.JSON(label="Schema", value=json.loads(json.dumps(get_schema(), indent=2)))
173
 
174
  # =========================
175
  # Click Event Handlers
 
198
  else:
199
  return result_df, None, ""
200
 
201
+ btn_generate.click(fn=on_generate_click, inputs=query, outputs=[sql_out, plot_code_out])
202
+ btn_execute.click(fn=on_execute_click, inputs=[sql_out, plot_code_out], outputs=[results_out, plot_out, error_out])
 
 
 
 
 
 
 
 
203
 
204
  # =========================
205
  # Launch the Gradio App