Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -7,13 +7,13 @@ import pandas as pd
|
|
7 |
import plotly.express as px
|
8 |
import os
|
9 |
|
10 |
-
# Set OpenAI API key
|
11 |
-
openai.api_key = os.getenv("OPENAI_API_KEY")
|
12 |
-
|
13 |
# =========================
|
14 |
# Configuration and Setup
|
15 |
# =========================
|
16 |
|
|
|
|
|
|
|
17 |
# Load the Parquet dataset path
|
18 |
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
|
19 |
|
@@ -69,6 +69,9 @@ def load_dataset_schema():
|
|
69 |
finally:
|
70 |
con.close()
|
71 |
|
|
|
|
|
|
|
72 |
# =========================
|
73 |
# OpenAI API Integration
|
74 |
# =========================
|
@@ -78,13 +81,13 @@ def parse_query(nl_query):
|
|
78 |
Converts a natural language query into a SQL query using OpenAI's API.
|
79 |
"""
|
80 |
messages = [
|
81 |
-
{"role": "system", "content": "
|
82 |
{"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
|
83 |
]
|
84 |
|
85 |
try:
|
86 |
-
response = openai.
|
87 |
-
model="gpt-
|
88 |
messages=messages,
|
89 |
temperature=0,
|
90 |
max_tokens=150,
|
@@ -102,23 +105,35 @@ def detect_plot_intent(nl_query):
|
|
102 |
"""
|
103 |
Detects if the user's query involves plotting.
|
104 |
"""
|
105 |
-
plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'trend', 'histogram', 'bar', 'line']
|
106 |
return any(keyword in nl_query.lower() for keyword in plot_keywords)
|
107 |
|
108 |
-
def
|
109 |
"""
|
110 |
-
Generates
|
111 |
"""
|
112 |
-
if not detect_plot_intent(
|
113 |
-
return None
|
114 |
|
115 |
columns = result_df.columns.tolist()
|
116 |
-
if len(columns)
|
117 |
-
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
else:
|
121 |
-
|
|
|
|
|
|
|
|
|
122 |
|
123 |
# =========================
|
124 |
# Gradio Application UI
|
@@ -126,48 +141,101 @@ def generate_plot_code(sql_query, result_df):
|
|
126 |
|
127 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
128 |
gr.Markdown("""
|
129 |
-
|
130 |
-
|
131 |
-
**Query and visualize data effortlessly.**
|
132 |
-
|
133 |
""", elem_id="main-title")
|
134 |
|
135 |
with gr.Row():
|
136 |
with gr.Column(scale=1):
|
137 |
query = gr.Textbox(
|
138 |
-
label="
|
139 |
placeholder='e.g., "What are the total awards over 1M in California?"',
|
140 |
lines=1
|
141 |
)
|
142 |
-
#
|
143 |
-
schema_display = gr.JSON(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
146 |
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
|
151 |
def on_query_submit(nl_query):
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
sql_query = parse_query(nl_query)
|
153 |
if sql_query.startswith("Error"):
|
154 |
return gr.update(visible=True, value=sql_query), None, None
|
|
|
155 |
result_df, error_msg = execute_query(sql_query)
|
156 |
if error_msg:
|
157 |
return gr.update(visible=True, value=error_msg), None, None
|
158 |
-
fig = generate_plot_code(nl_query, result_df)
|
159 |
-
return gr.update(visible=False), result_df, fig
|
160 |
|
161 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
162 |
return gr.update(visible=True)
|
163 |
|
|
|
|
|
|
|
|
|
164 |
query.submit(
|
165 |
fn=on_query_submit,
|
166 |
inputs=query,
|
167 |
outputs=[error_out, results_out, plot_out]
|
168 |
)
|
|
|
169 |
query.focus(
|
170 |
-
fn=
|
|
|
171 |
outputs=schema_display
|
172 |
)
|
173 |
|
@@ -179,12 +247,11 @@ def execute_query(sql_query):
|
|
179 |
"""
|
180 |
Executes the SQL query and returns the results.
|
181 |
"""
|
182 |
-
if sql_query.startswith("Error"):
|
183 |
-
return None, sql_query
|
184 |
-
|
185 |
try:
|
186 |
con = duckdb.connect()
|
187 |
-
con.execute(
|
|
|
|
|
188 |
result_df = con.execute(sql_query).fetchdf()
|
189 |
con.close()
|
190 |
return result_df, ""
|
@@ -195,4 +262,5 @@ def execute_query(sql_query):
|
|
195 |
# Launch the Gradio App
|
196 |
# =========================
|
197 |
|
198 |
-
|
|
|
|
7 |
import plotly.express as px
|
8 |
import os
|
9 |
|
|
|
|
|
|
|
10 |
# =========================
|
11 |
# Configuration and Setup
|
12 |
# =========================
|
13 |
|
14 |
+
# Set OpenAI API key
|
15 |
+
openai.api_key = os.getenv("OPENAI_API_KEY")
|
16 |
+
|
17 |
# Load the Parquet dataset path
|
18 |
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
|
19 |
|
|
|
69 |
finally:
|
70 |
con.close()
|
71 |
|
72 |
+
# Load the dataset schema at startup
|
73 |
+
load_dataset_schema()
|
74 |
+
|
75 |
# =========================
|
76 |
# OpenAI API Integration
|
77 |
# =========================
|
|
|
81 |
Converts a natural language query into a SQL query using OpenAI's API.
|
82 |
"""
|
83 |
messages = [
|
84 |
+
{"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
|
85 |
{"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
|
86 |
]
|
87 |
|
88 |
try:
|
89 |
+
response = openai.ChatCompletion.create(
|
90 |
+
model="gpt-3.5-turbo",
|
91 |
messages=messages,
|
92 |
temperature=0,
|
93 |
max_tokens=150,
|
|
|
105 |
"""
|
106 |
Detects if the user's query involves plotting.
|
107 |
"""
|
108 |
+
plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'trend', 'histogram', 'bar', 'line', 'scatter', 'pie']
|
109 |
return any(keyword in nl_query.lower() for keyword in plot_keywords)
|
110 |
|
111 |
+
def generate_plot(nl_query, result_df):
|
112 |
"""
|
113 |
+
Generates a Plotly figure based on the result DataFrame and the user's intent.
|
114 |
"""
|
115 |
+
if not detect_plot_intent(nl_query):
|
116 |
+
return None, ""
|
117 |
|
118 |
columns = result_df.columns.tolist()
|
119 |
+
if len(columns) < 2:
|
120 |
+
return None, "Not enough data to generate a plot."
|
121 |
+
|
122 |
+
# Simple heuristic to choose plot type based on keywords
|
123 |
+
if 'bar' in nl_query.lower():
|
124 |
+
fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
|
125 |
+
elif 'line' in nl_query.lower():
|
126 |
+
fig = px.line(result_df, x=columns[0], y=columns[1], title='Line Chart')
|
127 |
+
elif 'scatter' in nl_query.lower():
|
128 |
+
fig = px.scatter(result_df, x=columns[0], y=columns[1], title='Scatter Plot')
|
129 |
+
elif 'pie' in nl_query.lower():
|
130 |
+
fig = px.pie(result_df, names=columns[0], values=columns[1], title='Pie Chart')
|
131 |
else:
|
132 |
+
# Default to bar chart
|
133 |
+
fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
|
134 |
+
|
135 |
+
fig.update_layout(title_x=0.5)
|
136 |
+
return fig, ""
|
137 |
|
138 |
# =========================
|
139 |
# Gradio Application UI
|
|
|
141 |
|
142 |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
143 |
gr.Markdown("""
|
144 |
+
<h1 style="text-align: center; font-size: 2.5em; color: #333333;">Parquet Data Explorer</h1>
|
145 |
+
<p style="text-align: center; color: #666666;">Query and visualize your data effortlessly.</p>
|
|
|
|
|
146 |
""", elem_id="main-title")
|
147 |
|
148 |
with gr.Row():
|
149 |
with gr.Column(scale=1):
|
150 |
query = gr.Textbox(
|
151 |
+
label="Your Query",
|
152 |
placeholder='e.g., "What are the total awards over 1M in California?"',
|
153 |
lines=1
|
154 |
)
|
155 |
+
# Hidden schema display that appears on focus
|
156 |
+
schema_display = gr.JSON(
|
157 |
+
label="Dataset Schema",
|
158 |
+
value=get_schema(),
|
159 |
+
interactive=False,
|
160 |
+
visible=False
|
161 |
+
)
|
162 |
+
error_out = gr.Markdown(
|
163 |
+
value="",
|
164 |
+
visible=False
|
165 |
+
)
|
166 |
+
with gr.Column(scale=2):
|
167 |
+
results_out = gr.DataFrame(
|
168 |
+
label="Results",
|
169 |
+
interactive=False
|
170 |
+
)
|
171 |
+
plot_out = gr.Plot(
|
172 |
+
label="Visualization"
|
173 |
+
)
|
174 |
|
175 |
+
gr.Markdown("""
|
176 |
+
<style>
|
177 |
+
/* Center the content */
|
178 |
+
.gradio-container {
|
179 |
+
max-width: 1000px;
|
180 |
+
margin: auto;
|
181 |
+
}
|
182 |
+
/* Style the main title */
|
183 |
+
#main-title h1 {
|
184 |
+
font-weight: bold;
|
185 |
+
}
|
186 |
+
/* Style the error alert */
|
187 |
+
.gradio-container .alert-error {
|
188 |
+
background-color: #ffe6e6;
|
189 |
+
color: #cc0000;
|
190 |
+
border: 1px solid #cc0000;
|
191 |
+
}
|
192 |
+
</style>
|
193 |
+
""")
|
194 |
|
195 |
+
# =========================
|
196 |
+
# Click Event Handlers
|
197 |
+
# =========================
|
198 |
|
199 |
def on_query_submit(nl_query):
|
200 |
+
"""
|
201 |
+
Handles the submission of a natural language query.
|
202 |
+
"""
|
203 |
+
if not nl_query.strip():
|
204 |
+
return gr.update(visible=True, value="Please enter a query."), None, None
|
205 |
+
|
206 |
sql_query = parse_query(nl_query)
|
207 |
if sql_query.startswith("Error"):
|
208 |
return gr.update(visible=True, value=sql_query), None, None
|
209 |
+
|
210 |
result_df, error_msg = execute_query(sql_query)
|
211 |
if error_msg:
|
212 |
return gr.update(visible=True, value=error_msg), None, None
|
|
|
|
|
213 |
|
214 |
+
fig, plot_error = generate_plot(nl_query, result_df)
|
215 |
+
if plot_error:
|
216 |
+
return gr.update(visible=True, value=plot_error), None, None
|
217 |
+
|
218 |
+
return gr.update(visible=False, value=""), result_df, fig
|
219 |
+
|
220 |
+
def on_input_focus():
|
221 |
+
"""
|
222 |
+
Shows the dataset schema when the input box is focused.
|
223 |
+
"""
|
224 |
return gr.update(visible=True)
|
225 |
|
226 |
+
# =========================
|
227 |
+
# Assign Event Handlers
|
228 |
+
# =========================
|
229 |
+
|
230 |
query.submit(
|
231 |
fn=on_query_submit,
|
232 |
inputs=query,
|
233 |
outputs=[error_out, results_out, plot_out]
|
234 |
)
|
235 |
+
|
236 |
query.focus(
|
237 |
+
fn=lambda: gr.update(visible=True),
|
238 |
+
inputs=None,
|
239 |
outputs=schema_display
|
240 |
)
|
241 |
|
|
|
247 |
"""
|
248 |
Executes the SQL query and returns the results.
|
249 |
"""
|
|
|
|
|
|
|
250 |
try:
|
251 |
con = duckdb.connect()
|
252 |
+
con.execute("PRAGMA threads=4") # Optimize for performance
|
253 |
+
con.execute("DROP VIEW IF EXISTS contract_data")
|
254 |
+
con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
|
255 |
result_df = con.execute(sql_query).fetchdf()
|
256 |
con.close()
|
257 |
return result_df, ""
|
|
|
262 |
# Launch the Gradio App
|
263 |
# =========================
|
264 |
|
265 |
+
if __name__ == "__main__":
|
266 |
+
demo.launch()
|