LeonceNsh commited on
Commit
1ee39a9
·
verified ·
1 Parent(s): 8c62fa6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -42
app.py CHANGED
@@ -1,18 +1,18 @@
 
1
  import json
2
  import openai
3
- import gradio as gr
4
  import duckdb
 
5
  from functools import lru_cache
6
- import os
7
 
8
  # =========================
9
  # Configuration and Setup
10
  # =========================
11
 
12
  openai.api_key = os.getenv("OPENAI_API_KEY")
13
- dataset_path = 'hsas.parquet' # Update with your Parquet file path
14
 
15
- schema = [
16
  {"column_name": "total_charges", "column_type": "BIGINT"},
17
  {"column_name": "medicare_prov_num", "column_type": "BIGINT"},
18
  {"column_name": "zip_cd_of_residence", "column_type": "VARCHAR"},
@@ -22,7 +22,7 @@ schema = [
22
 
23
  @lru_cache(maxsize=1)
24
  def get_schema():
25
- return schema
26
 
27
  COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
28
 
@@ -32,13 +32,22 @@ COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
32
 
33
  def parse_query(nl_query):
34
  messages = [
35
- {"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'hsa_data' table."},
36
- {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
 
 
 
 
 
 
 
 
 
37
  ]
38
 
39
  try:
40
  response = openai.chat.completions.create(
41
- model="gpt-4",
42
  messages=messages,
43
  temperature=0,
44
  max_tokens=150,
@@ -54,8 +63,8 @@ def parse_query(nl_query):
54
 
55
  def execute_sql_query(sql_query):
56
  try:
57
- con = duckdb.connect()
58
- con.execute(f"CREATE OR REPLACE VIEW hsa_data AS SELECT * FROM '{dataset_path}'")
59
  result_df = con.execute(sql_query).fetchdf()
60
  con.close()
61
  return result_df, ""
@@ -68,40 +77,41 @@ def execute_sql_query(sql_query):
68
 
69
  with gr.Blocks() as demo:
70
  gr.Markdown("""
71
- # Text to SQL healthcare AI data Analyst agent to analyze U.S prescription data from the Center of Medicare and Medicaid
72
-
73
- # Replicate papers from academic journals on prescription drug prices
74
 
75
  ## Instructions
76
 
77
- ### 1. **Describe the data you want**: e.g., `Show total days of care by zip`
78
- ### 2. **Use Example Queries**: Click on any example query button below to execute.
79
- ### 3. **Generate SQL**: Or, enter your own query and click "Generate SQL" to see the SQL query.
80
 
81
  ## Example Queries
82
  """)
83
 
84
  with gr.Row():
85
  with gr.Column(scale=1):
86
-
87
- gr.Markdown("### Click on an example query:")
88
- with gr.Row():
89
- btn_example1 = gr.Button("Calculate the average total_charges by zip_cd_of_residence")
90
- btn_example2 = gr.Button("For each zip_cd_of_residence, calculate the sum of total_charges")
91
- btn_example3 = gr.Button("SELECT * from hsa_data where total_days_of_care > 40 LIMIT 30;")
 
92
 
93
  query_input = gr.Textbox(
94
  label="Your Query",
95
- placeholder='e.g., "What are the total awards over 1M in California?"',
96
- lines=1
97
  )
98
 
99
  btn_generate_sql = gr.Button("Generate SQL Query")
100
  sql_query_out = gr.Code(label="Generated SQL Query", language="sql")
101
  btn_execute_query = gr.Button("Execute Query")
102
- error_out = gr.Markdown("", visible=False)
103
  with gr.Column(scale=2):
104
- results_out = gr.Dataframe(label="Query Results", interactive=False)
105
 
106
  with gr.Tab("Dataset Schema"):
107
  gr.Markdown("### Dataset Schema")
@@ -113,22 +123,27 @@ with gr.Blocks() as demo:
113
 
114
  def generate_sql(nl_query):
115
  sql_query, error = parse_query(nl_query)
 
116
  return sql_query, error
117
 
118
  def execute_query(sql_query):
119
  result_df, error = execute_sql_query(sql_query)
 
120
  return result_df, error
121
 
122
  def handle_example_click(example_query):
123
  if example_query.strip().upper().startswith("SELECT"):
124
  sql_query = example_query
125
  result_df, error = execute_sql_query(sql_query)
126
- return sql_query, "", result_df, error
 
127
  else:
128
  sql_query, error = parse_query(example_query)
129
  if error:
 
130
  return sql_query, error, None, error
131
  result_df, exec_error = execute_sql_query(sql_query)
 
132
  return sql_query, exec_error, result_df, exec_error
133
 
134
  # =========================
@@ -138,27 +153,21 @@ with gr.Blocks() as demo:
138
  btn_generate_sql.click(
139
  fn=generate_sql,
140
  inputs=query_input,
141
- outputs=[sql_query_out, error_out]
142
  )
143
 
144
  btn_execute_query.click(
145
  fn=execute_query,
146
  inputs=sql_query_out,
147
- outputs=[results_out, error_out]
148
  )
149
 
150
- btn_example1.click(
151
- fn=lambda: handle_example_click("Calculate the average total_charges by zip_cd_of_residence"),
152
- outputs=[sql_query_out, error_out, results_out, error_out]
153
- )
154
- btn_example2.click(
155
- fn=lambda: handle_example_click("For each zip_cd_of_residence, calculate the sum of total_charges"),
156
- outputs=[sql_query_out, error_out, results_out, error_out]
157
- )
158
- btn_example3.click(
159
- fn=lambda: handle_example_click("SELECT * from hsa_data where total_days_of_care > 40 LIMIT 30;"),
160
- outputs=[sql_query_out, error_out, results_out, error_out]
161
- )
162
 
163
  # Launch the Gradio App
164
- demo.launch()
 
 
1
+ import os
2
  import json
3
  import openai
 
4
  import duckdb
5
+ import gradio as gr
6
  from functools import lru_cache
 
7
 
8
  # =========================
9
  # Configuration and Setup
10
  # =========================
11
 
12
  openai.api_key = os.getenv("OPENAI_API_KEY")
13
+ DATASET_PATH = 'hsas.parquet' # Update with your Parquet file path
14
 
15
+ SCHEMA = [
16
  {"column_name": "total_charges", "column_type": "BIGINT"},
17
  {"column_name": "medicare_prov_num", "column_type": "BIGINT"},
18
  {"column_name": "zip_cd_of_residence", "column_type": "VARCHAR"},
 
22
 
23
  @lru_cache(maxsize=1)
24
  def get_schema():
25
+ return SCHEMA
26
 
27
  COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
28
 
 
32
 
33
  def parse_query(nl_query):
34
  messages = [
35
+ {
36
+ "role": "system",
37
+ "content": (
38
+ "You are an assistant that converts natural language queries into SQL queries for the 'hsa_data' table. "
39
+ "Ensure the SQL query is syntactically correct and uses only the columns provided in the schema."
40
+ ),
41
+ },
42
+ {
43
+ "role": "user",
44
+ "content": f"Schema:\n{json.dumps(get_schema(), indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:",
45
+ },
46
  ]
47
 
48
  try:
49
  response = openai.chat.completions.create(
50
+ model="gpt-4o-mini",
51
  messages=messages,
52
  temperature=0,
53
  max_tokens=150,
 
63
 
64
  def execute_sql_query(sql_query):
65
  try:
66
+ con = duckdb.connect(database=':memory:')
67
+ con.execute(f"CREATE OR REPLACE VIEW hsa_data AS SELECT * FROM '{DATASET_PATH}'")
68
  result_df = con.execute(sql_query).fetchdf()
69
  con.close()
70
  return result_df, ""
 
77
 
78
  with gr.Blocks() as demo:
79
  gr.Markdown("""
80
+ # Text-to-SQL Healthcare Data Analyst Agent
81
+
82
+ Analyze U.S. prescription data from the Center of Medicare and Medicaid.
83
 
84
  ## Instructions
85
 
86
+ 1. **Describe the data you want**: e.g., `Show total days of care by zip`
87
+ 2. **Use Example Queries**: Click on any example query button below to execute.
88
+ 3. **Generate SQL**: Or, enter your own query and click "Generate SQL".
89
 
90
  ## Example Queries
91
  """)
92
 
93
  with gr.Row():
94
  with gr.Column(scale=1):
95
+ gr.Markdown("### Example Queries:")
96
+ query_buttons = [
97
+ "Calculate the average total_charges by zip_cd_of_residence",
98
+ "For each zip_cd_of_residence, calculate the sum of total_charges",
99
+ "SELECT * FROM hsa_data WHERE total_days_of_care > 40 LIMIT 30;",
100
+ ]
101
+ btn_queries = [gr.Button(q) for q in query_buttons]
102
 
103
  query_input = gr.Textbox(
104
  label="Your Query",
105
+ placeholder='e.g., "Show total charges over 1M by state"',
106
+ lines=1,
107
  )
108
 
109
  btn_generate_sql = gr.Button("Generate SQL Query")
110
  sql_query_out = gr.Code(label="Generated SQL Query", language="sql")
111
  btn_execute_query = gr.Button("Execute Query")
112
+ error_out = gr.Markdown(visible=False)
113
  with gr.Column(scale=2):
114
+ results_out = gr.Dataframe(label="Query Results")
115
 
116
  with gr.Tab("Dataset Schema"):
117
  gr.Markdown("### Dataset Schema")
 
123
 
124
  def generate_sql(nl_query):
125
  sql_query, error = parse_query(nl_query)
126
+ error_out.update(visible=bool(error))
127
  return sql_query, error
128
 
129
  def execute_query(sql_query):
130
  result_df, error = execute_sql_query(sql_query)
131
+ error_out.update(visible=bool(error))
132
  return result_df, error
133
 
134
  def handle_example_click(example_query):
135
  if example_query.strip().upper().startswith("SELECT"):
136
  sql_query = example_query
137
  result_df, error = execute_sql_query(sql_query)
138
+ error_out.update(visible=bool(error))
139
+ return sql_query, "", result_df, ""
140
  else:
141
  sql_query, error = parse_query(example_query)
142
  if error:
143
+ error_out.update(visible=True)
144
  return sql_query, error, None, error
145
  result_df, exec_error = execute_sql_query(sql_query)
146
+ error_out.update(visible=bool(exec_error))
147
  return sql_query, exec_error, result_df, exec_error
148
 
149
  # =========================
 
153
  btn_generate_sql.click(
154
  fn=generate_sql,
155
  inputs=query_input,
156
+ outputs=[sql_query_out, error_out],
157
  )
158
 
159
  btn_execute_query.click(
160
  fn=execute_query,
161
  inputs=sql_query_out,
162
+ outputs=[results_out, error_out],
163
  )
164
 
165
+ for btn, query in zip(btn_queries, query_buttons):
166
+ btn.click(
167
+ fn=lambda q=query: handle_example_click(q),
168
+ outputs=[sql_query_out, error_out, results_out, error_out],
169
+ )
 
 
 
 
 
 
 
170
 
171
  # Launch the Gradio App
172
+ if __name__ == "__main__":
173
+ demo.launch()