Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,10 @@ import plotly.express as px
|
|
7 |
import openai
|
8 |
import os
|
9 |
|
|
|
|
|
|
|
|
|
10 |
# Load the Parquet dataset path
|
11 |
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
|
12 |
|
@@ -45,9 +49,14 @@ def get_schema():
|
|
45 |
# Map column names to their types
|
46 |
COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
|
47 |
|
48 |
-
#
|
49 |
-
|
|
|
|
|
50 |
def load_dataset_schema():
|
|
|
|
|
|
|
51 |
con = duckdb.connect()
|
52 |
try:
|
53 |
# Drop the view if it exists to avoid errors
|
@@ -60,45 +69,51 @@ def load_dataset_schema():
|
|
60 |
finally:
|
61 |
con.close()
|
62 |
|
63 |
-
#
|
|
|
|
|
|
|
64 |
def parse_query(nl_query):
|
65 |
"""
|
66 |
-
Converts a natural language query into a SQL query using OpenAI's GPT-3.
|
67 |
"""
|
68 |
-
openai.api_key = os.getenv('OPENAI_API_KEY') #
|
69 |
|
70 |
-
|
|
|
71 |
|
72 |
-
user_prompt = f"""
|
73 |
Schema:
|
74 |
{json.dumps(schema, indent=2)}
|
75 |
-
|
76 |
-
|
77 |
"{nl_query}"
|
|
|
|
|
78 |
"""
|
79 |
|
80 |
try:
|
81 |
-
response = openai.
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
{"role": "user", "content": user_prompt}
|
86 |
-
],
|
87 |
-
temperature=0,
|
88 |
max_tokens=150,
|
89 |
top_p=1,
|
90 |
-
|
91 |
-
|
|
|
92 |
)
|
93 |
-
sql_query = response.choices[0].
|
94 |
return sql_query
|
95 |
except Exception as e:
|
96 |
return f"Error generating SQL query: {e}"
|
97 |
|
98 |
-
#
|
|
|
|
|
|
|
99 |
def detect_plot_intent(nl_query):
|
100 |
"""
|
101 |
-
Detects if the user's query involves plotting.
|
102 |
"""
|
103 |
plot_keywords = [
|
104 |
'plot', 'graph', 'chart', 'distribution', 'visualize', 'histogram',
|
@@ -109,7 +124,6 @@ def detect_plot_intent(nl_query):
|
|
109 |
return True
|
110 |
return False
|
111 |
|
112 |
-
# Generate SQL and Plot Code based on user query
|
113 |
def generate_sql_and_plot_code(query):
|
114 |
"""
|
115 |
Generates SQL query and plotting code based on the natural language input.
|
@@ -127,10 +141,9 @@ fig.update_layout(title_x=0.5)
|
|
127 |
"""
|
128 |
return sql_query, plot_code
|
129 |
|
130 |
-
# Execute the SQL query and return results or error
|
131 |
def execute_query(sql_query):
|
132 |
"""
|
133 |
-
Executes the SQL query and returns
|
134 |
"""
|
135 |
if sql_query.startswith("Error"):
|
136 |
return None, sql_query # Pass the error message forward
|
@@ -146,7 +159,6 @@ def execute_query(sql_query):
|
|
146 |
# In case of error, return None and error message
|
147 |
return None, f"Error executing query: {e}"
|
148 |
|
149 |
-
# Generate and display plot
|
150 |
def generate_plot(plot_code, result_df):
|
151 |
"""
|
152 |
Executes the plot code to generate a plot from the result DataFrame.
|
@@ -174,16 +186,25 @@ def generate_plot(plot_code, result_df):
|
|
174 |
except Exception as e:
|
175 |
return None, f"Error generating plot: {e}"
|
176 |
|
177 |
-
#
|
|
|
|
|
|
|
178 |
@lru_cache(maxsize=1)
|
179 |
def get_schema_json():
|
180 |
return json.dumps(get_schema(), indent=2)
|
181 |
|
182 |
-
#
|
|
|
|
|
|
|
183 |
if not load_dataset_schema():
|
184 |
raise Exception("Failed to load dataset schema. Please check the dataset path and format.")
|
185 |
|
186 |
-
#
|
|
|
|
|
|
|
187 |
with gr.Blocks() as demo:
|
188 |
gr.Markdown("""
|
189 |
# Parquet SQL Query and Plotting App
|
@@ -202,7 +223,7 @@ with gr.Blocks() as demo:
|
|
202 |
2. **Generate SQL**: Click "Generate SQL" to see the SQL query that will be executed.
|
203 |
3. **Execute Query**: Click "Execute Query" to run the query and view the results.
|
204 |
4. **View Plot**: If your query involves plotting, the plot will be displayed.
|
205 |
-
5. **View Dataset Schema**: to understand available columns and their types.
|
206 |
|
207 |
## Example Queries
|
208 |
|
@@ -236,12 +257,21 @@ with gr.Blocks() as demo:
|
|
236 |
gr.Markdown("### Dataset Schema")
|
237 |
schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json()))
|
238 |
|
239 |
-
#
|
|
|
|
|
|
|
240 |
def on_generate_click(nl_query):
|
|
|
|
|
|
|
241 |
sql_query, plot_code = generate_sql_and_plot_code(nl_query)
|
242 |
return sql_query, plot_code
|
243 |
|
244 |
def on_execute_click(sql_query, plot_code):
|
|
|
|
|
|
|
245 |
result_df, error_msg = execute_query(sql_query)
|
246 |
if error_msg:
|
247 |
return None, None, error_msg
|
@@ -265,5 +295,8 @@ with gr.Blocks() as demo:
|
|
265 |
outputs=[results_out, plot_out, error_out],
|
266 |
)
|
267 |
|
268 |
-
#
|
269 |
-
|
|
|
|
|
|
|
|
7 |
import openai
|
8 |
import os
|
9 |
|
10 |
+
# =========================
|
11 |
+
# Configuration and Setup
|
12 |
+
# =========================
|
13 |
+
|
14 |
# Load the Parquet dataset path
|
15 |
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
|
16 |
|
|
|
49 |
# Map column names to their types
|
50 |
COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
|
51 |
|
52 |
+
# =========================
|
53 |
+
# Database Interaction
|
54 |
+
# =========================
|
55 |
+
|
56 |
def load_dataset_schema():
|
57 |
+
"""
|
58 |
+
Loads the dataset schema into DuckDB by creating a view.
|
59 |
+
"""
|
60 |
con = duckdb.connect()
|
61 |
try:
|
62 |
# Drop the view if it exists to avoid errors
|
|
|
69 |
finally:
|
70 |
con.close()
|
71 |
|
72 |
+
# =========================
|
73 |
+
# OpenAI API Integration
|
74 |
+
# =========================
|
75 |
+
|
76 |
def parse_query(nl_query):
|
77 |
"""
|
78 |
+
Converts a natural language query into a SQL query using OpenAI's GPT-3 Completion API.
|
79 |
"""
|
80 |
+
openai.api_key = os.getenv('OPENAI_API_KEY') # Ensure your API key is set as an environment variable
|
81 |
|
82 |
+
prompt = f"""
|
83 |
+
You are an assistant that converts natural language queries into SQL queries for a DuckDB database named 'contract_data'. Use the provided schema to form accurate SQL queries.
|
84 |
|
|
|
85 |
Schema:
|
86 |
{json.dumps(schema, indent=2)}
|
87 |
+
|
88 |
+
Natural Language Query:
|
89 |
"{nl_query}"
|
90 |
+
|
91 |
+
SQL Query:
|
92 |
"""
|
93 |
|
94 |
try:
|
95 |
+
response = openai.Completion.create(
|
96 |
+
engine="text-davinci-003", # You can choose a different engine if preferred
|
97 |
+
prompt=prompt,
|
98 |
+
temperature=0, # Set to 0 for deterministic output
|
|
|
|
|
|
|
99 |
max_tokens=150,
|
100 |
top_p=1,
|
101 |
+
frequency_penalty=0,
|
102 |
+
presence_penalty=0,
|
103 |
+
stop=["\n\n"] # Stop generation after two newlines
|
104 |
)
|
105 |
+
sql_query = response.choices[0].text.strip()
|
106 |
return sql_query
|
107 |
except Exception as e:
|
108 |
return f"Error generating SQL query: {e}"
|
109 |
|
110 |
+
# =========================
|
111 |
+
# Plotting Utilities
|
112 |
+
# =========================
|
113 |
+
|
114 |
def detect_plot_intent(nl_query):
|
115 |
"""
|
116 |
+
Detects if the user's query involves plotting based on the presence of specific keywords.
|
117 |
"""
|
118 |
plot_keywords = [
|
119 |
'plot', 'graph', 'chart', 'distribution', 'visualize', 'histogram',
|
|
|
124 |
return True
|
125 |
return False
|
126 |
|
|
|
127 |
def generate_sql_and_plot_code(query):
|
128 |
"""
|
129 |
Generates SQL query and plotting code based on the natural language input.
|
|
|
141 |
"""
|
142 |
return sql_query, plot_code
|
143 |
|
|
|
144 |
def execute_query(sql_query):
|
145 |
"""
|
146 |
+
Executes the SQL query and returns results or an error message.
|
147 |
"""
|
148 |
if sql_query.startswith("Error"):
|
149 |
return None, sql_query # Pass the error message forward
|
|
|
159 |
# In case of error, return None and error message
|
160 |
return None, f"Error executing query: {e}"
|
161 |
|
|
|
162 |
def generate_plot(plot_code, result_df):
|
163 |
"""
|
164 |
Executes the plot code to generate a plot from the result DataFrame.
|
|
|
186 |
except Exception as e:
|
187 |
return None, f"Error generating plot: {e}"
|
188 |
|
189 |
+
# =========================
|
190 |
+
# Schema Display
|
191 |
+
# =========================
|
192 |
+
|
193 |
@lru_cache(maxsize=1)
|
194 |
def get_schema_json():
|
195 |
return json.dumps(get_schema(), indent=2)
|
196 |
|
197 |
+
# =========================
|
198 |
+
# Initialize Dataset Schema
|
199 |
+
# =========================
|
200 |
+
|
201 |
if not load_dataset_schema():
|
202 |
raise Exception("Failed to load dataset schema. Please check the dataset path and format.")
|
203 |
|
204 |
+
# =========================
|
205 |
+
# Gradio Application UI
|
206 |
+
# =========================
|
207 |
+
|
208 |
with gr.Blocks() as demo:
|
209 |
gr.Markdown("""
|
210 |
# Parquet SQL Query and Plotting App
|
|
|
223 |
2. **Generate SQL**: Click "Generate SQL" to see the SQL query that will be executed.
|
224 |
3. **Execute Query**: Click "Execute Query" to run the query and view the results.
|
225 |
4. **View Plot**: If your query involves plotting, the plot will be displayed.
|
226 |
+
5. **View Dataset Schema**: Check the "Dataset Schema" tab to understand available columns and their types.
|
227 |
|
228 |
## Example Queries
|
229 |
|
|
|
257 |
gr.Markdown("### Dataset Schema")
|
258 |
schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json()))
|
259 |
|
260 |
+
# =========================
|
261 |
+
# Click Event Handlers
|
262 |
+
# =========================
|
263 |
+
|
264 |
def on_generate_click(nl_query):
|
265 |
+
"""
|
266 |
+
Handles the "Generate SQL" button click event.
|
267 |
+
"""
|
268 |
sql_query, plot_code = generate_sql_and_plot_code(nl_query)
|
269 |
return sql_query, plot_code
|
270 |
|
271 |
def on_execute_click(sql_query, plot_code):
|
272 |
+
"""
|
273 |
+
Handles the "Execute Query" button click event.
|
274 |
+
"""
|
275 |
result_df, error_msg = execute_query(sql_query)
|
276 |
if error_msg:
|
277 |
return None, None, error_msg
|
|
|
295 |
outputs=[results_out, plot_out, error_out],
|
296 |
)
|
297 |
|
298 |
+
# =========================
|
299 |
+
# Launch the Gradio App
|
300 |
+
# =========================
|
301 |
+
|
302 |
+
demo.launch()
|