LeonceNsh commited on
Commit
b1a9f46
·
verified ·
1 Parent(s): b19ad01

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -324
app.py CHANGED
@@ -1,326 +1,88 @@
1
- import os
2
- import json
3
- import openai
4
- import duckdb
5
  import gradio as gr
6
- from functools import lru_cache
7
- from dotenv import load_dotenv
8
- from e2b_code_interpreter import Sandbox
9
-
10
- # =========================
11
- # Configuration and Setup
12
- # =========================
13
-
14
- # Load environment variables
15
- load_dotenv()
16
-
17
- # Initialize OpenAI API key
18
- openai.api_key = os.getenv("OPENAI_API_KEY")
19
- if not openai.api_key:
20
- raise ValueError("Please set the OPENAI_API_KEY environment variable.")
21
-
22
- # Initialize the Sandbox
23
- sbx = Sandbox() # By default, the sandbox is alive for 5 minutes
24
-
25
- # Path to your Parquet dataset
26
- DATASET_PATH = 'hsas.parquet' # Update with your Parquet file path
27
-
28
- # Define the schema of your dataset
29
- SCHEMA = [
30
- {"column_name": "total_charges", "column_type": "BIGINT"},
31
- {"column_name": "medicare_prov_num", "column_type": "BIGINT"},
32
- {"column_name": "zip_cd_of_residence", "column_type": "VARCHAR"},
33
- {"column_name": "total_days_of_care", "column_type": "BIGINT"},
34
- {"column_name": "total_cases", "column_type": "BIGINT"},
35
- ]
36
-
37
- @lru_cache(maxsize=1)
38
- def get_schema():
39
- """Returns the schema of the dataset."""
40
- return SCHEMA
41
-
42
- COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
43
-
44
- # =========================
45
- # OpenAI API Integration
46
- # =========================
47
-
48
- def parse_query(nl_query):
49
- """
50
- Converts a natural language query into an SQL query using OpenAI's GPT model.
51
-
52
- Args:
53
- nl_query (str): The natural language query.
54
-
55
- Returns:
56
- tuple: A tuple containing the SQL query and an error message (if any).
57
- """
58
- messages = [
59
- {
60
- "role": "system",
61
- "content": (
62
- "You are an assistant that converts natural language queries into SQL queries for the 'hsa_data' table. "
63
- "Ensure the SQL query is syntactically correct and uses only the columns provided in the schema."
64
- ),
65
- },
66
- {
67
- "role": "user",
68
- "content": f"Schema:\n{json.dumps(get_schema(), indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:",
69
- },
70
- ]
71
-
72
- try:
73
- response = openai.chat.completions.create(
74
- model="gpt-4o-mini", # Use a valid and accessible model
75
- messages=messages,
76
- temperature=0,
77
- max_tokens=150,
78
- )
79
- sql_query = response.choices[0].message.content.strip()
80
- return sql_query, ""
81
- except Exception as e:
82
- return "", f"Error generating SQL query: {e}"
83
-
84
- # =========================
85
- # Database Interaction
86
- # =========================
87
-
88
- def init_db():
89
- """
90
- Initializes the DuckDB in-memory database and loads the dataset.
91
-
92
- Returns:
93
- duckdb.DuckDBPyConnection: The DuckDB connection object.
94
- """
95
- try:
96
- con = duckdb.connect(database=':memory:')
97
- con.execute(f"CREATE OR REPLACE VIEW hsa_data AS SELECT * FROM read_parquet('{DATASET_PATH}')")
98
- return con
99
- except Exception as e:
100
- raise RuntimeError(f"Failed to initialize DuckDB: {e}")
101
-
102
- # Initialize the database connection once
103
- db_connection = init_db()
104
-
105
- def execute_sql_query(sql_query):
106
- """
107
- Executes an SQL query against the DuckDB database.
108
-
109
- Args:
110
- sql_query (str): The SQL query to execute.
111
-
112
- Returns:
113
- tuple: A tuple containing the result dataframe and an error message (if any).
114
- """
115
- try:
116
- result_df = db_connection.execute(sql_query).fetchdf()
117
- return result_df, ""
118
- except Exception as e:
119
- return None, f"Error executing query: {e}"
120
-
121
- # =========================
122
- # Gradio Application UI
123
- # =========================
124
-
125
- with gr.Blocks(css="""
126
- .error-message {
127
- color: red;
128
- font-weight: bold;
129
- }
130
- .gradio-container {
131
- max-width: 1200px;
132
- margin: auto;
133
- font-family: -apple-system, BlinkMacSystemFont, 'San Francisco', 'Helvetica Neue', Helvetica, Arial, sans-serif;
134
- }
135
- .header {
136
- text-align: center;
137
- padding: 30px 0;
138
- }
139
- .instructions {
140
- margin: 20px 0;
141
- font-size: 18px;
142
- line-height: 1.6;
143
- }
144
- .example-queries {
145
- margin-bottom: 20px;
146
- }
147
- .button-row {
148
- margin-top: 20px;
149
- }
150
- .input-area {
151
- margin-bottom: 20px;
152
- }
153
- .schema-tab {
154
- padding: 20px;
155
- }
156
- .results {
157
- margin-top: 20px;
158
- }
159
- .copy-button {
160
- margin-top: 10px;
161
- }
162
- """) as demo:
163
- # Header
164
- gr.Markdown("""
165
- # 🏥 Text-to-SQL Healthcare Data Analyst Agent
166
-
167
- Analyze data from the U.S. Center of Medicare and Medicaid using natural language queries.
168
-
169
- """, elem_classes="header")
170
-
171
- # Instructions
172
- gr.Markdown("""
173
- ### Instructions
174
-
175
- 1. **Describe the data you want**: e.g., *"Show total days of care by zip code"*
176
- 2. **Use Example Queries**: Click on any example query button below to execute
177
- 3. **Generate SQL**: Or, enter your own query and click **Generate SQL Query**
178
- 4. **Execute the Query**: Review the generated SQL and click **Execute Query** to see the results
179
-
180
- """, elem_classes="instructions")
181
-
182
- with gr.Row():
183
- with gr.Column(scale=1, min_width=350):
184
- gr.Markdown("### 💡 Example Queries", elem_classes="example-queries")
185
- query_buttons = [
186
- "Calculate the average total_charges by zip_cd_of_residence",
187
- "For each zip_cd_of_residence, calculate the sum of total_charges",
188
- "SELECT * FROM hsa_data WHERE total_days_of_care > 40 LIMIT 30;",
189
- ]
190
- btn_queries = [gr.Button(q, variant="secondary") for q in query_buttons]
191
-
192
- query_input = gr.Textbox(
193
- label="🔍 Your Query",
194
- placeholder='e.g., "Show total charges over 1M by zip code"',
195
- lines=2,
196
- interactive=True,
197
- elem_classes="input-area"
198
- )
199
-
200
- with gr.Row(elem_classes="button-row"):
201
- btn_generate_sql = gr.Button("Generate SQL Query", variant="primary")
202
- btn_execute_query = gr.Button("Execute Query", variant="primary")
203
-
204
- sql_query_out = gr.Code(label="📝 Generated SQL Query", language="sql")
205
- error_out = gr.HTML(elem_classes="error-message", visible=False)
206
-
207
- with gr.Column(scale=2, min_width=650):
208
- gr.Markdown("### 📊 Query Results", elem_classes="results")
209
- results_out = gr.Dataframe(label="Query Results", interactive=False)
210
-
211
- # Copy to Clipboard Button
212
- btn_copy_results = gr.Button("Copy Results to Clipboard", variant="secondary", elem_classes="copy-button")
213
-
214
- # JavaScript for copying to clipboard
215
- copy_script = gr.HTML("""
216
- <script>
217
- function copyToClipboard() {
218
- const resultsContainer = document.querySelector('div[data-testid="dataframe"] table');
219
- if (resultsContainer) {
220
- const text = Array.from(resultsContainer.rows)
221
- .map(row => Array.from(row.cells)
222
- .map(cell => cell.innerText).join("\\t"))
223
- .join("\\n");
224
- navigator.clipboard.writeText(text).then(function() {
225
- alert("Copied results to clipboard!");
226
- }, function(err) {
227
- alert("Failed to copy results: " + err);
228
- });
229
- } else {
230
- alert("No results to copy!");
231
- }
232
- }
233
-
234
- // Attach the copy function to the button
235
- document.addEventListener('DOMContentLoaded', function() {
236
- const copyButton = document.querySelector('.copy-button button');
237
- if (copyButton) {
238
- copyButton.addEventListener('click', copyToClipboard);
239
- }
240
- });
241
- </script>
242
- """)
243
-
244
- # Include the JavaScript in the app
245
- copy_script
246
-
247
- # Dataset Schema Tab
248
- with gr.Tab("📋 Dataset Schema", elem_classes="schema-tab"):
249
- gr.Markdown("### Dataset Schema")
250
- schema_display = gr.JSON(label="Schema", value=get_schema())
251
-
252
- # =========================
253
- # Event Functions
254
- # =========================
255
-
256
- def generate_sql(nl_query):
257
- if not nl_query.strip():
258
- return "", "<p>Please enter a query.</p>", gr.update(visible=True)
259
- sql_query, error = parse_query(nl_query)
260
- if error:
261
- return sql_query, f"<p>{error}</p>", gr.update(visible=True)
262
- else:
263
- return sql_query, "", gr.update(visible=False)
264
-
265
- def execute_query(sql_query):
266
- if not sql_query.strip():
267
- return None, "<p>No SQL query to execute.</p>", gr.update(visible=True)
268
- result_df, error = execute_sql_query(sql_query)
269
- if error:
270
- return None, f"<p>{error}</p>", gr.update(visible=True)
271
- else:
272
- return result_df, "", gr.update(visible=False)
273
-
274
- def handle_example_click(example_query):
275
- if example_query.strip().upper().startswith("SELECT"):
276
- sql_query = example_query
277
- result_df, error = execute_sql_query(sql_query)
278
- if error:
279
- return sql_query, f"<p>{error}</p>", None, gr.update(visible=True)
280
- else:
281
- return sql_query, "", result_df, gr.update(visible=False)
282
- else:
283
- sql_query, error = parse_query(example_query)
284
- if error:
285
- return sql_query, f"<p>{error}</p>", None, gr.update(visible=True)
286
- result_df, exec_error = execute_sql_query(sql_query)
287
- if exec_error:
288
- return sql_query, f"<p>{exec_error}</p>", None, gr.update(visible=True)
289
- else:
290
- return sql_query, "", result_df, gr.update(visible=False)
291
-
292
- # Button Click Event Handlers
293
- btn_generate_sql.click(
294
- fn=generate_sql,
295
- inputs=query_input,
296
- outputs=[sql_query_out, error_out, error_out],
297
- )
298
-
299
- btn_execute_query.click(
300
- fn=execute_query,
301
- inputs=sql_query_out,
302
- outputs=[results_out, error_out, error_out],
303
- )
304
-
305
- for btn in btn_queries:
306
- btn.click(
307
- fn=lambda q=btn.value: handle_example_click(q),
308
- inputs=None,
309
- outputs=[sql_query_out, error_out, results_out, error_out],
310
- )
311
-
312
- # Hide error message when inputs change
313
- query_input.change(fn=lambda: gr.update(visible=False), inputs=None, outputs=[error_out])
314
- sql_query_out.change(fn=lambda: gr.update(visible=False), inputs=None, outputs=[error_out])
315
-
316
- # =========================
317
- # Launch the Gradio App
318
- # =========================
319
-
320
  if __name__ == "__main__":
321
- demo.launch(
322
- server_name="0.0.0.0",
323
- server_port=7860,
324
- share=False,
325
- debug=True,
326
- )
 
 
 
 
 
1
  import gradio as gr
2
+ import pandas as pd
3
+ import duckdb
4
+ from sklearn.decomposition import PCA
5
+ from sklearn.preprocessing import StandardScaler
6
+ from sklearn.feature_selection import SelectKBest, f_regression
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+
11
+ # Function to load data from a Parquet file into a DuckDB in-memory database
12
+ def load_data(parquet_file):
13
+ con = duckdb.connect(database=':memory:')
14
+ con.execute(f"CREATE TABLE data AS SELECT * FROM read_parquet('{parquet_file}')")
15
+ df = con.execute("SELECT * FROM data").fetchdf()
16
+ return df
17
+
18
+ # Function to preprocess data and perform PCA
19
+ def preprocess_and_pca(df, target_column, n_components=5):
20
+ # Drop non-numeric columns
21
+ X = df.select_dtypes(include=[float, int])
22
+ y = df[target_column]
23
+
24
+ # Replace infinity values with NaN
25
+ X.replace([np.inf, -np.inf], np.nan, inplace=True)
26
+
27
+ # Handle missing values by imputing with the median
28
+ X = X.fillna(X.median())
29
+ y = y.fillna(y.median())
30
+
31
+ # Standardize the data
32
+ scaler = StandardScaler()
33
+ X_scaled = scaler.fit_transform(X)
34
+
35
+ # Apply PCA
36
+ pca = PCA(n_components=n_components)
37
+ X_pca = pca.fit_transform(X_scaled)
38
+
39
+ return X_pca, y
40
+
41
+ # Function to visualize the PCA components
42
+ def visualize_pca(X_pca, y):
43
+ # Visualize the first two principal components
44
+ plt.figure(figsize=(10, 6))
45
+ plt.scatter(X_pca[:, 0], X_pca[:, 1], c=y, cmap='viridis', edgecolor='k', s=50)
46
+ plt.xlabel('Principal Component 1')
47
+ plt.ylabel('Principal Component 2')
48
+ plt.title('PCA - First Two Principal Components')
49
+ plt.colorbar(label='Median Income Household')
50
+ plt.savefig('pca_scatter.png')
51
+ plt.close()
52
+
53
+ # Create a DataFrame with the first few principal components for pair plot
54
+ pca_df = pd.DataFrame(X_pca, columns=[f'PC{i+1}' for i in range(X_pca.shape[1])])
55
+ pca_df['Median_Income_Household'] = y
56
+
57
+ # Pair plot of the first few principal components
58
+ sns.pairplot(pca_df, vars=[f'PC{i+1}' for i in range(5)], hue='Median_Income_Household', palette='viridis')
59
+ plt.suptitle('Pair Plot of Principal Components', y=1.02)
60
+ plt.savefig('pca_pairplot.png')
61
+ plt.close()
62
+
63
+ return 'pca_scatter.png', 'pca_pairplot.png'
64
+
65
+ # Gradio interface function
66
+ def gradio_interface(target_column):
67
+ df = load_data('df_usa_health_features.parquet')
68
+ X_pca, y = preprocess_and_pca(df, target_column)
69
+ scatter_plot, pair_plot = visualize_pca(X_pca, y)
70
+ return scatter_plot, pair_plot
71
+
72
+ # Create Gradio interface
73
+ iface = gr.Interface(
74
+ fn=gradio_interface,
75
+ inputs=[
76
+ gr.inputs.Textbox(label="Target Column")
77
+ ],
78
+ outputs=[
79
+ gr.outputs.Image(type="file", label="PCA Scatter Plot"),
80
+ gr.outputs.Image(type="file", label="PCA Pair Plot")
81
+ ],
82
+ title="PCA Visualization with DuckDB and Gradio",
83
+ description="Specify the target column to visualize PCA components from the df_usa_health_features.parquet file."
84
+ )
85
+
86
+ # Launch the Gradio app
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  if __name__ == "__main__":
88
+ iface.launch()