Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,9 +1,11 @@
|
|
1 |
import json
|
2 |
import gradio as gr
|
3 |
import duckdb
|
4 |
-
import re
|
5 |
from functools import lru_cache
|
6 |
from transformers import pipeline
|
|
|
|
|
|
|
7 |
|
8 |
# Load the Parquet dataset path
|
9 |
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
|
@@ -58,86 +60,62 @@ def load_dataset_schema():
|
|
58 |
finally:
|
59 |
con.close()
|
60 |
|
61 |
-
#
|
62 |
-
@lru_cache(maxsize=1)
|
63 |
-
def get_nlp_model():
|
64 |
-
# We use a zero-shot-classification pipeline for query intent understanding
|
65 |
-
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
|
66 |
-
return classifier
|
67 |
-
|
68 |
-
# Advanced Natural Language to SQL Parser using NLP
|
69 |
def parse_query(nl_query):
|
70 |
"""
|
71 |
-
Converts a natural language query into SQL
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
"""
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
conditions = []
|
82 |
-
|
83 |
-
# Simple heuristic parsing (can be replaced with more advanced NLP techniques)
|
84 |
-
for col in columns:
|
85 |
-
if col in query:
|
86 |
-
for op in operations:
|
87 |
-
if op in query:
|
88 |
-
pattern = rf"{col}\s+{op}\s+(.*)"
|
89 |
-
match = re.search(pattern, query)
|
90 |
-
if match:
|
91 |
-
value = match.group(1).strip(' "')
|
92 |
-
sql_condition = ""
|
93 |
-
|
94 |
-
# Map operations to SQL syntax
|
95 |
-
if op == 'greater than or equal to':
|
96 |
-
sql_condition = f"{col} >= {value}"
|
97 |
-
elif op == 'less than or equal to':
|
98 |
-
sql_condition = f"{col} <= {value}"
|
99 |
-
elif op == 'greater than':
|
100 |
-
sql_condition = f"{col} > {value}"
|
101 |
-
elif op == 'less than':
|
102 |
-
sql_condition = f"{col} < {value}"
|
103 |
-
elif op == 'equal to':
|
104 |
-
sql_condition = f"{col} = '{value}'"
|
105 |
-
elif op == 'not equal to':
|
106 |
-
sql_condition = f"{col} != '{value}'"
|
107 |
-
elif op == 'between':
|
108 |
-
values = value.split(' and ')
|
109 |
-
if len(values) == 2:
|
110 |
-
sql_condition = f"{col} BETWEEN {values[0]} AND {values[1]}"
|
111 |
-
elif op == 'contains':
|
112 |
-
sql_condition = f"{col} LIKE '%{value}%'"
|
113 |
-
elif op == 'starts with':
|
114 |
-
sql_condition = f"{col} LIKE '{value}%'"
|
115 |
-
elif op == 'ends with':
|
116 |
-
sql_condition = f"{col} LIKE '%{value}'"
|
117 |
-
|
118 |
-
if sql_condition:
|
119 |
-
conditions.append(sql_condition)
|
120 |
-
break
|
121 |
-
|
122 |
-
# Combine conditions with AND
|
123 |
-
if conditions:
|
124 |
-
where_clause = ' AND '.join(conditions)
|
125 |
-
else:
|
126 |
-
where_clause = ''
|
127 |
-
|
128 |
-
return where_clause
|
129 |
-
|
130 |
-
# Generate SQL based on user query
|
131 |
-
def generate_sql_query(query):
|
132 |
"""
|
133 |
-
Generates
|
134 |
"""
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
# Execute the SQL query and return results or error
|
143 |
def execute_query(sql_query):
|
@@ -152,9 +130,37 @@ def execute_query(sql_query):
|
|
152 |
con.close()
|
153 |
return result_df, ""
|
154 |
except Exception as e:
|
155 |
-
# In case of error, return
|
156 |
return None, f"Error executing query: {e}"
|
157 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
# Cache the schema JSON for display
|
159 |
@lru_cache(maxsize=1)
|
160 |
def get_schema_json():
|
@@ -167,25 +173,28 @@ if not load_dataset_schema():
|
|
167 |
# Gradio app UI
|
168 |
with gr.Blocks() as demo:
|
169 |
gr.Markdown("""
|
170 |
-
# Parquet SQL Query App
|
171 |
|
172 |
-
**Query data** in `sample_contract_df.parquet`
|
173 |
|
174 |
## Instructions
|
175 |
|
176 |
-
1. **Describe the data you want to retrieve**: For example:
|
177 |
- `Show all awards greater than 1,000,000 in California`
|
|
|
|
|
178 |
- `List awardees who received multiple awards along with award amounts`
|
179 |
- `Number of awards issued by each department division`
|
180 |
-
- `Distribution of awards by city and zip code across different countries`
|
181 |
-
- `Active awards with their award numbers and dates`
|
182 |
|
183 |
2. **Generate SQL**: Click "Generate SQL" to see the SQL query that will be executed.
|
184 |
3. **Execute Query**: Click "Execute Query" to run the query and view the results.
|
185 |
-
4. **View
|
|
|
186 |
|
187 |
## Example Queries
|
188 |
|
|
|
|
|
189 |
- `award greater than 1000000 and state equal to "CA"`
|
190 |
- `List awards where department_ind_agency contains "Defense"`
|
191 |
""")
|
@@ -202,10 +211,12 @@ with gr.Blocks() as demo:
|
|
202 |
)
|
203 |
btn_generate = gr.Button("Generate SQL")
|
204 |
sql_out = gr.Code(label="Generated SQL Query", language="sql")
|
|
|
205 |
btn_execute = gr.Button("Execute Query")
|
206 |
error_out = gr.Markdown("", visible=False)
|
207 |
with gr.Column(scale=2):
|
208 |
results_out = gr.Dataframe(label="Query Results", interactive=False)
|
|
|
209 |
|
210 |
# Schema Tab
|
211 |
with gr.TabItem("Dataset Schema"):
|
@@ -213,15 +224,32 @@ with gr.Blocks() as demo:
|
|
213 |
schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json()))
|
214 |
|
215 |
# Set up click events
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
216 |
btn_generate.click(
|
217 |
-
fn=
|
218 |
inputs=query,
|
219 |
-
outputs=sql_out,
|
220 |
)
|
221 |
btn_execute.click(
|
222 |
-
fn=
|
223 |
-
inputs=sql_out,
|
224 |
-
outputs=[results_out, error_out],
|
225 |
)
|
226 |
|
227 |
# Launch the app
|
|
|
1 |
import json
|
2 |
import gradio as gr
|
3 |
import duckdb
|
|
|
4 |
from functools import lru_cache
|
5 |
from transformers import pipeline
|
6 |
+
import pandas as pd
|
7 |
+
import plotly.express as px
|
8 |
+
import openai
|
9 |
|
10 |
# Load the Parquet dataset path
|
11 |
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
|
|
|
60 |
finally:
|
61 |
con.close()
|
62 |
|
63 |
+
# Advanced Natural Language to SQL Parser using OpenAI's GPT-3
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
def parse_query(nl_query):
|
65 |
"""
|
66 |
+
Converts a natural language query into SQL query using OpenAI GPT-3.
|
67 |
+
"""
|
68 |
+
openai.api_key = 'YOUR_OPENAI_API_KEY' # Replace with your OpenAI API key
|
69 |
+
|
70 |
+
prompt = f"""
|
71 |
+
Convert the following natural language query into a SQL query for a DuckDB database. Use 'contract_data' as the table name.
|
72 |
+
Schema:
|
73 |
+
{json.dumps(schema, indent=2)}
|
74 |
+
Query:
|
75 |
+
"{nl_query}"
|
76 |
+
"""
|
77 |
+
try:
|
78 |
+
response = openai.Completion.create(
|
79 |
+
engine="text-davinci-003",
|
80 |
+
prompt=prompt,
|
81 |
+
temperature=0,
|
82 |
+
max_tokens=150,
|
83 |
+
top_p=1,
|
84 |
+
n=1,
|
85 |
+
stop=None
|
86 |
+
)
|
87 |
+
sql_query = response.choices[0].text.strip()
|
88 |
+
return sql_query
|
89 |
+
except Exception as e:
|
90 |
+
return f"Error generating SQL query: {e}"
|
91 |
+
|
92 |
+
# Function to detect if the user wants a plot
|
93 |
+
def detect_plot_intent(nl_query):
|
94 |
+
"""
|
95 |
+
Detects if the user's query involves plotting.
|
96 |
"""
|
97 |
+
plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'histogram', 'bar chart', 'line chart', 'scatter plot', 'pie chart']
|
98 |
+
for keyword in plot_keywords:
|
99 |
+
if keyword in nl_query.lower():
|
100 |
+
return True
|
101 |
+
return False
|
102 |
+
|
103 |
+
# Generate SQL and Plot Code based on user query
|
104 |
+
def generate_sql_and_plot_code(query):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
"""
|
106 |
+
Generates SQL query and plotting code based on the natural language input.
|
107 |
"""
|
108 |
+
is_plot = detect_plot_intent(query)
|
109 |
+
sql_query = parse_query(query)
|
110 |
+
plot_code = ""
|
111 |
+
if is_plot:
|
112 |
+
# Generate plot code based on the query
|
113 |
+
# For simplicity, we'll generate a basic plot code
|
114 |
+
plot_code = """
|
115 |
+
import plotly.express as px
|
116 |
+
fig = px.bar(result_df, x='x_column', y='y_column')
|
117 |
+
"""
|
118 |
+
return sql_query, plot_code
|
119 |
|
120 |
# Execute the SQL query and return results or error
|
121 |
def execute_query(sql_query):
|
|
|
130 |
con.close()
|
131 |
return result_df, ""
|
132 |
except Exception as e:
|
133 |
+
# In case of error, return None and error message
|
134 |
return None, f"Error executing query: {e}"
|
135 |
|
136 |
+
# Generate and display plot
|
137 |
+
def generate_plot(plot_code, result_df):
|
138 |
+
"""
|
139 |
+
Executes the plot code to generate a plot from the result DataFrame.
|
140 |
+
"""
|
141 |
+
if not plot_code.strip():
|
142 |
+
return None, "No plot code provided."
|
143 |
+
try:
|
144 |
+
# Replace placeholders in plot_code with actual column names
|
145 |
+
if result_df.empty:
|
146 |
+
return None, "Result DataFrame is empty."
|
147 |
+
columns = result_df.columns.tolist()
|
148 |
+
if len(columns) < 2:
|
149 |
+
return None, "Not enough columns to plot."
|
150 |
+
plot_code = plot_code.replace('x_column', columns[0])
|
151 |
+
plot_code = plot_code.replace('y_column', columns[1])
|
152 |
+
|
153 |
+
# Execute the plot code
|
154 |
+
local_vars = {'result_df': result_df}
|
155 |
+
exec(plot_code, {'px': px}, local_vars)
|
156 |
+
fig = local_vars.get('fig', None)
|
157 |
+
if fig:
|
158 |
+
return fig, ""
|
159 |
+
else:
|
160 |
+
return None, "Plot could not be generated."
|
161 |
+
except Exception as e:
|
162 |
+
return None, f"Error generating plot: {e}"
|
163 |
+
|
164 |
# Cache the schema JSON for display
|
165 |
@lru_cache(maxsize=1)
|
166 |
def get_schema_json():
|
|
|
173 |
# Gradio app UI
|
174 |
with gr.Blocks() as demo:
|
175 |
gr.Markdown("""
|
176 |
+
# Parquet SQL Query and Plotting App
|
177 |
|
178 |
+
**Query and visualize data** in `sample_contract_df.parquet`
|
179 |
|
180 |
## Instructions
|
181 |
|
182 |
+
1. **Describe the data you want to retrieve or plot**: For example:
|
183 |
- `Show all awards greater than 1,000,000 in California`
|
184 |
+
- `Plot the distribution of awards by state`
|
185 |
+
- `Show a bar chart of total awards per department`
|
186 |
- `List awardees who received multiple awards along with award amounts`
|
187 |
- `Number of awards issued by each department division`
|
|
|
|
|
188 |
|
189 |
2. **Generate SQL**: Click "Generate SQL" to see the SQL query that will be executed.
|
190 |
3. **Execute Query**: Click "Execute Query" to run the query and view the results.
|
191 |
+
4. **View Plot**: If your query involves plotting, the plot will be displayed.
|
192 |
+
5. **View Dataset Schema**: Check the "Dataset Schema" tab to understand available columns and their types.
|
193 |
|
194 |
## Example Queries
|
195 |
|
196 |
+
- `Plot the total award amount by state`
|
197 |
+
- `Show a histogram of awards over time`
|
198 |
- `award greater than 1000000 and state equal to "CA"`
|
199 |
- `List awards where department_ind_agency contains "Defense"`
|
200 |
""")
|
|
|
211 |
)
|
212 |
btn_generate = gr.Button("Generate SQL")
|
213 |
sql_out = gr.Code(label="Generated SQL Query", language="sql")
|
214 |
+
plot_code_out = gr.Code(label="Generated Plot Code", language="python")
|
215 |
btn_execute = gr.Button("Execute Query")
|
216 |
error_out = gr.Markdown("", visible=False)
|
217 |
with gr.Column(scale=2):
|
218 |
results_out = gr.Dataframe(label="Query Results", interactive=False)
|
219 |
+
plot_out = gr.Plot(label="Plot")
|
220 |
|
221 |
# Schema Tab
|
222 |
with gr.TabItem("Dataset Schema"):
|
|
|
224 |
schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json()))
|
225 |
|
226 |
# Set up click events
|
227 |
+
def on_generate_click(nl_query):
|
228 |
+
sql_query, plot_code = generate_sql_and_plot_code(nl_query)
|
229 |
+
return sql_query, plot_code
|
230 |
+
|
231 |
+
def on_execute_click(sql_query, plot_code):
|
232 |
+
result_df, error_msg = execute_query(sql_query)
|
233 |
+
if error_msg:
|
234 |
+
return None, None, error_msg
|
235 |
+
if plot_code.strip():
|
236 |
+
fig, plot_error = generate_plot(plot_code, result_df)
|
237 |
+
if plot_error:
|
238 |
+
return result_df, None, plot_error
|
239 |
+
else:
|
240 |
+
return result_df, fig, ""
|
241 |
+
else:
|
242 |
+
return result_df, None, ""
|
243 |
+
|
244 |
btn_generate.click(
|
245 |
+
fn=on_generate_click,
|
246 |
inputs=query,
|
247 |
+
outputs=[sql_out, plot_code_out],
|
248 |
)
|
249 |
btn_execute.click(
|
250 |
+
fn=on_execute_click,
|
251 |
+
inputs=[sql_out, plot_code_out],
|
252 |
+
outputs=[results_out, plot_out, error_out],
|
253 |
)
|
254 |
|
255 |
# Launch the app
|