LeonceNsh commited on
Commit
94bf8f1
·
verified ·
1 Parent(s): c490b83

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -36
app.py CHANGED
@@ -7,13 +7,13 @@ import pandas as pd
7
  import plotly.express as px
8
  import os
9
 
10
- # Set OpenAI API key
11
- openai.api_key = os.getenv("OPENAI_API_KEY")
12
-
13
  # =========================
14
  # Configuration and Setup
15
  # =========================
16
 
 
 
 
17
  # Load the Parquet dataset path
18
  dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
19
 
@@ -69,6 +69,9 @@ def load_dataset_schema():
69
  finally:
70
  con.close()
71
 
 
 
 
72
  # =========================
73
  # OpenAI API Integration
74
  # =========================
@@ -78,13 +81,13 @@ def parse_query(nl_query):
78
  Converts a natural language query into a SQL query using OpenAI's API.
79
  """
80
  messages = [
81
- {"role": "system", "content": "Convert natural language queries to SQL queries for 'contract_data'."},
82
  {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
83
  ]
84
 
85
  try:
86
- response = openai.chat.completions.create(
87
- model="gpt-4o-mini",
88
  messages=messages,
89
  temperature=0,
90
  max_tokens=150,
@@ -102,23 +105,35 @@ def detect_plot_intent(nl_query):
102
  """
103
  Detects if the user's query involves plotting.
104
  """
105
- plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'trend', 'histogram', 'bar', 'line']
106
  return any(keyword in nl_query.lower() for keyword in plot_keywords)
107
 
108
- def generate_plot_code(sql_query, result_df):
109
  """
110
- Generates plotting code based on the SQL query and result DataFrame.
111
  """
112
- if not detect_plot_intent(sql_query):
113
- return None
114
 
115
  columns = result_df.columns.tolist()
116
- if len(columns) >= 2:
117
- fig = px.bar(result_df, x=columns[0], y=columns[1], title='Generated Plot')
118
- fig.update_layout(title_x=0.5)
119
- return fig
 
 
 
 
 
 
 
 
120
  else:
121
- return None
 
 
 
 
122
 
123
  # =========================
124
  # Gradio Application UI
@@ -126,48 +141,101 @@ def generate_plot_code(sql_query, result_df):
126
 
127
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
128
  gr.Markdown("""
129
- ## Parquet Data Explorer
130
-
131
- **Query and visualize data effortlessly.**
132
-
133
  """, elem_id="main-title")
134
 
135
  with gr.Row():
136
  with gr.Column(scale=1):
137
  query = gr.Textbox(
138
- label="Ask a question about the data",
139
  placeholder='e.g., "What are the total awards over 1M in California?"',
140
  lines=1
141
  )
142
- # Display schema next to the input
143
- schema_display = gr.JSON(value=json.loads(json.dumps(get_schema(), indent=2)), visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
- error_out = gr.Alert(variant="error", visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
146
 
147
- with gr.Column(scale=2):
148
- results_out = gr.DataFrame(label="Results")
149
- plot_out = gr.Plot()
150
 
151
  def on_query_submit(nl_query):
 
 
 
 
 
 
152
  sql_query = parse_query(nl_query)
153
  if sql_query.startswith("Error"):
154
  return gr.update(visible=True, value=sql_query), None, None
 
155
  result_df, error_msg = execute_query(sql_query)
156
  if error_msg:
157
  return gr.update(visible=True, value=error_msg), None, None
158
- fig = generate_plot_code(nl_query, result_df)
159
- return gr.update(visible=False), result_df, fig
160
 
161
- def on_focus():
 
 
 
 
 
 
 
 
 
162
  return gr.update(visible=True)
163
 
 
 
 
 
164
  query.submit(
165
  fn=on_query_submit,
166
  inputs=query,
167
  outputs=[error_out, results_out, plot_out]
168
  )
 
169
  query.focus(
170
- fn=on_focus,
 
171
  outputs=schema_display
172
  )
173
 
@@ -179,12 +247,11 @@ def execute_query(sql_query):
179
  """
180
  Executes the SQL query and returns the results.
181
  """
182
- if sql_query.startswith("Error"):
183
- return None, sql_query
184
-
185
  try:
186
  con = duckdb.connect()
187
- con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
 
 
188
  result_df = con.execute(sql_query).fetchdf()
189
  con.close()
190
  return result_df, ""
@@ -195,4 +262,5 @@ def execute_query(sql_query):
195
  # Launch the Gradio App
196
  # =========================
197
 
198
- demo.launch()
 
 
7
  import plotly.express as px
8
  import os
9
 
 
 
 
10
  # =========================
11
  # Configuration and Setup
12
  # =========================
13
 
14
+ # Set OpenAI API key
15
+ openai.api_key = os.getenv("OPENAI_API_KEY")
16
+
17
  # Load the Parquet dataset path
18
  dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
19
 
 
69
  finally:
70
  con.close()
71
 
72
+ # Load the dataset schema at startup
73
+ load_dataset_schema()
74
+
75
  # =========================
76
  # OpenAI API Integration
77
  # =========================
 
81
  Converts a natural language query into a SQL query using OpenAI's API.
82
  """
83
  messages = [
84
+ {"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
85
  {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
86
  ]
87
 
88
  try:
89
+ response = openai.ChatCompletion.create(
90
+ model="gpt-3.5-turbo",
91
  messages=messages,
92
  temperature=0,
93
  max_tokens=150,
 
105
  """
106
  Detects if the user's query involves plotting.
107
  """
108
+ plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'trend', 'histogram', 'bar', 'line', 'scatter', 'pie']
109
  return any(keyword in nl_query.lower() for keyword in plot_keywords)
110
 
111
+ def generate_plot(nl_query, result_df):
112
  """
113
+ Generates a Plotly figure based on the result DataFrame and the user's intent.
114
  """
115
+ if not detect_plot_intent(nl_query):
116
+ return None, ""
117
 
118
  columns = result_df.columns.tolist()
119
+ if len(columns) < 2:
120
+ return None, "Not enough data to generate a plot."
121
+
122
+ # Simple heuristic to choose plot type based on keywords
123
+ if 'bar' in nl_query.lower():
124
+ fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
125
+ elif 'line' in nl_query.lower():
126
+ fig = px.line(result_df, x=columns[0], y=columns[1], title='Line Chart')
127
+ elif 'scatter' in nl_query.lower():
128
+ fig = px.scatter(result_df, x=columns[0], y=columns[1], title='Scatter Plot')
129
+ elif 'pie' in nl_query.lower():
130
+ fig = px.pie(result_df, names=columns[0], values=columns[1], title='Pie Chart')
131
  else:
132
+ # Default to bar chart
133
+ fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
134
+
135
+ fig.update_layout(title_x=0.5)
136
+ return fig, ""
137
 
138
  # =========================
139
  # Gradio Application UI
 
141
 
142
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
143
  gr.Markdown("""
144
+ <h1 style="text-align: center; font-size: 2.5em; color: #333333;">Parquet Data Explorer</h1>
145
+ <p style="text-align: center; color: #666666;">Query and visualize your data effortlessly.</p>
 
 
146
  """, elem_id="main-title")
147
 
148
  with gr.Row():
149
  with gr.Column(scale=1):
150
  query = gr.Textbox(
151
+ label="Your Query",
152
  placeholder='e.g., "What are the total awards over 1M in California?"',
153
  lines=1
154
  )
155
+ # Hidden schema display that appears on focus
156
+ schema_display = gr.JSON(
157
+ label="Dataset Schema",
158
+ value=get_schema(),
159
+ interactive=False,
160
+ visible=False
161
+ )
162
+ error_out = gr.Markdown(
163
+ value="",
164
+ visible=False
165
+ )
166
+ with gr.Column(scale=2):
167
+ results_out = gr.DataFrame(
168
+ label="Results",
169
+ interactive=False
170
+ )
171
+ plot_out = gr.Plot(
172
+ label="Visualization"
173
+ )
174
 
175
+ gr.Markdown("""
176
+ <style>
177
+ /* Center the content */
178
+ .gradio-container {
179
+ max-width: 1000px;
180
+ margin: auto;
181
+ }
182
+ /* Style the main title */
183
+ #main-title h1 {
184
+ font-weight: bold;
185
+ }
186
+ /* Style the error alert */
187
+ .gradio-container .alert-error {
188
+ background-color: #ffe6e6;
189
+ color: #cc0000;
190
+ border: 1px solid #cc0000;
191
+ }
192
+ </style>
193
+ """)
194
 
195
+ # =========================
196
+ # Click Event Handlers
197
+ # =========================
198
 
199
  def on_query_submit(nl_query):
200
+ """
201
+ Handles the submission of a natural language query.
202
+ """
203
+ if not nl_query.strip():
204
+ return gr.update(visible=True, value="Please enter a query."), None, None
205
+
206
  sql_query = parse_query(nl_query)
207
  if sql_query.startswith("Error"):
208
  return gr.update(visible=True, value=sql_query), None, None
209
+
210
  result_df, error_msg = execute_query(sql_query)
211
  if error_msg:
212
  return gr.update(visible=True, value=error_msg), None, None
 
 
213
 
214
+ fig, plot_error = generate_plot(nl_query, result_df)
215
+ if plot_error:
216
+ return gr.update(visible=True, value=plot_error), None, None
217
+
218
+ return gr.update(visible=False, value=""), result_df, fig
219
+
220
+ def on_input_focus():
221
+ """
222
+ Shows the dataset schema when the input box is focused.
223
+ """
224
  return gr.update(visible=True)
225
 
226
+ # =========================
227
+ # Assign Event Handlers
228
+ # =========================
229
+
230
  query.submit(
231
  fn=on_query_submit,
232
  inputs=query,
233
  outputs=[error_out, results_out, plot_out]
234
  )
235
+
236
  query.focus(
237
+ fn=lambda: gr.update(visible=True),
238
+ inputs=None,
239
  outputs=schema_display
240
  )
241
 
 
247
  """
248
  Executes the SQL query and returns the results.
249
  """
 
 
 
250
  try:
251
  con = duckdb.connect()
252
+ con.execute("PRAGMA threads=4") # Optimize for performance
253
+ con.execute("DROP VIEW IF EXISTS contract_data")
254
+ con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
255
  result_df = con.execute(sql_query).fetchdf()
256
  con.close()
257
  return result_df, ""
 
262
  # Launch the Gradio App
263
  # =========================
264
 
265
+ if __name__ == "__main__":
266
+ demo.launch()