LeonceNsh commited on
Commit
00c05fa
·
verified ·
1 Parent(s): a1792a1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -12
app.py CHANGED
@@ -7,6 +7,9 @@ import plotly.express as px
7
  import openai
8
  import os
9
 
 
 
 
10
  # =========================
11
  # Configuration and Setup
12
  # =========================
@@ -75,23 +78,25 @@ def load_dataset_schema():
75
 
76
  async def parse_query(nl_query):
77
  """
78
- Converts a natural language query into a SQL query using OpenAI's GPT-4-turbo model.
79
  """
 
80
  messages = [
81
- {"role": "system", "content": (
82
- "You are an assistant that converts natural language queries into SQL queries "
83
- "for a DuckDB database named 'contract_data'. Use the provided schema to form accurate SQL queries."
84
- )},
85
- {"role": "user", "content": (
86
- f"Schema:\n{json.dumps(schema, indent=2)}\n\nNatural Language Query:\n\"{nl_query}\"\n\nSQL Query:"
87
- )}
 
88
  ]
89
 
90
  try:
91
  response = await openai.ChatCompletion.acreate(
92
  model="gpt-3.5-turbo",
93
  messages=messages,
94
- temperature=0,
95
  max_tokens=150,
96
  )
97
  sql_query = response.choices[0].message['content'].strip()
@@ -114,7 +119,7 @@ def detect_plot_intent(nl_query):
114
  for keyword in plot_keywords:
115
  if keyword in nl_query.lower():
116
  return True
117
- return False
118
 
119
  async def generate_sql_and_plot_code(query):
120
  """
@@ -125,6 +130,7 @@ async def generate_sql_and_plot_code(query):
125
  plot_code = ""
126
  if is_plot and not sql_query.startswith("Error"):
127
  # Generate plot code based on the query
 
128
  plot_code = """
129
  import plotly.express as px
130
  fig = px.bar(result_df, x='x_column', y='y_column', title='Generated Plot')
@@ -141,11 +147,13 @@ def execute_query(sql_query):
141
 
142
  try:
143
  con = duckdb.connect()
 
144
  con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
145
  result_df = con.execute(sql_query).fetchdf()
146
  con.close()
147
  return result_df, ""
148
  except Exception as e:
 
149
  return None, f"Error executing query: {e}"
150
 
151
  def generate_plot(plot_code, result_df):
@@ -155,6 +163,7 @@ def generate_plot(plot_code, result_df):
155
  if not plot_code.strip():
156
  return None, "No plot code provided."
157
  try:
 
158
  if result_df.empty:
159
  return None, "Result DataFrame is empty."
160
  columns = result_df.columns.tolist()
@@ -163,10 +172,14 @@ def generate_plot(plot_code, result_df):
163
  plot_code = plot_code.replace('x_column', columns[0])
164
  plot_code = plot_code.replace('y_column', columns[1])
165
 
 
166
  local_vars = {'result_df': result_df, 'px': px}
167
  exec(plot_code, {}, local_vars)
168
  fig = local_vars.get('fig', None)
169
- return fig, "" if fig else "Plot could not be generated."
 
 
 
170
  except Exception as e:
171
  return None, f"Error generating plot: {e}"
172
 
@@ -194,9 +207,31 @@ with gr.Blocks() as demo:
194
  # Parquet SQL Query and Plotting App
195
 
196
  **Query and visualize data** in `sample_contract_df.parquet`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  """)
198
 
199
  with gr.Tabs():
 
200
  with gr.TabItem("Query Data"):
201
  with gr.Row():
202
  with gr.Column(scale=1):
@@ -214,21 +249,35 @@ with gr.Blocks() as demo:
214
  results_out = gr.Dataframe(label="Query Results", interactive=False)
215
  plot_out = gr.Plot(label="Plot")
216
 
 
217
  with gr.TabItem("Dataset Schema"):
218
  gr.Markdown("### Dataset Schema")
219
  schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json()))
220
 
 
 
 
 
221
  async def on_generate_click(nl_query):
 
 
 
222
  sql_query, plot_code = await generate_sql_and_plot_code(nl_query)
223
  return sql_query, plot_code
224
 
225
  def on_execute_click(sql_query, plot_code):
 
 
 
226
  result_df, error_msg = execute_query(sql_query)
227
  if error_msg:
228
  return None, None, error_msg
229
  if plot_code.strip():
230
  fig, plot_error = generate_plot(plot_code, result_df)
231
- return result_df, fig, plot_error if plot_error else ""
 
 
 
232
  else:
233
  return result_df, None, ""
234
 
@@ -243,4 +292,8 @@ with gr.Blocks() as demo:
243
  outputs=[results_out, plot_out, error_out],
244
  )
245
 
 
 
 
 
246
  demo.launch()
 
7
  import openai
8
  import os
9
 
10
+ # Set OpenAI API key
11
+ openai.api_key = os.getenv("OPENAI_API_KEY")
12
+
13
  # =========================
14
  # Configuration and Setup
15
  # =========================
 
78
 
79
  async def parse_query(nl_query):
80
  """
81
+ Converts a natural language query into a SQL query using OpenAI's GPT-3.5-turbo model.
82
  """
83
+
84
  messages = [
85
+ {"role": "system",
86
+ "content": (
87
+ "You are an assistant that converts natural language queries into SQL queries "
88
+ "for a DuckDB database named 'contract_data'. Use the provided schema to form accurate SQL queries.")
89
+ },
90
+ {"role": "user",
91
+ "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nNatural Language Query:\n\"{nl_query}\"\n\nSQL Query:"
92
+ }
93
  ]
94
 
95
  try:
96
  response = await openai.ChatCompletion.acreate(
97
  model="gpt-3.5-turbo",
98
  messages=messages,
99
+ temperature=0, # Set to 0 for deterministic output
100
  max_tokens=150,
101
  )
102
  sql_query = response.choices[0].message['content'].strip()
 
119
  for keyword in plot_keywords:
120
  if keyword in nl_query.lower():
121
  return True
122
+ return False
123
 
124
  async def generate_sql_and_plot_code(query):
125
  """
 
130
  plot_code = ""
131
  if is_plot and not sql_query.startswith("Error"):
132
  # Generate plot code based on the query
133
+ # For simplicity, we'll generate a basic plot code
134
  plot_code = """
135
  import plotly.express as px
136
  fig = px.bar(result_df, x='x_column', y='y_column', title='Generated Plot')
 
147
 
148
  try:
149
  con = duckdb.connect()
150
+ # Ensure the view is created
151
  con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
152
  result_df = con.execute(sql_query).fetchdf()
153
  con.close()
154
  return result_df, ""
155
  except Exception as e:
156
+ # In case of error, return None and error message
157
  return None, f"Error executing query: {e}"
158
 
159
  def generate_plot(plot_code, result_df):
 
163
  if not plot_code.strip():
164
  return None, "No plot code provided."
165
  try:
166
+ # Replace placeholders in plot_code with actual column names
167
  if result_df.empty:
168
  return None, "Result DataFrame is empty."
169
  columns = result_df.columns.tolist()
 
172
  plot_code = plot_code.replace('x_column', columns[0])
173
  plot_code = plot_code.replace('y_column', columns[1])
174
 
175
+ # Execute the plot code
176
  local_vars = {'result_df': result_df, 'px': px}
177
  exec(plot_code, {}, local_vars)
178
  fig = local_vars.get('fig', None)
179
+ if fig:
180
+ return fig, ""
181
+ else:
182
+ return None, "Plot could not be generated."
183
  except Exception as e:
184
  return None, f"Error generating plot: {e}"
185
 
 
207
  # Parquet SQL Query and Plotting App
208
 
209
  **Query and visualize data** in `sample_contract_df.parquet`
210
+
211
+ ## Instructions
212
+
213
+ 1. **Describe the data you want to retrieve or plot**: For example:
214
+ - `Show all awards greater than 1,000,000 in California`
215
+ - `Plot the distribution of awards by state`
216
+ - `Show a bar chart of total awards per department`
217
+ - `List awardees who received multiple awards along with award amounts`
218
+ - `Number of awards issued by each department division`
219
+
220
+ 2. **Generate SQL**: Click "Generate SQL" to see the SQL query that will be executed.
221
+ 3. **Execute Query**: Click "Execute Query" to run the query and view the results.
222
+ 4. **View Plot**: If your query involves plotting, the plot will be displayed.
223
+ 5. **View Dataset Schema**: Check the "Dataset Schema" tab to understand available columns and their types.
224
+
225
+ ## Example Queries
226
+
227
+ - `Plot the total award amount by state`
228
+ - `Show a histogram of awards over time`
229
+ - `award greater than 1000000 and state equal to "CA"`
230
+ - `List awards where department_ind_agency contains "Defense"`
231
  """)
232
 
233
  with gr.Tabs():
234
+ # Query Tab
235
  with gr.TabItem("Query Data"):
236
  with gr.Row():
237
  with gr.Column(scale=1):
 
249
  results_out = gr.Dataframe(label="Query Results", interactive=False)
250
  plot_out = gr.Plot(label="Plot")
251
 
252
+ # Schema Tab
253
  with gr.TabItem("Dataset Schema"):
254
  gr.Markdown("### Dataset Schema")
255
  schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json()))
256
 
257
+ # =========================
258
+ # Click Event Handlers
259
+ # =========================
260
+
261
  async def on_generate_click(nl_query):
262
+ """
263
+ Handles the "Generate SQL" button click event.
264
+ """
265
  sql_query, plot_code = await generate_sql_and_plot_code(nl_query)
266
  return sql_query, plot_code
267
 
268
  def on_execute_click(sql_query, plot_code):
269
+ """
270
+ Handles the "Execute Query" button click event.
271
+ """
272
  result_df, error_msg = execute_query(sql_query)
273
  if error_msg:
274
  return None, None, error_msg
275
  if plot_code.strip():
276
  fig, plot_error = generate_plot(plot_code, result_df)
277
+ if plot_error:
278
+ return result_df, None, plot_error
279
+ else:
280
+ return result_df, fig, ""
281
  else:
282
  return result_df, None, ""
283
 
 
292
  outputs=[results_out, plot_out, error_out],
293
  )
294
 
295
+ # =========================
296
+ # Launch the Gradio App
297
+ # =========================
298
+
299
  demo.launch()