Update app.py
Browse files
app.py
CHANGED
@@ -1,326 +1,88 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import openai
|
4 |
-
import duckdb
|
5 |
import gradio as gr
|
6 |
-
|
7 |
-
|
8 |
-
from
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
#
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
#
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
Returns:
|
93 |
-
duckdb.DuckDBPyConnection: The DuckDB connection object.
|
94 |
-
"""
|
95 |
-
try:
|
96 |
-
con = duckdb.connect(database=':memory:')
|
97 |
-
con.execute(f"CREATE OR REPLACE VIEW hsa_data AS SELECT * FROM read_parquet('{DATASET_PATH}')")
|
98 |
-
return con
|
99 |
-
except Exception as e:
|
100 |
-
raise RuntimeError(f"Failed to initialize DuckDB: {e}")
|
101 |
-
|
102 |
-
# Initialize the database connection once
|
103 |
-
db_connection = init_db()
|
104 |
-
|
105 |
-
def execute_sql_query(sql_query):
|
106 |
-
"""
|
107 |
-
Executes an SQL query against the DuckDB database.
|
108 |
-
|
109 |
-
Args:
|
110 |
-
sql_query (str): The SQL query to execute.
|
111 |
-
|
112 |
-
Returns:
|
113 |
-
tuple: A tuple containing the result dataframe and an error message (if any).
|
114 |
-
"""
|
115 |
-
try:
|
116 |
-
result_df = db_connection.execute(sql_query).fetchdf()
|
117 |
-
return result_df, ""
|
118 |
-
except Exception as e:
|
119 |
-
return None, f"Error executing query: {e}"
|
120 |
-
|
121 |
-
# =========================
|
122 |
-
# Gradio Application UI
|
123 |
-
# =========================
|
124 |
-
|
125 |
-
with gr.Blocks(css="""
|
126 |
-
.error-message {
|
127 |
-
color: red;
|
128 |
-
font-weight: bold;
|
129 |
-
}
|
130 |
-
.gradio-container {
|
131 |
-
max-width: 1200px;
|
132 |
-
margin: auto;
|
133 |
-
font-family: -apple-system, BlinkMacSystemFont, 'San Francisco', 'Helvetica Neue', Helvetica, Arial, sans-serif;
|
134 |
-
}
|
135 |
-
.header {
|
136 |
-
text-align: center;
|
137 |
-
padding: 30px 0;
|
138 |
-
}
|
139 |
-
.instructions {
|
140 |
-
margin: 20px 0;
|
141 |
-
font-size: 18px;
|
142 |
-
line-height: 1.6;
|
143 |
-
}
|
144 |
-
.example-queries {
|
145 |
-
margin-bottom: 20px;
|
146 |
-
}
|
147 |
-
.button-row {
|
148 |
-
margin-top: 20px;
|
149 |
-
}
|
150 |
-
.input-area {
|
151 |
-
margin-bottom: 20px;
|
152 |
-
}
|
153 |
-
.schema-tab {
|
154 |
-
padding: 20px;
|
155 |
-
}
|
156 |
-
.results {
|
157 |
-
margin-top: 20px;
|
158 |
-
}
|
159 |
-
.copy-button {
|
160 |
-
margin-top: 10px;
|
161 |
-
}
|
162 |
-
""") as demo:
|
163 |
-
# Header
|
164 |
-
gr.Markdown("""
|
165 |
-
# 🏥 Text-to-SQL Healthcare Data Analyst Agent
|
166 |
-
|
167 |
-
Analyze data from the U.S. Center of Medicare and Medicaid using natural language queries.
|
168 |
-
|
169 |
-
""", elem_classes="header")
|
170 |
-
|
171 |
-
# Instructions
|
172 |
-
gr.Markdown("""
|
173 |
-
### Instructions
|
174 |
-
|
175 |
-
1. **Describe the data you want**: e.g., *"Show total days of care by zip code"*
|
176 |
-
2. **Use Example Queries**: Click on any example query button below to execute
|
177 |
-
3. **Generate SQL**: Or, enter your own query and click **Generate SQL Query**
|
178 |
-
4. **Execute the Query**: Review the generated SQL and click **Execute Query** to see the results
|
179 |
-
|
180 |
-
""", elem_classes="instructions")
|
181 |
-
|
182 |
-
with gr.Row():
|
183 |
-
with gr.Column(scale=1, min_width=350):
|
184 |
-
gr.Markdown("### 💡 Example Queries", elem_classes="example-queries")
|
185 |
-
query_buttons = [
|
186 |
-
"Calculate the average total_charges by zip_cd_of_residence",
|
187 |
-
"For each zip_cd_of_residence, calculate the sum of total_charges",
|
188 |
-
"SELECT * FROM hsa_data WHERE total_days_of_care > 40 LIMIT 30;",
|
189 |
-
]
|
190 |
-
btn_queries = [gr.Button(q, variant="secondary") for q in query_buttons]
|
191 |
-
|
192 |
-
query_input = gr.Textbox(
|
193 |
-
label="🔍 Your Query",
|
194 |
-
placeholder='e.g., "Show total charges over 1M by zip code"',
|
195 |
-
lines=2,
|
196 |
-
interactive=True,
|
197 |
-
elem_classes="input-area"
|
198 |
-
)
|
199 |
-
|
200 |
-
with gr.Row(elem_classes="button-row"):
|
201 |
-
btn_generate_sql = gr.Button("Generate SQL Query", variant="primary")
|
202 |
-
btn_execute_query = gr.Button("Execute Query", variant="primary")
|
203 |
-
|
204 |
-
sql_query_out = gr.Code(label="📝 Generated SQL Query", language="sql")
|
205 |
-
error_out = gr.HTML(elem_classes="error-message", visible=False)
|
206 |
-
|
207 |
-
with gr.Column(scale=2, min_width=650):
|
208 |
-
gr.Markdown("### 📊 Query Results", elem_classes="results")
|
209 |
-
results_out = gr.Dataframe(label="Query Results", interactive=False)
|
210 |
-
|
211 |
-
# Copy to Clipboard Button
|
212 |
-
btn_copy_results = gr.Button("Copy Results to Clipboard", variant="secondary", elem_classes="copy-button")
|
213 |
-
|
214 |
-
# JavaScript for copying to clipboard
|
215 |
-
copy_script = gr.HTML("""
|
216 |
-
<script>
|
217 |
-
function copyToClipboard() {
|
218 |
-
const resultsContainer = document.querySelector('div[data-testid="dataframe"] table');
|
219 |
-
if (resultsContainer) {
|
220 |
-
const text = Array.from(resultsContainer.rows)
|
221 |
-
.map(row => Array.from(row.cells)
|
222 |
-
.map(cell => cell.innerText).join("\\t"))
|
223 |
-
.join("\\n");
|
224 |
-
navigator.clipboard.writeText(text).then(function() {
|
225 |
-
alert("Copied results to clipboard!");
|
226 |
-
}, function(err) {
|
227 |
-
alert("Failed to copy results: " + err);
|
228 |
-
});
|
229 |
-
} else {
|
230 |
-
alert("No results to copy!");
|
231 |
-
}
|
232 |
-
}
|
233 |
-
|
234 |
-
// Attach the copy function to the button
|
235 |
-
document.addEventListener('DOMContentLoaded', function() {
|
236 |
-
const copyButton = document.querySelector('.copy-button button');
|
237 |
-
if (copyButton) {
|
238 |
-
copyButton.addEventListener('click', copyToClipboard);
|
239 |
-
}
|
240 |
-
});
|
241 |
-
</script>
|
242 |
-
""")
|
243 |
-
|
244 |
-
# Include the JavaScript in the app
|
245 |
-
copy_script
|
246 |
-
|
247 |
-
# Dataset Schema Tab
|
248 |
-
with gr.Tab("📋 Dataset Schema", elem_classes="schema-tab"):
|
249 |
-
gr.Markdown("### Dataset Schema")
|
250 |
-
schema_display = gr.JSON(label="Schema", value=get_schema())
|
251 |
-
|
252 |
-
# =========================
|
253 |
-
# Event Functions
|
254 |
-
# =========================
|
255 |
-
|
256 |
-
def generate_sql(nl_query):
|
257 |
-
if not nl_query.strip():
|
258 |
-
return "", "<p>Please enter a query.</p>", gr.update(visible=True)
|
259 |
-
sql_query, error = parse_query(nl_query)
|
260 |
-
if error:
|
261 |
-
return sql_query, f"<p>{error}</p>", gr.update(visible=True)
|
262 |
-
else:
|
263 |
-
return sql_query, "", gr.update(visible=False)
|
264 |
-
|
265 |
-
def execute_query(sql_query):
|
266 |
-
if not sql_query.strip():
|
267 |
-
return None, "<p>No SQL query to execute.</p>", gr.update(visible=True)
|
268 |
-
result_df, error = execute_sql_query(sql_query)
|
269 |
-
if error:
|
270 |
-
return None, f"<p>{error}</p>", gr.update(visible=True)
|
271 |
-
else:
|
272 |
-
return result_df, "", gr.update(visible=False)
|
273 |
-
|
274 |
-
def handle_example_click(example_query):
|
275 |
-
if example_query.strip().upper().startswith("SELECT"):
|
276 |
-
sql_query = example_query
|
277 |
-
result_df, error = execute_sql_query(sql_query)
|
278 |
-
if error:
|
279 |
-
return sql_query, f"<p>{error}</p>", None, gr.update(visible=True)
|
280 |
-
else:
|
281 |
-
return sql_query, "", result_df, gr.update(visible=False)
|
282 |
-
else:
|
283 |
-
sql_query, error = parse_query(example_query)
|
284 |
-
if error:
|
285 |
-
return sql_query, f"<p>{error}</p>", None, gr.update(visible=True)
|
286 |
-
result_df, exec_error = execute_sql_query(sql_query)
|
287 |
-
if exec_error:
|
288 |
-
return sql_query, f"<p>{exec_error}</p>", None, gr.update(visible=True)
|
289 |
-
else:
|
290 |
-
return sql_query, "", result_df, gr.update(visible=False)
|
291 |
-
|
292 |
-
# Button Click Event Handlers
|
293 |
-
btn_generate_sql.click(
|
294 |
-
fn=generate_sql,
|
295 |
-
inputs=query_input,
|
296 |
-
outputs=[sql_query_out, error_out, error_out],
|
297 |
-
)
|
298 |
-
|
299 |
-
btn_execute_query.click(
|
300 |
-
fn=execute_query,
|
301 |
-
inputs=sql_query_out,
|
302 |
-
outputs=[results_out, error_out, error_out],
|
303 |
-
)
|
304 |
-
|
305 |
-
for btn in btn_queries:
|
306 |
-
btn.click(
|
307 |
-
fn=lambda q=btn.value: handle_example_click(q),
|
308 |
-
inputs=None,
|
309 |
-
outputs=[sql_query_out, error_out, results_out, error_out],
|
310 |
-
)
|
311 |
-
|
312 |
-
# Hide error message when inputs change
|
313 |
-
query_input.change(fn=lambda: gr.update(visible=False), inputs=None, outputs=[error_out])
|
314 |
-
sql_query_out.change(fn=lambda: gr.update(visible=False), inputs=None, outputs=[error_out])
|
315 |
-
|
316 |
-
# =========================
|
317 |
-
# Launch the Gradio App
|
318 |
-
# =========================
|
319 |
-
|
320 |
if __name__ == "__main__":
|
321 |
-
|
322 |
-
server_name="0.0.0.0",
|
323 |
-
server_port=7860,
|
324 |
-
share=False,
|
325 |
-
debug=True,
|
326 |
-
)
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import pandas as pd
|
3 |
+
import duckdb
|
4 |
+
from sklearn.decomposition import PCA
|
5 |
+
from sklearn.preprocessing import StandardScaler
|
6 |
+
from sklearn.feature_selection import SelectKBest, f_regression
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import seaborn as sns
|
10 |
+
|
11 |
+
# Function to load data from a Parquet file into a DuckDB in-memory database
|
12 |
+
def load_data(parquet_file):
|
13 |
+
con = duckdb.connect(database=':memory:')
|
14 |
+
con.execute(f"CREATE TABLE data AS SELECT * FROM read_parquet('{parquet_file}')")
|
15 |
+
df = con.execute("SELECT * FROM data").fetchdf()
|
16 |
+
return df
|
17 |
+
|
18 |
+
# Function to preprocess data and perform PCA
|
19 |
+
def preprocess_and_pca(df, target_column, n_components=5):
|
20 |
+
# Drop non-numeric columns
|
21 |
+
X = df.select_dtypes(include=[float, int])
|
22 |
+
y = df[target_column]
|
23 |
+
|
24 |
+
# Replace infinity values with NaN
|
25 |
+
X.replace([np.inf, -np.inf], np.nan, inplace=True)
|
26 |
+
|
27 |
+
# Handle missing values by imputing with the median
|
28 |
+
X = X.fillna(X.median())
|
29 |
+
y = y.fillna(y.median())
|
30 |
+
|
31 |
+
# Standardize the data
|
32 |
+
scaler = StandardScaler()
|
33 |
+
X_scaled = scaler.fit_transform(X)
|
34 |
+
|
35 |
+
# Apply PCA
|
36 |
+
pca = PCA(n_components=n_components)
|
37 |
+
X_pca = pca.fit_transform(X_scaled)
|
38 |
+
|
39 |
+
return X_pca, y
|
40 |
+
|
41 |
+
# Function to visualize the PCA components
|
42 |
+
def visualize_pca(X_pca, y):
|
43 |
+
# Visualize the first two principal components
|
44 |
+
plt.figure(figsize=(10, 6))
|
45 |
+
plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, cmap='viridis', edgecolor='k', s=50)
|
46 |
+
plt.xlabel('Principal Component 1')
|
47 |
+
plt.ylabel('Principal Component 2')
|
48 |
+
plt.title('PCA - First Two Principal Components')
|
49 |
+
plt.colorbar(label='Median Income Household')
|
50 |
+
plt.savefig('pca_scatter.png')
|
51 |
+
plt.close()
|
52 |
+
|
53 |
+
# Create a DataFrame with the first few principal components for pair plot
|
54 |
+
pca_df = pd.DataFrame(X_pca, columns=[f'PC{i+1}' for i in range(X_pca.shape[1])])
|
55 |
+
pca_df['Median_Income_Household'] = y
|
56 |
+
|
57 |
+
# Pair plot of the first few principal components
|
58 |
+
sns.pairplot(pca_df, vars=[f'PC{i+1}' for i in range(5)], hue='Median_Income_Household', palette='viridis')
|
59 |
+
plt.suptitle('Pair Plot of Principal Components', y=1.02)
|
60 |
+
plt.savefig('pca_pairplot.png')
|
61 |
+
plt.close()
|
62 |
+
|
63 |
+
return 'pca_scatter.png', 'pca_pairplot.png'
|
64 |
+
|
65 |
+
# Gradio interface function
|
66 |
+
def gradio_interface(target_column):
|
67 |
+
df = load_data('df_usa_health_features.parquet')
|
68 |
+
X_pca, y = preprocess_and_pca(df, target_column)
|
69 |
+
scatter_plot, pair_plot = visualize_pca(X_pca, y)
|
70 |
+
return scatter_plot, pair_plot
|
71 |
+
|
72 |
+
# Create Gradio interface
|
73 |
+
iface = gr.Interface(
|
74 |
+
fn=gradio_interface,
|
75 |
+
inputs=[
|
76 |
+
gr.inputs.Textbox(label="Target Column")
|
77 |
+
],
|
78 |
+
outputs=[
|
79 |
+
gr.outputs.Image(type="file", label="PCA Scatter Plot"),
|
80 |
+
gr.outputs.Image(type="file", label="PCA Pair Plot")
|
81 |
+
],
|
82 |
+
title="PCA Visualization with DuckDB and Gradio",
|
83 |
+
description="Specify the target column to visualize PCA components from the df_usa_health_features.parquet file."
|
84 |
+
)
|
85 |
+
|
86 |
+
# Launch the Gradio app
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
if __name__ == "__main__":
|
88 |
+
iface.launch()
|
|
|
|
|
|
|
|
|
|