LeonceNsh commited on
Commit
a1792a1
·
verified ·
1 Parent(s): c6f04cb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -73
app.py CHANGED
@@ -73,31 +73,25 @@ def load_dataset_schema():
73
  # OpenAI API Integration
74
  # =========================
75
 
76
- def parse_query(nl_query):
77
  """
78
  Converts a natural language query into a SQL query using OpenAI's GPT-4-turbo model.
79
  """
80
-
81
- # new
82
- from openai import AsyncOpenAI
83
-
84
- client = AsyncOpenAI()
85
- completion = await client.chat.completions.create(model="gpt-3.5-turbo",
86
- messages = [{"role": "system",
87
- "content": (
88
- "You are an assistant that converts natural language queries into SQL queries "
89
- "for a DuckDB database named 'contract_data'. Use the provided schema to form accurate SQL queries.")
90
- },
91
- {"role": "user",
92
- "content": ( f"Schema:\n{json.dumps(schema, indent=2)}\n\n" f"Natural Language Query:\n\"{nl_query}\"\n\nSQL Query:"
93
- )}
94
- ])
95
 
96
  try:
97
- response = openai.ChatCompletion.create(
98
- model="gpt-4-turbo",
99
  messages=messages,
100
- temperature=0, # Set to 0 for deterministic output
101
  max_tokens=150,
102
  )
103
  sql_query = response.choices[0].message['content'].strip()
@@ -122,16 +116,15 @@ def detect_plot_intent(nl_query):
122
  return True
123
  return False
124
 
125
- def generate_sql_and_plot_code(query):
126
  """
127
  Generates SQL query and plotting code based on the natural language input.
128
  """
129
  is_plot = detect_plot_intent(query)
130
- sql_query = parse_query(query)
131
  plot_code = ""
132
  if is_plot and not sql_query.startswith("Error"):
133
  # Generate plot code based on the query
134
- # For simplicity, we'll generate a basic plot code
135
  plot_code = """
136
  import plotly.express as px
137
  fig = px.bar(result_df, x='x_column', y='y_column', title='Generated Plot')
@@ -148,13 +141,11 @@ def execute_query(sql_query):
148
 
149
  try:
150
  con = duckdb.connect()
151
- # Ensure the view is created
152
  con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
153
  result_df = con.execute(sql_query).fetchdf()
154
  con.close()
155
  return result_df, ""
156
  except Exception as e:
157
- # In case of error, return None and error message
158
  return None, f"Error executing query: {e}"
159
 
160
  def generate_plot(plot_code, result_df):
@@ -164,7 +155,6 @@ def generate_plot(plot_code, result_df):
164
  if not plot_code.strip():
165
  return None, "No plot code provided."
166
  try:
167
- # Replace placeholders in plot_code with actual column names
168
  if result_df.empty:
169
  return None, "Result DataFrame is empty."
170
  columns = result_df.columns.tolist()
@@ -173,14 +163,10 @@ def generate_plot(plot_code, result_df):
173
  plot_code = plot_code.replace('x_column', columns[0])
174
  plot_code = plot_code.replace('y_column', columns[1])
175
 
176
- # Execute the plot code
177
  local_vars = {'result_df': result_df, 'px': px}
178
  exec(plot_code, {}, local_vars)
179
  fig = local_vars.get('fig', None)
180
- if fig:
181
- return fig, ""
182
- else:
183
- return None, "Plot could not be generated."
184
  except Exception as e:
185
  return None, f"Error generating plot: {e}"
186
 
@@ -208,31 +194,9 @@ with gr.Blocks() as demo:
208
  # Parquet SQL Query and Plotting App
209
 
210
  **Query and visualize data** in `sample_contract_df.parquet`
211
-
212
- ## Instructions
213
-
214
- 1. **Describe the data you want to retrieve or plot**: For example:
215
- - `Show all awards greater than 1,000,000 in California`
216
- - `Plot the distribution of awards by state`
217
- - `Show a bar chart of total awards per department`
218
- - `List awardees who received multiple awards along with award amounts`
219
- - `Number of awards issued by each department division`
220
-
221
- 2. **Generate SQL**: Click "Generate SQL" to see the SQL query that will be executed.
222
- 3. **Execute Query**: Click "Execute Query" to run the query and view the results.
223
- 4. **View Plot**: If your query involves plotting, the plot will be displayed.
224
- 5. **View Dataset Schema**: Check the "Dataset Schema" tab to understand available columns and their types.
225
-
226
- ## Example Queries
227
-
228
- - `Plot the total award amount by state`
229
- - `Show a histogram of awards over time`
230
- - `award greater than 1000000 and state equal to "CA"`
231
- - `List awards where department_ind_agency contains "Defense"`
232
  """)
233
 
234
  with gr.Tabs():
235
- # Query Tab
236
  with gr.TabItem("Query Data"):
237
  with gr.Row():
238
  with gr.Column(scale=1):
@@ -250,35 +214,21 @@ with gr.Blocks() as demo:
250
  results_out = gr.Dataframe(label="Query Results", interactive=False)
251
  plot_out = gr.Plot(label="Plot")
252
 
253
- # Schema Tab
254
  with gr.TabItem("Dataset Schema"):
255
  gr.Markdown("### Dataset Schema")
256
  schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json()))
257
 
258
- # =========================
259
- # Click Event Handlers
260
- # =========================
261
-
262
- def on_generate_click(nl_query):
263
- """
264
- Handles the "Generate SQL" button click event.
265
- """
266
- sql_query, plot_code = generate_sql_and_plot_code(nl_query)
267
  return sql_query, plot_code
268
 
269
  def on_execute_click(sql_query, plot_code):
270
- """
271
- Handles the "Execute Query" button click event.
272
- """
273
  result_df, error_msg = execute_query(sql_query)
274
  if error_msg:
275
  return None, None, error_msg
276
  if plot_code.strip():
277
  fig, plot_error = generate_plot(plot_code, result_df)
278
- if plot_error:
279
- return result_df, None, plot_error
280
- else:
281
- return result_df, fig, ""
282
  else:
283
  return result_df, None, ""
284
 
@@ -293,8 +243,4 @@ with gr.Blocks() as demo:
293
  outputs=[results_out, plot_out, error_out],
294
  )
295
 
296
- # =========================
297
- # Launch the Gradio App
298
- # =========================
299
-
300
  demo.launch()
 
73
  # OpenAI API Integration
74
  # =========================
75
 
76
+ async def parse_query(nl_query):
77
  """
78
  Converts a natural language query into a SQL query using OpenAI's GPT-4-turbo model.
79
  """
80
+ messages = [
81
+ {"role": "system", "content": (
82
+ "You are an assistant that converts natural language queries into SQL queries "
83
+ "for a DuckDB database named 'contract_data'. Use the provided schema to form accurate SQL queries."
84
+ )},
85
+ {"role": "user", "content": (
86
+ f"Schema:\n{json.dumps(schema, indent=2)}\n\nNatural Language Query:\n\"{nl_query}\"\n\nSQL Query:"
87
+ )}
88
+ ]
 
 
 
 
 
 
89
 
90
  try:
91
+ response = await openai.ChatCompletion.acreate(
92
+ model="gpt-3.5-turbo",
93
  messages=messages,
94
+ temperature=0,
95
  max_tokens=150,
96
  )
97
  sql_query = response.choices[0].message['content'].strip()
 
116
  return True
117
  return False
118
 
119
+ async def generate_sql_and_plot_code(query):
120
  """
121
  Generates SQL query and plotting code based on the natural language input.
122
  """
123
  is_plot = detect_plot_intent(query)
124
+ sql_query = await parse_query(query)
125
  plot_code = ""
126
  if is_plot and not sql_query.startswith("Error"):
127
  # Generate plot code based on the query
 
128
  plot_code = """
129
  import plotly.express as px
130
  fig = px.bar(result_df, x='x_column', y='y_column', title='Generated Plot')
 
141
 
142
  try:
143
  con = duckdb.connect()
 
144
  con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
145
  result_df = con.execute(sql_query).fetchdf()
146
  con.close()
147
  return result_df, ""
148
  except Exception as e:
 
149
  return None, f"Error executing query: {e}"
150
 
151
  def generate_plot(plot_code, result_df):
 
155
  if not plot_code.strip():
156
  return None, "No plot code provided."
157
  try:
 
158
  if result_df.empty:
159
  return None, "Result DataFrame is empty."
160
  columns = result_df.columns.tolist()
 
163
  plot_code = plot_code.replace('x_column', columns[0])
164
  plot_code = plot_code.replace('y_column', columns[1])
165
 
 
166
  local_vars = {'result_df': result_df, 'px': px}
167
  exec(plot_code, {}, local_vars)
168
  fig = local_vars.get('fig', None)
169
+ return fig, "" if fig else "Plot could not be generated."
 
 
 
170
  except Exception as e:
171
  return None, f"Error generating plot: {e}"
172
 
 
194
  # Parquet SQL Query and Plotting App
195
 
196
  **Query and visualize data** in `sample_contract_df.parquet`
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  """)
198
 
199
  with gr.Tabs():
 
200
  with gr.TabItem("Query Data"):
201
  with gr.Row():
202
  with gr.Column(scale=1):
 
214
  results_out = gr.Dataframe(label="Query Results", interactive=False)
215
  plot_out = gr.Plot(label="Plot")
216
 
 
217
  with gr.TabItem("Dataset Schema"):
218
  gr.Markdown("### Dataset Schema")
219
  schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json()))
220
 
221
+ async def on_generate_click(nl_query):
222
+ sql_query, plot_code = await generate_sql_and_plot_code(nl_query)
 
 
 
 
 
 
 
223
  return sql_query, plot_code
224
 
225
  def on_execute_click(sql_query, plot_code):
 
 
 
226
  result_df, error_msg = execute_query(sql_query)
227
  if error_msg:
228
  return None, None, error_msg
229
  if plot_code.strip():
230
  fig, plot_error = generate_plot(plot_code, result_df)
231
+ return result_df, fig, plot_error if plot_error else ""
 
 
 
232
  else:
233
  return result_df, None, ""
234
 
 
243
  outputs=[results_out, plot_out, error_out],
244
  )
245
 
 
 
 
 
246
  demo.launch()