LeonceNsh commited on
Commit
c490b83
·
verified ·
1 Parent(s): f7a7a3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -154
app.py CHANGED
@@ -8,7 +8,7 @@ import plotly.express as px
8
  import os
9
 
10
  # Set OpenAI API key
11
- client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
12
 
13
  # =========================
14
  # Configuration and Setup
@@ -73,7 +73,7 @@ def load_dataset_schema():
73
  # OpenAI API Integration
74
  # =========================
75
 
76
- async def parse_query(nl_query):
77
  """
78
  Converts a natural language query into a SQL query using OpenAI's API.
79
  """
@@ -102,180 +102,94 @@ def detect_plot_intent(nl_query):
102
  """
103
  Detects if the user's query involves plotting.
104
  """
105
- plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize']
106
  return any(keyword in nl_query.lower() for keyword in plot_keywords)
107
 
108
- async def generate_sql_and_plot_code(query):
109
  """
110
- Generates SQL query and optional plotting code.
111
  """
112
- is_plot = detect_plot_intent(query)
113
- sql_query = await parse_query(query)
114
- plot_code = ""
115
- if is_plot and not sql_query.startswith("Error"):
116
- plot_code = """
117
- import plotly.express as px
118
- fig = px.bar(result_df, x='x_column', y='y_column', title='Generated Plot')
119
- fig.update_layout(title_x=0.5)
120
- """
121
- return sql_query, plot_code
122
-
123
- def execute_query(sql_query):
124
- """
125
- Executes the SQL query and returns the results.
126
- """
127
- if sql_query.startswith("Error"):
128
- return None, sql_query
129
-
130
- try:
131
- con = duckdb.connect()
132
- con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
133
- result_df = con.execute(sql_query).fetchdf()
134
- con.close()
135
- return result_df, ""
136
- except Exception as e:
137
- return None, f"Error executing query: {e}"
138
 
139
- def generate_plot(plot_code, result_df):
140
- """
141
- Executes the plot code to generate a plot from the result DataFrame.
142
- """
143
- if not plot_code.strip():
144
- return None, "No plot code provided."
145
- try:
146
- columns = result_df.columns.tolist()
147
- if len(columns) < 2:
148
- return None, "Not enough columns to plot."
149
- plot_code = plot_code.replace('x_column', columns[0])
150
- plot_code = plot_code.replace('y_column', columns[1])
151
- local_vars = {'result_df': result_df, 'px': px}
152
- exec(plot_code, {}, local_vars)
153
- fig = local_vars.get('fig', None)
154
- return fig, "" if fig else "Plot could not be generated."
155
- except Exception as e:
156
- return None, f"Error generating plot: {e}"
157
 
158
  # =========================
159
  # Gradio Application UI
160
  # =========================
161
 
162
- with gr.Blocks() as demo:
163
  gr.Markdown("""
164
- # Parquet SQL Query and Plotting App
165
-
166
- **Query and visualize data** in `sample_contract_df.parquet`
167
-
168
- ## Instructions
169
-
170
- 1. **Describe the data you want**: e.g., `Show awards over 1M in CA`
171
- 2. **Use Example Queries**: Click on any example query button below to execute.
172
- 3. **Generate SQL**: Or, enter your own query and click "Generate SQL" to see the SQL query.
173
- 4. **Execute Query**: Run the query to view results and plots.
174
- 5. **Dataset Schema**: See available columns and types in the "Schema" tab.
175
-
176
- ## Example Queries
177
- """)
178
-
179
- with gr.Tabs():
180
- with gr.TabItem("Query Data"):
181
- with gr.Row():
182
- with gr.Column(scale=1):
183
- query = gr.Textbox(label="Natural Language Query", placeholder='e.g., "Awards > 1M in CA"')
184
 
185
- # Example query buttons
186
- gr.Markdown("### Click on an example query:")
187
- with gr.Row():
188
- btn_example1 = gr.Button("Show awards over 1M in CA")
189
- btn_example2 = gr.Button("List all contracts in New York")
190
- btn_example3 = gr.Button("Show top 5 departments by award amount")
191
- btn_example4 = gr.Button("Execute: SELECT * from contract_data LIMIT 10;")
192
 
193
- btn_generate = gr.Button("Generate SQL")
194
- sql_out = gr.Code(label="Generated SQL Query", language="sql")
195
- plot_code_out = gr.Code(label="Generated Plot Code", language="python")
196
- btn_execute = gr.Button("Execute Query")
197
- error_out = gr.Markdown("", visible=False)
198
- with gr.Column(scale=2):
199
- results_out = gr.Dataframe(label="Query Results", interactive=False)
200
- plot_out = gr.Plot(label="Plot")
201
 
202
- with gr.TabItem("Dataset Schema"):
203
- gr.Markdown("### Dataset Schema")
204
- schema_display = gr.JSON(label="Schema", value=json.loads(json.dumps(get_schema(), indent=2)))
 
 
 
 
 
 
205
 
206
- # =========================
207
- # Click Event Handlers
208
- # =========================
209
 
210
- async def on_generate_click(nl_query):
211
- """
212
- Handles the "Generate SQL" button click event.
213
- """
214
- sql_query, plot_code = await generate_sql_and_plot_code(nl_query)
215
- return sql_query, plot_code
216
 
217
- def on_execute_click(sql_query, plot_code):
218
- """
219
- Handles the "Execute Query" button click event.
220
- """
221
  result_df, error_msg = execute_query(sql_query)
222
  if error_msg:
223
- return None, None, error_msg
224
- if plot_code.strip():
225
- fig, plot_error = generate_plot(plot_code, result_df)
226
- if plot_error:
227
- return result_df, None, plot_error
228
- else:
229
- return result_df, fig, ""
230
- else:
231
- return result_df, None, ""
 
 
 
 
 
 
 
232
 
233
- # Functions for example query buttons
234
- async def on_example_nl_click(query_text):
235
- sql_query, plot_code = await generate_sql_and_plot_code(query_text)
236
- result_df, error_msg = execute_query(sql_query)
237
- fig = None
238
- if error_msg:
239
- return sql_query, plot_code, None, None, error_msg
240
- if plot_code.strip():
241
- fig, plot_error = generate_plot(plot_code, result_df)
242
- if plot_error:
243
- error_msg = plot_error
244
- else:
245
- error_msg = ""
246
- else:
247
- fig = None
248
- error_msg = ""
249
- return sql_query, plot_code, result_df, fig, error_msg
250
-
251
- def on_example_sql_click(sql_query):
252
- result_df, error_msg = execute_query(sql_query)
253
- fig = None
254
- plot_code = ""
255
- if error_msg:
256
- return sql_query, plot_code, None, None, error_msg
257
- else:
258
- return sql_query, plot_code, result_df, fig, ""
259
-
260
- async def on_example1_click():
261
- return await on_example_nl_click("Show awards over 1M in CA")
262
-
263
- async def on_example2_click():
264
- return await on_example_nl_click("List all contracts in New York")
265
-
266
- async def on_example3_click():
267
- return await on_example_nl_click("Show top 5 departments by award amount")
268
-
269
- def on_example4_click():
270
- return on_example_sql_click("SELECT * from contract_data LIMIT 10;")
271
 
272
- btn_example1.click(fn=on_example1_click, inputs=[], outputs=[sql_out, plot_code_out, results_out, plot_out, error_out])
273
- btn_example2.click(fn=on_example2_click, inputs=[], outputs=[sql_out, plot_code_out, results_out, plot_out, error_out])
274
- btn_example3.click(fn=on_example3_click, inputs=[], outputs=[sql_out, plot_code_out, results_out, plot_out, error_out])
275
- btn_example4.click(fn=on_example4_click, inputs=[], outputs=[sql_out, plot_code_out, results_out, plot_out, error_out])
 
 
276
 
277
- btn_generate.click(fn=on_generate_click, inputs=query, outputs=[sql_out, plot_code_out])
278
- btn_execute.click(fn=on_execute_click, inputs=[sql_out, plot_code_out], outputs=[results_out, plot_out, error_out])
 
 
 
 
 
 
279
 
280
  # =========================
281
  # Launch the Gradio App
 
8
  import os
9
 
10
  # Set OpenAI API key
11
+ openai.api_key = os.getenv("OPENAI_API_KEY")
12
 
13
  # =========================
14
  # Configuration and Setup
 
73
  # OpenAI API Integration
74
  # =========================
75
 
76
+ def parse_query(nl_query):
77
  """
78
  Converts a natural language query into a SQL query using OpenAI's API.
79
  """
 
102
  """
103
  Detects if the user's query involves plotting.
104
  """
105
+ plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'trend', 'histogram', 'bar', 'line']
106
  return any(keyword in nl_query.lower() for keyword in plot_keywords)
107
 
108
+ def generate_plot_code(sql_query, result_df):
109
  """
110
+ Generates plotting code based on the SQL query and result DataFrame.
111
  """
112
+ if not detect_plot_intent(sql_query):
113
+ return None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
 
115
+ columns = result_df.columns.tolist()
116
+ if len(columns) >= 2:
117
+ fig = px.bar(result_df, x=columns[0], y=columns[1], title='Generated Plot')
118
+ fig.update_layout(title_x=0.5)
119
+ return fig
120
+ else:
121
+ return None
 
 
 
 
 
 
 
 
 
 
 
122
 
123
  # =========================
124
  # Gradio Application UI
125
  # =========================
126
 
127
+ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
128
  gr.Markdown("""
129
+ ## Parquet Data Explorer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
+ **Query and visualize data effortlessly.**
 
 
 
 
 
 
132
 
133
+ """, elem_id="main-title")
 
 
 
 
 
 
 
134
 
135
+ with gr.Row():
136
+ with gr.Column(scale=1):
137
+ query = gr.Textbox(
138
+ label="Ask a question about the data",
139
+ placeholder='e.g., "What are the total awards over 1M in California?"',
140
+ lines=1
141
+ )
142
+ # Display schema next to the input
143
+ schema_display = gr.JSON(value=json.loads(json.dumps(get_schema(), indent=2)), visible=False)
144
 
145
+ error_out = gr.Alert(variant="error", visible=False)
 
 
146
 
147
+ with gr.Column(scale=2):
148
+ results_out = gr.DataFrame(label="Results")
149
+ plot_out = gr.Plot()
 
 
 
150
 
151
+ def on_query_submit(nl_query):
152
+ sql_query = parse_query(nl_query)
153
+ if sql_query.startswith("Error"):
154
+ return gr.update(visible=True, value=sql_query), None, None
155
  result_df, error_msg = execute_query(sql_query)
156
  if error_msg:
157
+ return gr.update(visible=True, value=error_msg), None, None
158
+ fig = generate_plot_code(nl_query, result_df)
159
+ return gr.update(visible=False), result_df, fig
160
+
161
+ def on_focus():
162
+ return gr.update(visible=True)
163
+
164
+ query.submit(
165
+ fn=on_query_submit,
166
+ inputs=query,
167
+ outputs=[error_out, results_out, plot_out]
168
+ )
169
+ query.focus(
170
+ fn=on_focus,
171
+ outputs=schema_display
172
+ )
173
 
174
+ # =========================
175
+ # Helper Functions
176
+ # =========================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ def execute_query(sql_query):
179
+ """
180
+ Executes the SQL query and returns the results.
181
+ """
182
+ if sql_query.startswith("Error"):
183
+ return None, sql_query
184
 
185
+ try:
186
+ con = duckdb.connect()
187
+ con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
188
+ result_df = con.execute(sql_query).fetchdf()
189
+ con.close()
190
+ return result_df, ""
191
+ except Exception as e:
192
+ return None, f"Error executing query: {e}"
193
 
194
  # =========================
195
  # Launch the Gradio App