LeonceNsh commited on
Commit
78f16f0
·
verified ·
1 Parent(s): 48aeb6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -88
app.py CHANGED
@@ -11,13 +11,9 @@ import os
11
  # Configuration and Setup
12
  # =========================
13
 
14
- # Set OpenAI API key
15
  openai.api_key = os.getenv("OPENAI_API_KEY")
16
-
17
- # Load the Parquet dataset path
18
  dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
19
 
20
- # Provided schema
21
  schema = [
22
  {"column_name": "department_ind_agency", "column_type": "VARCHAR"},
23
  {"column_name": "cgac", "column_type": "BIGINT"},
@@ -50,14 +46,7 @@ def get_schema():
50
 
51
  COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
52
 
53
- # =========================
54
- # Database Interaction
55
- # =========================
56
-
57
  def load_dataset_schema():
58
- """
59
- Loads the dataset schema into DuckDB by creating a view.
60
- """
61
  con = duckdb.connect()
62
  try:
63
  con.execute("DROP VIEW IF EXISTS contract_data")
@@ -69,7 +58,6 @@ def load_dataset_schema():
69
  finally:
70
  con.close()
71
 
72
- # Load the dataset schema at startup
73
  load_dataset_schema()
74
 
75
  # =========================
@@ -77,9 +65,6 @@ load_dataset_schema()
77
  # =========================
78
 
79
  def parse_query(nl_query):
80
- """
81
- Converts a natural language query into a SQL query using OpenAI's API.
82
- """
83
  messages = [
84
  {"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
85
  {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
@@ -87,7 +72,7 @@ def parse_query(nl_query):
87
 
88
  try:
89
  response = openai.chat.completions.create(
90
- model="gpt-4o-mini",
91
  messages=messages,
92
  temperature=0,
93
  max_tokens=150,
@@ -97,21 +82,11 @@ def parse_query(nl_query):
97
  except Exception as e:
98
  return f"Error generating SQL query: {e}"
99
 
100
- # =========================
101
- # Plotting Utilities
102
- # =========================
103
-
104
  def detect_plot_intent(nl_query):
105
- """
106
- Detects if the user's query involves plotting.
107
- """
108
  plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'trend', 'histogram', 'bar', 'line', 'scatter', 'pie']
109
  return any(keyword in nl_query.lower() for keyword in plot_keywords)
110
 
111
  def generate_plot(nl_query, result_df):
112
- """
113
- Generates a Plotly figure based on the result DataFrame and the user's intent.
114
- """
115
  if not detect_plot_intent(nl_query):
116
  return None, ""
117
 
@@ -119,7 +94,6 @@ def generate_plot(nl_query, result_df):
119
  if len(columns) < 2:
120
  return None, "Not enough data to generate a plot."
121
 
122
- # Simple heuristic to choose plot type based on keywords
123
  if 'bar' in nl_query.lower():
124
  fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
125
  elif 'line' in nl_query.lower():
@@ -129,7 +103,6 @@ def generate_plot(nl_query, result_df):
129
  elif 'pie' in nl_query.lower():
130
  fig = px.pie(result_df, names=columns[0], values=columns[1], title='Pie Chart')
131
  else:
132
- # Default to bar chart
133
  fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
134
 
135
  fig.update_layout(title_x=0.5)
@@ -143,7 +116,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
143
  gr.Markdown("""
144
  <h1 style="text-align: center; font-size: 2.5em; color: #333333;">Parquet Data Explorer</h1>
145
  <p style="text-align: center; color: #666666;">Query and visualize your data effortlessly.</p>
146
- """, elem_id="main-title")
147
 
148
  with gr.Row():
149
  with gr.Column(scale=1):
@@ -152,12 +125,13 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
152
  placeholder='e.g., "What are the total awards over 1M in California?"',
153
  lines=1
154
  )
155
- # Hidden schema display that appears on focus
156
- schema_display = gr.JSON(
157
- label="Dataset Schema",
158
- value=get_schema(),
159
- visible=False
160
- )
 
161
  error_out = gr.Markdown(
162
  value="",
163
  visible=False
@@ -170,24 +144,12 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
170
  label="Visualization"
171
  )
172
 
 
173
  gr.Markdown("""
174
- <style>
175
- /* Center the content */
176
- .gradio-container {
177
- max-width: 1000px;
178
- margin: auto;
179
- }
180
- /* Style the main title */
181
- #main-title h1 {
182
- font-weight: bold;
183
- }
184
- /* Style the error alert */
185
- .gradio-container .alert-error {
186
- background-color: #ffe6e6;
187
- color: #cc0000;
188
- border: 1px solid #cc0000;
189
- }
190
- </style>
191
  """)
192
 
193
  # =========================
@@ -195,9 +157,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
195
  # =========================
196
 
197
  def on_query_submit(nl_query):
198
- """
199
- Handles the submission of a natural language query.
200
- """
201
  if not nl_query.strip():
202
  return gr.update(visible=True, value="Please enter a query."), None, None
203
 
@@ -215,15 +174,18 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
215
 
216
  return gr.update(visible=False, value=""), result_df, fig
217
 
218
- def on_input_focus():
219
- """
220
- Shows the dataset schema when the input box is focused.
221
- """
222
- return gr.update(visible=True)
 
 
223
 
224
- # =========================
225
- # Assign Event Handlers
226
- # =========================
 
227
 
228
  query.submit(
229
  fn=on_query_submit,
@@ -231,31 +193,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
231
  outputs=[error_out, results_out, plot_out]
232
  )
233
 
234
- query.focus(
235
- fn=lambda: gr.update(visible=True),
236
- inputs=None,
237
- outputs=schema_display
238
- )
239
-
240
- # =========================
241
- # Helper Functions
242
- # =========================
243
-
244
- def execute_query(sql_query):
245
- """
246
- Executes the SQL query and returns the results.
247
- """
248
- try:
249
- con = duckdb.connect()
250
- con.execute("PRAGMA threads=4") # Optimize for performance
251
- con.execute("DROP VIEW IF EXISTS contract_data")
252
- con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
253
- result_df = con.execute(sql_query).fetchdf()
254
- con.close()
255
- return result_df, ""
256
- except Exception as e:
257
- return None, f"Error executing query: {e}"
258
-
259
  # =========================
260
  # Launch the Gradio App
261
  # =========================
 
11
  # Configuration and Setup
12
  # =========================
13
 
 
14
  openai.api_key = os.getenv("OPENAI_API_KEY")
 
 
15
  dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
16
 
 
17
  schema = [
18
  {"column_name": "department_ind_agency", "column_type": "VARCHAR"},
19
  {"column_name": "cgac", "column_type": "BIGINT"},
 
46
 
47
  COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
48
 
 
 
 
 
49
  def load_dataset_schema():
 
 
 
50
  con = duckdb.connect()
51
  try:
52
  con.execute("DROP VIEW IF EXISTS contract_data")
 
58
  finally:
59
  con.close()
60
 
 
61
  load_dataset_schema()
62
 
63
  # =========================
 
65
  # =========================
66
 
67
  def parse_query(nl_query):
 
 
 
68
  messages = [
69
  {"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
70
  {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
 
72
 
73
  try:
74
  response = openai.chat.completions.create(
75
+ model="gpt-4",
76
  messages=messages,
77
  temperature=0,
78
  max_tokens=150,
 
82
  except Exception as e:
83
  return f"Error generating SQL query: {e}"
84
 
 
 
 
 
85
  def detect_plot_intent(nl_query):
 
 
 
86
  plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'trend', 'histogram', 'bar', 'line', 'scatter', 'pie']
87
  return any(keyword in nl_query.lower() for keyword in plot_keywords)
88
 
89
  def generate_plot(nl_query, result_df):
 
 
 
90
  if not detect_plot_intent(nl_query):
91
  return None, ""
92
 
 
94
  if len(columns) < 2:
95
  return None, "Not enough data to generate a plot."
96
 
 
97
  if 'bar' in nl_query.lower():
98
  fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
99
  elif 'line' in nl_query.lower():
 
103
  elif 'pie' in nl_query.lower():
104
  fig = px.pie(result_df, names=columns[0], values=columns[1], title='Pie Chart')
105
  else:
 
106
  fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
107
 
108
  fig.update_layout(title_x=0.5)
 
116
  gr.Markdown("""
117
  <h1 style="text-align: center; font-size: 2.5em; color: #333333;">Parquet Data Explorer</h1>
118
  <p style="text-align: center; color: #666666;">Query and visualize your data effortlessly.</p>
119
+ """)
120
 
121
  with gr.Row():
122
  with gr.Column(scale=1):
 
125
  placeholder='e.g., "What are the total awards over 1M in California?"',
126
  lines=1
127
  )
128
+ gr.Markdown("### Example Queries")
129
+ with gr.Row():
130
+ btn_example1 = gr.Button("Show awards over 1M in CA")
131
+ btn_example2 = gr.Button("List all contracts in New York")
132
+ btn_example3 = gr.Button("Show top 5 departments by award amount")
133
+ btn_example4 = gr.Button("Execute: SELECT * from contract_data LIMIT 10;")
134
+
135
  error_out = gr.Markdown(
136
  value="",
137
  visible=False
 
144
  label="Visualization"
145
  )
146
 
147
+ # Instructions
148
  gr.Markdown("""
149
+ ## Instructions
150
+ 1. **Enter a query**: Type in a natural language query in the textbox.
151
+ 2. **Use Example Queries**: Click on any example query button above.
152
+ 3. **Generate SQL and Plot**: Click "Execute" to see results and visualization.
 
 
 
 
 
 
 
 
 
 
 
 
 
153
  """)
154
 
155
  # =========================
 
157
  # =========================
158
 
159
  def on_query_submit(nl_query):
 
 
 
160
  if not nl_query.strip():
161
  return gr.update(visible=True, value="Please enter a query."), None, None
162
 
 
174
 
175
  return gr.update(visible=False, value=""), result_df, fig
176
 
177
+ def on_example_click(query_text):
178
+ sql_query = parse_query(query_text)
179
+ result_df, error_msg = execute_query(sql_query)
180
+ if error_msg:
181
+ return sql_query, None, None, error_msg
182
+ fig, plot_error = generate_plot(query_text, result_df)
183
+ return sql_query, result_df, fig, plot_error if plot_error else ""
184
 
185
+ btn_example1.click(lambda: on_example_click("Show awards over 1M in CA"), outputs=[results_out, plot_out, error_out])
186
+ btn_example2.click(lambda: on_example_click("List all contracts in New York"), outputs=[results_out, plot_out, error_out])
187
+ btn_example3.click(lambda: on_example_click("Show top 5 departments by award amount"), outputs=[results_out, plot_out, error_out])
188
+ btn_example4.click(lambda: on_example_click("SELECT * from contract_data LIMIT 10;"), outputs=[results_out, plot_out, error_out])
189
 
190
  query.submit(
191
  fn=on_query_submit,
 
193
  outputs=[error_out, results_out, plot_out]
194
  )
195
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
  # =========================
197
  # Launch the Gradio App
198
  # =========================