Update app.py
Browse files
app.py
CHANGED
@@ -11,13 +11,16 @@ from e2b_code_interpreter import Sandbox
|
|
11 |
# Configuration and Setup
|
12 |
# =========================
|
13 |
|
|
|
|
|
|
|
14 |
# Initialize OpenAI API key
|
15 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
16 |
if not openai.api_key:
|
17 |
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
|
18 |
|
19 |
-
|
20 |
-
sbx = Sandbox()
|
21 |
|
22 |
# Path to your Parquet dataset
|
23 |
DATASET_PATH = 'hsas.parquet' # Update with your Parquet file path
|
@@ -67,8 +70,8 @@ def parse_query(nl_query):
|
|
67 |
]
|
68 |
|
69 |
try:
|
70 |
-
response = openai.
|
71 |
-
model="gpt-
|
72 |
messages=messages,
|
73 |
temperature=0,
|
74 |
max_tokens=150,
|
@@ -122,23 +125,27 @@ def execute_sql_query(sql_query):
|
|
122 |
with gr.Blocks(css="""
|
123 |
.error-message {
|
124 |
color: red;
|
|
|
125 |
}
|
126 |
.gradio-container {
|
127 |
-
max-width:
|
128 |
margin: auto;
|
|
|
129 |
}
|
130 |
.header {
|
131 |
text-align: center;
|
132 |
-
padding:
|
133 |
}
|
134 |
.instructions {
|
135 |
-
margin
|
|
|
|
|
136 |
}
|
137 |
.example-queries {
|
138 |
margin-bottom: 20px;
|
139 |
}
|
140 |
.button-row {
|
141 |
-
margin-top:
|
142 |
}
|
143 |
.input-area {
|
144 |
margin-bottom: 20px;
|
@@ -146,41 +153,49 @@ with gr.Blocks(css="""
|
|
146 |
.schema-tab {
|
147 |
padding: 20px;
|
148 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
149 |
""") as demo:
|
150 |
-
|
151 |
-
gr.Markdown("# π₯ Text-to-SQL Healthcare Data Analyst Agent", elem_classes="header")
|
152 |
-
|
153 |
gr.Markdown("""
|
154 |
-
#
|
155 |
|
156 |
-
|
157 |
|
158 |
-
|
159 |
|
160 |
# Instructions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
|
162 |
-
# 1. **Describe the data you want**: e.g., `Show total days of care by zip`
|
163 |
-
# 2. **Use Example Queries**: Click on any example query button below to execute
|
164 |
-
# 3. **Generate SQL**: Or, enter your own query and click "Generate SQL"
|
165 |
""", elem_classes="instructions")
|
166 |
|
167 |
with gr.Row():
|
168 |
-
with gr.Column(scale=1, min_width=
|
169 |
-
gr.Markdown("
|
170 |
query_buttons = [
|
171 |
"Calculate the average total_charges by zip_cd_of_residence",
|
172 |
"For each zip_cd_of_residence, calculate the sum of total_charges",
|
173 |
"SELECT * FROM hsa_data WHERE total_days_of_care > 40 LIMIT 30;",
|
174 |
]
|
175 |
-
btn_queries = [gr.Button(q, variant="secondary"
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
|
185 |
with gr.Row(elem_classes="button-row"):
|
186 |
btn_generate_sql = gr.Button("Generate SQL Query", variant="primary")
|
@@ -189,28 +204,28 @@ with gr.Blocks(css="""
|
|
189 |
sql_query_out = gr.Code(label="π Generated SQL Query", language="sql")
|
190 |
error_out = gr.HTML(elem_classes="error-message", visible=False)
|
191 |
|
192 |
-
with gr.Column(scale=2, min_width=
|
193 |
gr.Markdown("### π Query Results", elem_classes="results")
|
194 |
results_out = gr.Dataframe(label="Query Results", interactive=False)
|
195 |
-
btn_copy_results = gr.Button("Copy to Clipboard", variant="secondary", size="sm")
|
196 |
|
197 |
-
#
|
|
|
|
|
|
|
198 |
copy_script = gr.HTML("""
|
199 |
<script>
|
200 |
function copyToClipboard() {
|
201 |
-
const resultsContainer = document.querySelector('
|
202 |
if (resultsContainer) {
|
203 |
const text = Array.from(resultsContainer.rows)
|
204 |
.map(row => Array.from(row.cells)
|
205 |
-
.map(cell => cell.innerText).join("
|
206 |
-
.join("
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
212 |
-
document.body.removeChild(tempTextArea);
|
213 |
-
alert("Copied results to clipboard!");
|
214 |
} else {
|
215 |
alert("No results to copy!");
|
216 |
}
|
@@ -220,9 +235,10 @@ with gr.Blocks(css="""
|
|
220 |
|
221 |
# Connect JavaScript function to button
|
222 |
btn_copy_results.click(
|
223 |
-
|
224 |
)
|
225 |
|
|
|
226 |
with gr.Tab("π Dataset Schema", elem_classes="schema-tab"):
|
227 |
gr.Markdown("### Dataset Schema")
|
228 |
schema_display = gr.JSON(label="Schema", value=get_schema())
|
@@ -233,19 +249,19 @@ with gr.Blocks(css="""
|
|
233 |
|
234 |
def generate_sql(nl_query):
|
235 |
if not nl_query.strip():
|
236 |
-
return "", "<p
|
237 |
sql_query, error = parse_query(nl_query)
|
238 |
if error:
|
239 |
-
return sql_query, f"<p
|
240 |
else:
|
241 |
return sql_query, "", gr.update(visible=False)
|
242 |
|
243 |
def execute_query(sql_query):
|
244 |
if not sql_query.strip():
|
245 |
-
return None, "<p
|
246 |
result_df, error = execute_sql_query(sql_query)
|
247 |
if error:
|
248 |
-
return None, f"<p
|
249 |
else:
|
250 |
return result_df, "", gr.update(visible=False)
|
251 |
|
@@ -254,16 +270,16 @@ with gr.Blocks(css="""
|
|
254 |
sql_query = example_query
|
255 |
result_df, error = execute_sql_query(sql_query)
|
256 |
if error:
|
257 |
-
return sql_query, f"<p
|
258 |
else:
|
259 |
-
return sql_query, "", result_df, gr.update(visible
|
260 |
else:
|
261 |
sql_query, error = parse_query(example_query)
|
262 |
if error:
|
263 |
-
return sql_query, f"<p
|
264 |
result_df, exec_error = execute_sql_query(sql_query)
|
265 |
if exec_error:
|
266 |
-
return sql_query, f"<p
|
267 |
else:
|
268 |
return sql_query, "", result_df, gr.update(visible=False)
|
269 |
|
@@ -287,8 +303,9 @@ with gr.Blocks(css="""
|
|
287 |
outputs=[sql_query_out, error_out, results_out, error_out],
|
288 |
)
|
289 |
|
290 |
-
|
291 |
-
|
|
|
292 |
|
293 |
# =========================
|
294 |
# Launch the Gradio App
|
|
|
11 |
# Configuration and Setup
|
12 |
# =========================
|
13 |
|
14 |
+
# Load environment variables
|
15 |
+
load_dotenv()
|
16 |
+
|
17 |
# Initialize OpenAI API key
|
18 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
19 |
if not openai.api_key:
|
20 |
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
|
21 |
|
22 |
+
# Initialize the Sandbox
|
23 |
+
sbx = Sandbox() # By default, the sandbox is alive for 5 minutes
|
24 |
|
25 |
# Path to your Parquet dataset
|
26 |
DATASET_PATH = 'hsas.parquet' # Update with your Parquet file path
|
|
|
70 |
]
|
71 |
|
72 |
try:
|
73 |
+
response = openai.ChatCompletion.create(
|
74 |
+
model="gpt-3.5-turbo", # Use a valid and accessible model
|
75 |
messages=messages,
|
76 |
temperature=0,
|
77 |
max_tokens=150,
|
|
|
125 |
with gr.Blocks(css="""
|
126 |
.error-message {
|
127 |
color: red;
|
128 |
+
font-weight: bold;
|
129 |
}
|
130 |
.gradio-container {
|
131 |
+
max-width: 1200px;
|
132 |
margin: auto;
|
133 |
+
font-family: -apple-system, BlinkMacSystemFont, 'San Francisco', 'Helvetica Neue', Helvetica, Arial, sans-serif;
|
134 |
}
|
135 |
.header {
|
136 |
text-align: center;
|
137 |
+
padding: 30px 0;
|
138 |
}
|
139 |
.instructions {
|
140 |
+
margin: 20px 0;
|
141 |
+
font-size: 18px;
|
142 |
+
line-height: 1.6;
|
143 |
}
|
144 |
.example-queries {
|
145 |
margin-bottom: 20px;
|
146 |
}
|
147 |
.button-row {
|
148 |
+
margin-top: 20px;
|
149 |
}
|
150 |
.input-area {
|
151 |
margin-bottom: 20px;
|
|
|
153 |
.schema-tab {
|
154 |
padding: 20px;
|
155 |
}
|
156 |
+
.results {
|
157 |
+
margin-top: 20px;
|
158 |
+
}
|
159 |
+
.copy-button {
|
160 |
+
margin-top: 10px;
|
161 |
+
}
|
162 |
""") as demo:
|
163 |
+
# Header
|
|
|
|
|
164 |
gr.Markdown("""
|
165 |
+
# π₯ Text-to-SQL Healthcare Data Analyst Agent
|
166 |
|
167 |
+
Analyze data from the U.S. Center of Medicare and Medicaid using natural language queries.
|
168 |
|
169 |
+
""", elem_classes="header")
|
170 |
|
171 |
# Instructions
|
172 |
+
gr.Markdown("""
|
173 |
+
### Instructions
|
174 |
+
|
175 |
+
1. **Describe the data you want**: e.g., *"Show total days of care by zip code"*
|
176 |
+
2. **Use Example Queries**: Click on any example query button below to execute
|
177 |
+
3. **Generate SQL**: Or, enter your own query and click **Generate SQL Query**
|
178 |
+
4. **Execute the Query**: Review the generated SQL and click **Execute Query** to see the results
|
179 |
|
|
|
|
|
|
|
180 |
""", elem_classes="instructions")
|
181 |
|
182 |
with gr.Row():
|
183 |
+
with gr.Column(scale=1, min_width=350):
|
184 |
+
gr.Markdown("### π‘ Example Queries", elem_classes="example-queries")
|
185 |
query_buttons = [
|
186 |
"Calculate the average total_charges by zip_cd_of_residence",
|
187 |
"For each zip_cd_of_residence, calculate the sum of total_charges",
|
188 |
"SELECT * FROM hsa_data WHERE total_days_of_care > 40 LIMIT 30;",
|
189 |
]
|
190 |
+
btn_queries = [gr.Button(q, variant="secondary") for q in query_buttons]
|
191 |
+
|
192 |
+
query_input = gr.Textbox(
|
193 |
+
label="π Your Query",
|
194 |
+
placeholder='e.g., "Show total charges over 1M by zip code"',
|
195 |
+
lines=2,
|
196 |
+
interactive=True,
|
197 |
+
elem_classes="input-area"
|
198 |
+
)
|
199 |
|
200 |
with gr.Row(elem_classes="button-row"):
|
201 |
btn_generate_sql = gr.Button("Generate SQL Query", variant="primary")
|
|
|
204 |
sql_query_out = gr.Code(label="π Generated SQL Query", language="sql")
|
205 |
error_out = gr.HTML(elem_classes="error-message", visible=False)
|
206 |
|
207 |
+
with gr.Column(scale=2, min_width=650):
|
208 |
gr.Markdown("### π Query Results", elem_classes="results")
|
209 |
results_out = gr.Dataframe(label="Query Results", interactive=False)
|
|
|
210 |
|
211 |
+
# Copy to Clipboard Button
|
212 |
+
btn_copy_results = gr.Button("Copy Results to Clipboard", variant="secondary", elem_classes="copy-button")
|
213 |
+
|
214 |
+
# JavaScript for copying to clipboard
|
215 |
copy_script = gr.HTML("""
|
216 |
<script>
|
217 |
function copyToClipboard() {
|
218 |
+
const resultsContainer = document.querySelector('div[data-testid="dataframe"] table');
|
219 |
if (resultsContainer) {
|
220 |
const text = Array.from(resultsContainer.rows)
|
221 |
.map(row => Array.from(row.cells)
|
222 |
+
.map(cell => cell.innerText).join("\\t"))
|
223 |
+
.join("\\n");
|
224 |
+
navigator.clipboard.writeText(text).then(function() {
|
225 |
+
alert("Copied results to clipboard!");
|
226 |
+
}, function(err) {
|
227 |
+
alert("Failed to copy results: " + err);
|
228 |
+
});
|
|
|
|
|
229 |
} else {
|
230 |
alert("No results to copy!");
|
231 |
}
|
|
|
235 |
|
236 |
# Connect JavaScript function to button
|
237 |
btn_copy_results.click(
|
238 |
+
_js="copyToClipboard"
|
239 |
)
|
240 |
|
241 |
+
# Dataset Schema Tab
|
242 |
with gr.Tab("π Dataset Schema", elem_classes="schema-tab"):
|
243 |
gr.Markdown("### Dataset Schema")
|
244 |
schema_display = gr.JSON(label="Schema", value=get_schema())
|
|
|
249 |
|
250 |
def generate_sql(nl_query):
|
251 |
if not nl_query.strip():
|
252 |
+
return "", "<p>Please enter a query.</p>", gr.update(visible=True)
|
253 |
sql_query, error = parse_query(nl_query)
|
254 |
if error:
|
255 |
+
return sql_query, f"<p>{error}</p>", gr.update(visible=True)
|
256 |
else:
|
257 |
return sql_query, "", gr.update(visible=False)
|
258 |
|
259 |
def execute_query(sql_query):
|
260 |
if not sql_query.strip():
|
261 |
+
return None, "<p>No SQL query to execute.</p>", gr.update(visible=True)
|
262 |
result_df, error = execute_sql_query(sql_query)
|
263 |
if error:
|
264 |
+
return None, f"<p>{error}</p>", gr.update(visible=True)
|
265 |
else:
|
266 |
return result_df, "", gr.update(visible=False)
|
267 |
|
|
|
270 |
sql_query = example_query
|
271 |
result_df, error = execute_sql_query(sql_query)
|
272 |
if error:
|
273 |
+
return sql_query, f"<p>{error}</p>", None, gr.update(visible=True)
|
274 |
else:
|
275 |
+
return sql_query, "", result_df, gr.update(visible(False))
|
276 |
else:
|
277 |
sql_query, error = parse_query(example_query)
|
278 |
if error:
|
279 |
+
return sql_query, f"<p>{error}</p>", None, gr.update(visible=True)
|
280 |
result_df, exec_error = execute_sql_query(sql_query)
|
281 |
if exec_error:
|
282 |
+
return sql_query, f"<p>{exec_error}</p>", None, gr.update(visible=True)
|
283 |
else:
|
284 |
return sql_query, "", result_df, gr.update(visible=False)
|
285 |
|
|
|
303 |
outputs=[sql_query_out, error_out, results_out, error_out],
|
304 |
)
|
305 |
|
306 |
+
# Hide error message when inputs change
|
307 |
+
query_input.change(fn=lambda: gr.update(visible=False), inputs=None, outputs=[error_out])
|
308 |
+
sql_query_out.change(fn=lambda: gr.update(visible=False), inputs=None, outputs=[error_out])
|
309 |
|
310 |
# =========================
|
311 |
# Launch the Gradio App
|