Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -11,13 +11,9 @@ import os
|
|
11 |
# Configuration and Setup
|
12 |
# =========================
|
13 |
|
14 |
-
# Set OpenAI API key
|
15 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
16 |
-
|
17 |
-
# Load the Parquet dataset path
|
18 |
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
|
19 |
|
20 |
-
# Provided schema
|
21 |
schema = [
|
22 |
{"column_name": "department_ind_agency", "column_type": "VARCHAR"},
|
23 |
{"column_name": "cgac", "column_type": "BIGINT"},
|
@@ -50,14 +46,7 @@ def get_schema():
|
|
50 |
|
51 |
COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
|
52 |
|
53 |
-
# =========================
|
54 |
-
# Database Interaction
|
55 |
-
# =========================
|
56 |
-
|
57 |
def load_dataset_schema():
|
58 |
-
"""
|
59 |
-
Loads the dataset schema into DuckDB by creating a view.
|
60 |
-
"""
|
61 |
con = duckdb.connect()
|
62 |
try:
|
63 |
con.execute("DROP VIEW IF EXISTS contract_data")
|
@@ -69,7 +58,6 @@ def load_dataset_schema():
|
|
69 |
finally:
|
70 |
con.close()
|
71 |
|
72 |
-
# Load the dataset schema at startup
|
73 |
load_dataset_schema()
|
74 |
|
75 |
# =========================
|
@@ -77,9 +65,6 @@ load_dataset_schema()
|
|
77 |
# =========================
|
78 |
|
79 |
def parse_query(nl_query):
|
80 |
-
"""
|
81 |
-
Converts a natural language query into a SQL query using OpenAI's API.
|
82 |
-
"""
|
83 |
messages = [
|
84 |
{"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
|
85 |
{"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
|
@@ -87,7 +72,7 @@ def parse_query(nl_query):
|
|
87 |
|
88 |
try:
|
89 |
response = openai.chat.completions.create(
|
90 |
-
model="gpt-
|
91 |
messages=messages,
|
92 |
temperature=0,
|
93 |
max_tokens=150,
|
@@ -97,21 +82,11 @@ def parse_query(nl_query):
|
|
97 |
except Exception as e:
|
98 |
return f"Error generating SQL query: {e}"
|
99 |
|
100 |
-
# =========================
|
101 |
-
# Plotting Utilities
|
102 |
-
# =========================
|
103 |
-
|
104 |
def detect_plot_intent(nl_query):
|
105 |
-
"""
|
106 |
-
Detects if the user's query involves plotting.
|
107 |
-
"""
|
108 |
plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'trend', 'histogram', 'bar', 'line', 'scatter', 'pie']
|
109 |
return any(keyword in nl_query.lower() for keyword in plot_keywords)
|
110 |
|
111 |
def generate_plot(nl_query, result_df):
|
112 |
-
"""
|
113 |
-
Generates a Plotly figure based on the result DataFrame and the user's intent.
|
114 |
-
"""
|
115 |
if not detect_plot_intent(nl_query):
|
116 |
return None, ""
|
117 |
|
@@ -119,7 +94,6 @@ def generate_plot(nl_query, result_df):
|
|
119 |
if len(columns) < 2:
|
120 |
return None, "Not enough data to generate a plot."
|
121 |
|
122 |
-
# Simple heuristic to choose plot type based on keywords
|
123 |
if 'bar' in nl_query.lower():
|
124 |
fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
|
125 |
elif 'line' in nl_query.lower():
|
@@ -129,7 +103,6 @@ def generate_plot(nl_query, result_df):
|
|
129 |
elif 'pie' in nl_query.lower():
|
130 |
fig = px.pie(result_df, names=columns[0], values=columns[1], title='Pie Chart')
|
131 |
else:
|
132 |
-
# Default to bar chart
|
133 |
fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
|
134 |
|
135 |
fig.update_layout(title_x=0.5)
|
@@ -143,7 +116,7 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
143 |
gr.Markdown("""
|
144 |
<h1 style="text-align: center; font-size: 2.5em; color: #333333;">Parquet Data Explorer</h1>
|
145 |
<p style="text-align: center; color: #666666;">Query and visualize your data effortlessly.</p>
|
146 |
-
"""
|
147 |
|
148 |
with gr.Row():
|
149 |
with gr.Column(scale=1):
|
@@ -152,12 +125,13 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
152 |
placeholder='e.g., "What are the total awards over 1M in California?"',
|
153 |
lines=1
|
154 |
)
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
|
|
161 |
error_out = gr.Markdown(
|
162 |
value="",
|
163 |
visible=False
|
@@ -170,24 +144,12 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
170 |
label="Visualization"
|
171 |
)
|
172 |
|
|
|
173 |
gr.Markdown("""
|
174 |
-
|
175 |
-
|
176 |
-
.
|
177 |
-
|
178 |
-
margin: auto;
|
179 |
-
}
|
180 |
-
/* Style the main title */
|
181 |
-
#main-title h1 {
|
182 |
-
font-weight: bold;
|
183 |
-
}
|
184 |
-
/* Style the error alert */
|
185 |
-
.gradio-container .alert-error {
|
186 |
-
background-color: #ffe6e6;
|
187 |
-
color: #cc0000;
|
188 |
-
border: 1px solid #cc0000;
|
189 |
-
}
|
190 |
-
</style>
|
191 |
""")
|
192 |
|
193 |
# =========================
|
@@ -195,9 +157,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
195 |
# =========================
|
196 |
|
197 |
def on_query_submit(nl_query):
|
198 |
-
"""
|
199 |
-
Handles the submission of a natural language query.
|
200 |
-
"""
|
201 |
if not nl_query.strip():
|
202 |
return gr.update(visible=True, value="Please enter a query."), None, None
|
203 |
|
@@ -215,15 +174,18 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
215 |
|
216 |
return gr.update(visible=False, value=""), result_df, fig
|
217 |
|
218 |
-
def
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
223 |
|
224 |
-
|
225 |
-
|
226 |
-
|
|
|
227 |
|
228 |
query.submit(
|
229 |
fn=on_query_submit,
|
@@ -231,31 +193,6 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
|
|
231 |
outputs=[error_out, results_out, plot_out]
|
232 |
)
|
233 |
|
234 |
-
query.focus(
|
235 |
-
fn=lambda: gr.update(visible=True),
|
236 |
-
inputs=None,
|
237 |
-
outputs=schema_display
|
238 |
-
)
|
239 |
-
|
240 |
-
# =========================
|
241 |
-
# Helper Functions
|
242 |
-
# =========================
|
243 |
-
|
244 |
-
def execute_query(sql_query):
|
245 |
-
"""
|
246 |
-
Executes the SQL query and returns the results.
|
247 |
-
"""
|
248 |
-
try:
|
249 |
-
con = duckdb.connect()
|
250 |
-
con.execute("PRAGMA threads=4") # Optimize for performance
|
251 |
-
con.execute("DROP VIEW IF EXISTS contract_data")
|
252 |
-
con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
|
253 |
-
result_df = con.execute(sql_query).fetchdf()
|
254 |
-
con.close()
|
255 |
-
return result_df, ""
|
256 |
-
except Exception as e:
|
257 |
-
return None, f"Error executing query: {e}"
|
258 |
-
|
259 |
# =========================
|
260 |
# Launch the Gradio App
|
261 |
# =========================
|
|
|
11 |
# Configuration and Setup
|
12 |
# =========================
|
13 |
|
|
|
14 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
|
|
|
|
15 |
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
|
16 |
|
|
|
17 |
schema = [
|
18 |
{"column_name": "department_ind_agency", "column_type": "VARCHAR"},
|
19 |
{"column_name": "cgac", "column_type": "BIGINT"},
|
|
|
46 |
|
47 |
COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
|
48 |
|
|
|
|
|
|
|
|
|
49 |
def load_dataset_schema():
|
|
|
|
|
|
|
50 |
con = duckdb.connect()
|
51 |
try:
|
52 |
con.execute("DROP VIEW IF EXISTS contract_data")
|
|
|
58 |
finally:
|
59 |
con.close()
|
60 |
|
|
|
61 |
load_dataset_schema()
|
62 |
|
63 |
# =========================
|
|
|
65 |
# =========================
|
66 |
|
67 |
def parse_query(nl_query):
|
|
|
|
|
|
|
68 |
messages = [
|
69 |
{"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
|
70 |
{"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
|
|
|
72 |
|
73 |
try:
|
74 |
response = openai.chat.completions.create(
|
75 |
+
model="gpt-4",
|
76 |
messages=messages,
|
77 |
temperature=0,
|
78 |
max_tokens=150,
|
|
|
82 |
except Exception as e:
|
83 |
return f"Error generating SQL query: {e}"
|
84 |
|
|
|
|
|
|
|
|
|
85 |
def detect_plot_intent(nl_query):
|
|
|
|
|
|
|
86 |
plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'trend', 'histogram', 'bar', 'line', 'scatter', 'pie']
|
87 |
return any(keyword in nl_query.lower() for keyword in plot_keywords)
|
88 |
|
89 |
def generate_plot(nl_query, result_df):
|
|
|
|
|
|
|
90 |
if not detect_plot_intent(nl_query):
|
91 |
return None, ""
|
92 |
|
|
|
94 |
if len(columns) < 2:
|
95 |
return None, "Not enough data to generate a plot."
|
96 |
|
|
|
97 |
if 'bar' in nl_query.lower():
|
98 |
fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
|
99 |
elif 'line' in nl_query.lower():
|
|
|
103 |
elif 'pie' in nl_query.lower():
|
104 |
fig = px.pie(result_df, names=columns[0], values=columns[1], title='Pie Chart')
|
105 |
else:
|
|
|
106 |
fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
|
107 |
|
108 |
fig.update_layout(title_x=0.5)
|
|
|
116 |
gr.Markdown("""
|
117 |
<h1 style="text-align: center; font-size: 2.5em; color: #333333;">Parquet Data Explorer</h1>
|
118 |
<p style="text-align: center; color: #666666;">Query and visualize your data effortlessly.</p>
|
119 |
+
""")
|
120 |
|
121 |
with gr.Row():
|
122 |
with gr.Column(scale=1):
|
|
|
125 |
placeholder='e.g., "What are the total awards over 1M in California?"',
|
126 |
lines=1
|
127 |
)
|
128 |
+
gr.Markdown("### Example Queries")
|
129 |
+
with gr.Row():
|
130 |
+
btn_example1 = gr.Button("Show awards over 1M in CA")
|
131 |
+
btn_example2 = gr.Button("List all contracts in New York")
|
132 |
+
btn_example3 = gr.Button("Show top 5 departments by award amount")
|
133 |
+
btn_example4 = gr.Button("Execute: SELECT * from contract_data LIMIT 10;")
|
134 |
+
|
135 |
error_out = gr.Markdown(
|
136 |
value="",
|
137 |
visible=False
|
|
|
144 |
label="Visualization"
|
145 |
)
|
146 |
|
147 |
+
# Instructions
|
148 |
gr.Markdown("""
|
149 |
+
## Instructions
|
150 |
+
1. **Enter a query**: Type in a natural language query in the textbox.
|
151 |
+
2. **Use Example Queries**: Click on any example query button above.
|
152 |
+
3. **Generate SQL and Plot**: Click "Execute" to see results and visualization.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
153 |
""")
|
154 |
|
155 |
# =========================
|
|
|
157 |
# =========================
|
158 |
|
159 |
def on_query_submit(nl_query):
|
|
|
|
|
|
|
160 |
if not nl_query.strip():
|
161 |
return gr.update(visible=True, value="Please enter a query."), None, None
|
162 |
|
|
|
174 |
|
175 |
return gr.update(visible=False, value=""), result_df, fig
|
176 |
|
177 |
+
def on_example_click(query_text):
|
178 |
+
sql_query = parse_query(query_text)
|
179 |
+
result_df, error_msg = execute_query(sql_query)
|
180 |
+
if error_msg:
|
181 |
+
return sql_query, None, None, error_msg
|
182 |
+
fig, plot_error = generate_plot(query_text, result_df)
|
183 |
+
return sql_query, result_df, fig, plot_error if plot_error else ""
|
184 |
|
185 |
+
btn_example1.click(lambda: on_example_click("Show awards over 1M in CA"), outputs=[results_out, plot_out, error_out])
|
186 |
+
btn_example2.click(lambda: on_example_click("List all contracts in New York"), outputs=[results_out, plot_out, error_out])
|
187 |
+
btn_example3.click(lambda: on_example_click("Show top 5 departments by award amount"), outputs=[results_out, plot_out, error_out])
|
188 |
+
btn_example4.click(lambda: on_example_click("SELECT * from contract_data LIMIT 10;"), outputs=[results_out, plot_out, error_out])
|
189 |
|
190 |
query.submit(
|
191 |
fn=on_query_submit,
|
|
|
193 |
outputs=[error_out, results_out, plot_out]
|
194 |
)
|
195 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
# =========================
|
197 |
# Launch the Gradio App
|
198 |
# =========================
|