LeonceNsh commited on
Commit
44c9f07
Β·
verified Β·
1 Parent(s): ce0f72b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -52
app.py CHANGED
@@ -11,13 +11,16 @@ from e2b_code_interpreter import Sandbox
11
  # Configuration and Setup
12
  # =========================
13
 
 
 
 
14
  # Initialize OpenAI API key
15
  openai.api_key = os.getenv("OPENAI_API_KEY")
16
  if not openai.api_key:
17
  raise ValueError("Please set the OPENAI_API_KEY environment variable.")
18
 
19
- load_dotenv()
20
- sbx = Sandbox() # By default, the sandbox is alive for 5 minutes
21
 
22
  # Path to your Parquet dataset
23
  DATASET_PATH = 'hsas.parquet' # Update with your Parquet file path
@@ -67,8 +70,8 @@ def parse_query(nl_query):
67
  ]
68
 
69
  try:
70
- response = openai.chat.completions.create(
71
- model="gpt-4o-mini", # Use a valid and accessible model
72
  messages=messages,
73
  temperature=0,
74
  max_tokens=150,
@@ -122,23 +125,27 @@ def execute_sql_query(sql_query):
122
  with gr.Blocks(css="""
123
  .error-message {
124
  color: red;
 
125
  }
126
  .gradio-container {
127
- max-width: 1000px;
128
  margin: auto;
 
129
  }
130
  .header {
131
  text-align: center;
132
- padding: 20px;
133
  }
134
  .instructions {
135
- margin-bottom: 20px;
 
 
136
  }
137
  .example-queries {
138
  margin-bottom: 20px;
139
  }
140
  .button-row {
141
- margin-top: 10px;
142
  }
143
  .input-area {
144
  margin-bottom: 20px;
@@ -146,41 +153,49 @@ with gr.Blocks(css="""
146
  .schema-tab {
147
  padding: 20px;
148
  }
 
 
 
 
 
 
149
  """) as demo:
150
- with gr.Row():
151
- gr.Markdown("# πŸ₯ Text-to-SQL Healthcare Data Analyst Agent", elem_classes="header")
152
-
153
  gr.Markdown("""
154
- # Analyze data from the U.S. Center of Medicare and Medicaid
155
 
156
- ## Demonstrate how to use AI for data analysis and visualization in an isolated sandbox
157
 
158
- ## Use SQL directly in the browser to answer complex business questions starting with plain English.
159
 
160
  # Instructions
 
 
 
 
 
 
 
161
 
162
- # 1. **Describe the data you want**: e.g., `Show total days of care by zip`
163
- # 2. **Use Example Queries**: Click on any example query button below to execute
164
- # 3. **Generate SQL**: Or, enter your own query and click "Generate SQL"
165
  """, elem_classes="instructions")
166
 
167
  with gr.Row():
168
- with gr.Column(scale=1, min_width=300):
169
- gr.Markdown("# πŸ’‘ Example Queries:", elem_classes="example-queries")
170
  query_buttons = [
171
  "Calculate the average total_charges by zip_cd_of_residence",
172
  "For each zip_cd_of_residence, calculate the sum of total_charges",
173
  "SELECT * FROM hsa_data WHERE total_days_of_care > 40 LIMIT 30;",
174
  ]
175
- btn_queries = [gr.Button(q, variant="secondary", size="sm") for q in query_buttons]
176
-
177
- with gr.Group():
178
- query_input = gr.Textbox(
179
- label="πŸ” Your Query",
180
- placeholder='e.g., "Show total charges over 1M by state"',
181
- lines=2,
182
- interactive=True,
183
- )
184
 
185
  with gr.Row(elem_classes="button-row"):
186
  btn_generate_sql = gr.Button("Generate SQL Query", variant="primary")
@@ -189,28 +204,28 @@ with gr.Blocks(css="""
189
  sql_query_out = gr.Code(label="πŸ“ Generated SQL Query", language="sql")
190
  error_out = gr.HTML(elem_classes="error-message", visible=False)
191
 
192
- with gr.Column(scale=2, min_width=600):
193
  gr.Markdown("### πŸ“Š Query Results", elem_classes="results")
194
  results_out = gr.Dataframe(label="Query Results", interactive=False)
195
- btn_copy_results = gr.Button("Copy to Clipboard", variant="secondary", size="sm")
196
 
197
- # JavaScript for copying to clipboard - updated approach
 
 
 
198
  copy_script = gr.HTML("""
199
  <script>
200
  function copyToClipboard() {
201
- const resultsContainer = document.querySelector('.dataframe-output table');
202
  if (resultsContainer) {
203
  const text = Array.from(resultsContainer.rows)
204
  .map(row => Array.from(row.cells)
205
- .map(cell => cell.innerText).join("\t"))
206
- .join("\n");
207
- const tempTextArea = document.createElement('textarea');
208
- tempTextArea.value = text;
209
- document.body.appendChild(tempTextArea);
210
- tempTextArea.select();
211
- document.execCommand('copy');
212
- document.body.removeChild(tempTextArea);
213
- alert("Copied results to clipboard!");
214
  } else {
215
  alert("No results to copy!");
216
  }
@@ -220,9 +235,10 @@ with gr.Blocks(css="""
220
 
221
  # Connect JavaScript function to button
222
  btn_copy_results.click(
223
- None, None, None, _js="copyToClipboard"
224
  )
225
 
 
226
  with gr.Tab("πŸ“‹ Dataset Schema", elem_classes="schema-tab"):
227
  gr.Markdown("### Dataset Schema")
228
  schema_display = gr.JSON(label="Schema", value=get_schema())
@@ -233,19 +249,19 @@ with gr.Blocks(css="""
233
 
234
  def generate_sql(nl_query):
235
  if not nl_query.strip():
236
- return "", "<p class='error-message'>Please enter a query.</p>", gr.update(visible=True)
237
  sql_query, error = parse_query(nl_query)
238
  if error:
239
- return sql_query, f"<p class='error-message'>{error}</p>", gr.update(visible=True)
240
  else:
241
  return sql_query, "", gr.update(visible=False)
242
 
243
  def execute_query(sql_query):
244
  if not sql_query.strip():
245
- return None, "<p class='error-message'>No SQL query to execute.</p>", gr.update(visible=True)
246
  result_df, error = execute_sql_query(sql_query)
247
  if error:
248
- return None, f"<p class='error-message'>{error}</p>", gr.update(visible=True)
249
  else:
250
  return result_df, "", gr.update(visible=False)
251
 
@@ -254,16 +270,16 @@ with gr.Blocks(css="""
254
  sql_query = example_query
255
  result_df, error = execute_sql_query(sql_query)
256
  if error:
257
- return sql_query, f"<p class='error-message'>{error}</p>", None, gr.update(visible=True)
258
  else:
259
- return sql_query, "", result_df, gr.update(visible=False)
260
  else:
261
  sql_query, error = parse_query(example_query)
262
  if error:
263
- return sql_query, f"<p class='error-message'>{error}</p>", None, gr.update(visible=True)
264
  result_df, exec_error = execute_sql_query(sql_query)
265
  if exec_error:
266
- return sql_query, f"<p class='error-message'>{exec_error}</p>", None, gr.update(visible=True)
267
  else:
268
  return sql_query, "", result_df, gr.update(visible=False)
269
 
@@ -287,8 +303,9 @@ with gr.Blocks(css="""
287
  outputs=[sql_query_out, error_out, results_out, error_out],
288
  )
289
 
290
- query_input.change(fn=lambda: ("", gr.update(visible=False)), inputs=None, outputs=[error_out])
291
- sql_query_out.change(fn=lambda: ("", gr.update(visible=False)), inputs=None, outputs=[error_out])
 
292
 
293
  # =========================
294
  # Launch the Gradio App
 
11
  # Configuration and Setup
12
  # =========================
13
 
14
+ # Load environment variables
15
+ load_dotenv()
16
+
17
  # Initialize OpenAI API key
18
  openai.api_key = os.getenv("OPENAI_API_KEY")
19
  if not openai.api_key:
20
  raise ValueError("Please set the OPENAI_API_KEY environment variable.")
21
 
22
+ # Initialize the Sandbox
23
+ sbx = Sandbox() # By default, the sandbox is alive for 5 minutes
24
 
25
  # Path to your Parquet dataset
26
  DATASET_PATH = 'hsas.parquet' # Update with your Parquet file path
 
70
  ]
71
 
72
  try:
73
+ response = openai.ChatCompletion.create(
74
+ model="gpt-3.5-turbo", # Use a valid and accessible model
75
  messages=messages,
76
  temperature=0,
77
  max_tokens=150,
 
125
  with gr.Blocks(css="""
126
  .error-message {
127
  color: red;
128
+ font-weight: bold;
129
  }
130
  .gradio-container {
131
+ max-width: 1200px;
132
  margin: auto;
133
+ font-family: -apple-system, BlinkMacSystemFont, 'San Francisco', 'Helvetica Neue', Helvetica, Arial, sans-serif;
134
  }
135
  .header {
136
  text-align: center;
137
+ padding: 30px 0;
138
  }
139
  .instructions {
140
+ margin: 20px 0;
141
+ font-size: 18px;
142
+ line-height: 1.6;
143
  }
144
  .example-queries {
145
  margin-bottom: 20px;
146
  }
147
  .button-row {
148
+ margin-top: 20px;
149
  }
150
  .input-area {
151
  margin-bottom: 20px;
 
153
  .schema-tab {
154
  padding: 20px;
155
  }
156
+ .results {
157
+ margin-top: 20px;
158
+ }
159
+ .copy-button {
160
+ margin-top: 10px;
161
+ }
162
  """) as demo:
163
+ # Header
 
 
164
  gr.Markdown("""
165
+ # πŸ₯ Text-to-SQL Healthcare Data Analyst Agent
166
 
167
+ Analyze data from the U.S. Center of Medicare and Medicaid using natural language queries.
168
 
169
+ """, elem_classes="header")
170
 
171
  # Instructions
172
+ gr.Markdown("""
173
+ ### Instructions
174
+
175
+ 1. **Describe the data you want**: e.g., *"Show total days of care by zip code"*
176
+ 2. **Use Example Queries**: Click on any example query button below to execute
177
+ 3. **Generate SQL**: Or, enter your own query and click **Generate SQL Query**
178
+ 4. **Execute the Query**: Review the generated SQL and click **Execute Query** to see the results
179
 
 
 
 
180
  """, elem_classes="instructions")
181
 
182
  with gr.Row():
183
+ with gr.Column(scale=1, min_width=350):
184
+ gr.Markdown("### πŸ’‘ Example Queries", elem_classes="example-queries")
185
  query_buttons = [
186
  "Calculate the average total_charges by zip_cd_of_residence",
187
  "For each zip_cd_of_residence, calculate the sum of total_charges",
188
  "SELECT * FROM hsa_data WHERE total_days_of_care > 40 LIMIT 30;",
189
  ]
190
+ btn_queries = [gr.Button(q, variant="secondary") for q in query_buttons]
191
+
192
+ query_input = gr.Textbox(
193
+ label="πŸ” Your Query",
194
+ placeholder='e.g., "Show total charges over 1M by zip code"',
195
+ lines=2,
196
+ interactive=True,
197
+ elem_classes="input-area"
198
+ )
199
 
200
  with gr.Row(elem_classes="button-row"):
201
  btn_generate_sql = gr.Button("Generate SQL Query", variant="primary")
 
204
  sql_query_out = gr.Code(label="πŸ“ Generated SQL Query", language="sql")
205
  error_out = gr.HTML(elem_classes="error-message", visible=False)
206
 
207
+ with gr.Column(scale=2, min_width=650):
208
  gr.Markdown("### πŸ“Š Query Results", elem_classes="results")
209
  results_out = gr.Dataframe(label="Query Results", interactive=False)
 
210
 
211
+ # Copy to Clipboard Button
212
+ btn_copy_results = gr.Button("Copy Results to Clipboard", variant="secondary", elem_classes="copy-button")
213
+
214
+ # JavaScript for copying to clipboard
215
  copy_script = gr.HTML("""
216
  <script>
217
  function copyToClipboard() {
218
+ const resultsContainer = document.querySelector('div[data-testid="dataframe"] table');
219
  if (resultsContainer) {
220
  const text = Array.from(resultsContainer.rows)
221
  .map(row => Array.from(row.cells)
222
+ .map(cell => cell.innerText).join("\\t"))
223
+ .join("\\n");
224
+ navigator.clipboard.writeText(text).then(function() {
225
+ alert("Copied results to clipboard!");
226
+ }, function(err) {
227
+ alert("Failed to copy results: " + err);
228
+ });
 
 
229
  } else {
230
  alert("No results to copy!");
231
  }
 
235
 
236
  # Connect JavaScript function to button
237
  btn_copy_results.click(
238
+ _js="copyToClipboard"
239
  )
240
 
241
+ # Dataset Schema Tab
242
  with gr.Tab("πŸ“‹ Dataset Schema", elem_classes="schema-tab"):
243
  gr.Markdown("### Dataset Schema")
244
  schema_display = gr.JSON(label="Schema", value=get_schema())
 
249
 
250
  def generate_sql(nl_query):
251
  if not nl_query.strip():
252
+ return "", "<p>Please enter a query.</p>", gr.update(visible=True)
253
  sql_query, error = parse_query(nl_query)
254
  if error:
255
+ return sql_query, f"<p>{error}</p>", gr.update(visible=True)
256
  else:
257
  return sql_query, "", gr.update(visible=False)
258
 
259
  def execute_query(sql_query):
260
  if not sql_query.strip():
261
+ return None, "<p>No SQL query to execute.</p>", gr.update(visible=True)
262
  result_df, error = execute_sql_query(sql_query)
263
  if error:
264
+ return None, f"<p>{error}</p>", gr.update(visible=True)
265
  else:
266
  return result_df, "", gr.update(visible=False)
267
 
 
270
  sql_query = example_query
271
  result_df, error = execute_sql_query(sql_query)
272
  if error:
273
+ return sql_query, f"<p>{error}</p>", None, gr.update(visible=True)
274
  else:
275
+ return sql_query, "", result_df, gr.update(visible(False))
276
  else:
277
  sql_query, error = parse_query(example_query)
278
  if error:
279
+ return sql_query, f"<p>{error}</p>", None, gr.update(visible=True)
280
  result_df, exec_error = execute_sql_query(sql_query)
281
  if exec_error:
282
+ return sql_query, f"<p>{exec_error}</p>", None, gr.update(visible=True)
283
  else:
284
  return sql_query, "", result_df, gr.update(visible=False)
285
 
 
303
  outputs=[sql_query_out, error_out, results_out, error_out],
304
  )
305
 
306
+ # Hide error message when inputs change
307
+ query_input.change(fn=lambda: gr.update(visible=False), inputs=None, outputs=[error_out])
308
+ sql_query_out.change(fn=lambda: gr.update(visible=False), inputs=None, outputs=[error_out])
309
 
310
  # =========================
311
  # Launch the Gradio App