LeonceNsh commited on
Commit
dfe1769
·
verified ·
1 Parent(s): 776a658

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +116 -88
app.py CHANGED
@@ -1,9 +1,11 @@
1
  import json
2
  import gradio as gr
3
  import duckdb
4
- import re
5
  from functools import lru_cache
6
  from transformers import pipeline
 
 
 
7
 
8
  # Load the Parquet dataset path
9
  dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
@@ -58,86 +60,62 @@ def load_dataset_schema():
58
  finally:
59
  con.close()
60
 
61
- # Initialize the NLP model for query parsing
62
- @lru_cache(maxsize=1)
63
- def get_nlp_model():
64
- # We use a zero-shot-classification pipeline for query intent understanding
65
- classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
66
- return classifier
67
-
68
- # Advanced Natural Language to SQL Parser using NLP
69
  def parse_query(nl_query):
70
  """
71
- Converts a natural language query into SQL WHERE conditions based on the schema.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  """
73
- # Tokenize and normalize the query
74
- query = nl_query.lower()
75
-
76
- # Identify columns and possible operations
77
- columns = [col['column_name'] for col in get_schema()]
78
- operations = ['greater than or equal to', 'less than or equal to', 'greater than', 'less than', 'equal to', 'not equal to', 'between', 'contains', 'starts with', 'ends with']
79
-
80
- # Extract conditions from the query
81
- conditions = []
82
-
83
- # Simple heuristic parsing (can be replaced with more advanced NLP techniques)
84
- for col in columns:
85
- if col in query:
86
- for op in operations:
87
- if op in query:
88
- pattern = rf"{col}\s+{op}\s+(.*)"
89
- match = re.search(pattern, query)
90
- if match:
91
- value = match.group(1).strip(' "')
92
- sql_condition = ""
93
-
94
- # Map operations to SQL syntax
95
- if op == 'greater than or equal to':
96
- sql_condition = f"{col} >= {value}"
97
- elif op == 'less than or equal to':
98
- sql_condition = f"{col} <= {value}"
99
- elif op == 'greater than':
100
- sql_condition = f"{col} > {value}"
101
- elif op == 'less than':
102
- sql_condition = f"{col} < {value}"
103
- elif op == 'equal to':
104
- sql_condition = f"{col} = '{value}'"
105
- elif op == 'not equal to':
106
- sql_condition = f"{col} != '{value}'"
107
- elif op == 'between':
108
- values = value.split(' and ')
109
- if len(values) == 2:
110
- sql_condition = f"{col} BETWEEN {values[0]} AND {values[1]}"
111
- elif op == 'contains':
112
- sql_condition = f"{col} LIKE '%{value}%'"
113
- elif op == 'starts with':
114
- sql_condition = f"{col} LIKE '{value}%'"
115
- elif op == 'ends with':
116
- sql_condition = f"{col} LIKE '%{value}'"
117
-
118
- if sql_condition:
119
- conditions.append(sql_condition)
120
- break
121
-
122
- # Combine conditions with AND
123
- if conditions:
124
- where_clause = ' AND '.join(conditions)
125
- else:
126
- where_clause = ''
127
-
128
- return where_clause
129
-
130
- # Generate SQL based on user query
131
- def generate_sql_query(query):
132
  """
133
- Generates a SQL query based on the natural language input.
134
  """
135
- condition = parse_query(query)
136
- if condition:
137
- sql_query = f"SELECT * FROM contract_data WHERE {condition}"
138
- else:
139
- sql_query = "SELECT * FROM contract_data"
140
- return sql_query
 
 
 
 
 
141
 
142
  # Execute the SQL query and return results or error
143
  def execute_query(sql_query):
@@ -152,9 +130,37 @@ def execute_query(sql_query):
152
  con.close()
153
  return result_df, ""
154
  except Exception as e:
155
- # In case of error, return empty dataframe and error message
156
  return None, f"Error executing query: {e}"
157
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  # Cache the schema JSON for display
159
  @lru_cache(maxsize=1)
160
  def get_schema_json():
@@ -167,25 +173,28 @@ if not load_dataset_schema():
167
  # Gradio app UI
168
  with gr.Blocks() as demo:
169
  gr.Markdown("""
170
- # Parquet SQL Query App
171
 
172
- **Query data** in `sample_contract_df.parquet`
173
 
174
  ## Instructions
175
 
176
- 1. **Describe the data you want to retrieve**: For example:
177
  - `Show all awards greater than 1,000,000 in California`
 
 
178
  - `List awardees who received multiple awards along with award amounts`
179
  - `Number of awards issued by each department division`
180
- - `Distribution of awards by city and zip code across different countries`
181
- - `Active awards with their award numbers and dates`
182
 
183
  2. **Generate SQL**: Click "Generate SQL" to see the SQL query that will be executed.
184
  3. **Execute Query**: Click "Execute Query" to run the query and view the results.
185
- 4. **View Dataset Schema**: Check the "Dataset Schema" tab to understand available columns and their types.
 
186
 
187
  ## Example Queries
188
 
 
 
189
  - `award greater than 1000000 and state equal to "CA"`
190
  - `List awards where department_ind_agency contains "Defense"`
191
  """)
@@ -202,10 +211,12 @@ with gr.Blocks() as demo:
202
  )
203
  btn_generate = gr.Button("Generate SQL")
204
  sql_out = gr.Code(label="Generated SQL Query", language="sql")
 
205
  btn_execute = gr.Button("Execute Query")
206
  error_out = gr.Markdown("", visible=False)
207
  with gr.Column(scale=2):
208
  results_out = gr.Dataframe(label="Query Results", interactive=False)
 
209
 
210
  # Schema Tab
211
  with gr.TabItem("Dataset Schema"):
@@ -213,15 +224,32 @@ with gr.Blocks() as demo:
213
  schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json()))
214
 
215
  # Set up click events
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
216
  btn_generate.click(
217
- fn=generate_sql_query,
218
  inputs=query,
219
- outputs=sql_out,
220
  )
221
  btn_execute.click(
222
- fn=execute_query,
223
- inputs=sql_out,
224
- outputs=[results_out, error_out],
225
  )
226
 
227
  # Launch the app
 
1
  import json
2
  import gradio as gr
3
  import duckdb
 
4
  from functools import lru_cache
5
  from transformers import pipeline
6
+ import pandas as pd
7
+ import plotly.express as px
8
+ import openai
9
 
10
  # Load the Parquet dataset path
11
  dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
 
60
  finally:
61
  con.close()
62
 
63
+ # Advanced Natural Language to SQL Parser using OpenAI's GPT-3
 
 
 
 
 
 
 
64
  def parse_query(nl_query):
65
  """
66
+ Converts a natural language query into SQL query using OpenAI GPT-3.
67
+ """
68
+ openai.api_key = 'YOUR_OPENAI_API_KEY' # Replace with your OpenAI API key
69
+
70
+ prompt = f"""
71
+ Convert the following natural language query into a SQL query for a DuckDB database. Use 'contract_data' as the table name.
72
+ Schema:
73
+ {json.dumps(schema, indent=2)}
74
+ Query:
75
+ "{nl_query}"
76
+ """
77
+ try:
78
+ response = openai.Completion.create(
79
+ engine="text-davinci-003",
80
+ prompt=prompt,
81
+ temperature=0,
82
+ max_tokens=150,
83
+ top_p=1,
84
+ n=1,
85
+ stop=None
86
+ )
87
+ sql_query = response.choices[0].text.strip()
88
+ return sql_query
89
+ except Exception as e:
90
+ return f"Error generating SQL query: {e}"
91
+
92
+ # Function to detect if the user wants a plot
93
+ def detect_plot_intent(nl_query):
94
+ """
95
+ Detects if the user's query involves plotting.
96
  """
97
+ plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'histogram', 'bar chart', 'line chart', 'scatter plot', 'pie chart']
98
+ for keyword in plot_keywords:
99
+ if keyword in nl_query.lower():
100
+ return True
101
+ return False
102
+
103
+ # Generate SQL and Plot Code based on user query
104
+ def generate_sql_and_plot_code(query):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  """
106
+ Generates SQL query and plotting code based on the natural language input.
107
  """
108
+ is_plot = detect_plot_intent(query)
109
+ sql_query = parse_query(query)
110
+ plot_code = ""
111
+ if is_plot:
112
+ # Generate plot code based on the query
113
+ # For simplicity, we'll generate a basic plot code
114
+ plot_code = """
115
+ import plotly.express as px
116
+ fig = px.bar(result_df, x='x_column', y='y_column')
117
+ """
118
+ return sql_query, plot_code
119
 
120
  # Execute the SQL query and return results or error
121
  def execute_query(sql_query):
 
130
  con.close()
131
  return result_df, ""
132
  except Exception as e:
133
+ # In case of error, return None and error message
134
  return None, f"Error executing query: {e}"
135
 
136
+ # Generate and display plot
137
+ def generate_plot(plot_code, result_df):
138
+ """
139
+ Executes the plot code to generate a plot from the result DataFrame.
140
+ """
141
+ if not plot_code.strip():
142
+ return None, "No plot code provided."
143
+ try:
144
+ # Replace placeholders in plot_code with actual column names
145
+ if result_df.empty:
146
+ return None, "Result DataFrame is empty."
147
+ columns = result_df.columns.tolist()
148
+ if len(columns) < 2:
149
+ return None, "Not enough columns to plot."
150
+ plot_code = plot_code.replace('x_column', columns[0])
151
+ plot_code = plot_code.replace('y_column', columns[1])
152
+
153
+ # Execute the plot code
154
+ local_vars = {'result_df': result_df}
155
+ exec(plot_code, {'px': px}, local_vars)
156
+ fig = local_vars.get('fig', None)
157
+ if fig:
158
+ return fig, ""
159
+ else:
160
+ return None, "Plot could not be generated."
161
+ except Exception as e:
162
+ return None, f"Error generating plot: {e}"
163
+
164
  # Cache the schema JSON for display
165
  @lru_cache(maxsize=1)
166
  def get_schema_json():
 
173
  # Gradio app UI
174
  with gr.Blocks() as demo:
175
  gr.Markdown("""
176
+ # Parquet SQL Query and Plotting App
177
 
178
+ **Query and visualize data** in `sample_contract_df.parquet`
179
 
180
  ## Instructions
181
 
182
+ 1. **Describe the data you want to retrieve or plot**: For example:
183
  - `Show all awards greater than 1,000,000 in California`
184
+ - `Plot the distribution of awards by state`
185
+ - `Show a bar chart of total awards per department`
186
  - `List awardees who received multiple awards along with award amounts`
187
  - `Number of awards issued by each department division`
 
 
188
 
189
  2. **Generate SQL**: Click "Generate SQL" to see the SQL query that will be executed.
190
  3. **Execute Query**: Click "Execute Query" to run the query and view the results.
191
+ 4. **View Plot**: If your query involves plotting, the plot will be displayed.
192
+ 5. **View Dataset Schema**: Check the "Dataset Schema" tab to understand available columns and their types.
193
 
194
  ## Example Queries
195
 
196
+ - `Plot the total award amount by state`
197
+ - `Show a histogram of awards over time`
198
  - `award greater than 1000000 and state equal to "CA"`
199
  - `List awards where department_ind_agency contains "Defense"`
200
  """)
 
211
  )
212
  btn_generate = gr.Button("Generate SQL")
213
  sql_out = gr.Code(label="Generated SQL Query", language="sql")
214
+ plot_code_out = gr.Code(label="Generated Plot Code", language="python")
215
  btn_execute = gr.Button("Execute Query")
216
  error_out = gr.Markdown("", visible=False)
217
  with gr.Column(scale=2):
218
  results_out = gr.Dataframe(label="Query Results", interactive=False)
219
+ plot_out = gr.Plot(label="Plot")
220
 
221
  # Schema Tab
222
  with gr.TabItem("Dataset Schema"):
 
224
  schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json()))
225
 
226
  # Set up click events
227
+ def on_generate_click(nl_query):
228
+ sql_query, plot_code = generate_sql_and_plot_code(nl_query)
229
+ return sql_query, plot_code
230
+
231
+ def on_execute_click(sql_query, plot_code):
232
+ result_df, error_msg = execute_query(sql_query)
233
+ if error_msg:
234
+ return None, None, error_msg
235
+ if plot_code.strip():
236
+ fig, plot_error = generate_plot(plot_code, result_df)
237
+ if plot_error:
238
+ return result_df, None, plot_error
239
+ else:
240
+ return result_df, fig, ""
241
+ else:
242
+ return result_df, None, ""
243
+
244
  btn_generate.click(
245
+ fn=on_generate_click,
246
  inputs=query,
247
+ outputs=[sql_out, plot_code_out],
248
  )
249
  btn_execute.click(
250
+ fn=on_execute_click,
251
+ inputs=[sql_out, plot_code_out],
252
+ outputs=[results_out, plot_out, error_out],
253
  )
254
 
255
  # Launch the app