LeonceNsh commited on
Commit
564d637
·
verified ·
1 Parent(s): 87f90ca

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -0
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import openai
3
+ import gradio as gr
4
+ import duckdb
5
+ from functools import lru_cache
6
+ import os
7
+
8
+ # =========================
9
+ # Configuration and Setup
10
+ # =========================
11
+
12
+ openai.api_key = os.getenv("OPENAI_API_KEY")
13
+ dataset_path = 'hsas.parquet' # Update with your Parquet file path
14
+
15
+ schema = [
16
+ {"column_name": "department_ind_agency", "column_type": "VARCHAR"},
17
+ {"column_name": "cgac", "column_type": "BIGINT"},
18
+ {"column_name": "sub_tier", "column_type": "VARCHAR"},
19
+ {"column_name": "fpds_code", "column_type": "VARCHAR"},
20
+ {"column_name": "office", "column_type": "VARCHAR"},
21
+ {"column_name": "aac_code", "column_type": "VARCHAR"},
22
+ {"column_name": "posteddate", "column_type": "VARCHAR"},
23
+ {"column_name": "type", "column_type": "VARCHAR"},
24
+ {"column_name": "basetype", "column_type": "VARCHAR"},
25
+ {"column_name": "popstreetaddress", "column_type": "VARCHAR"},
26
+ {"column_name": "popcity", "column_type": "VARCHAR"},
27
+ {"column_name": "popstate", "column_type": "VARCHAR"},
28
+ {"column_name": "popzip", "column_type": "VARCHAR"},
29
+ {"column_name": "popcountry", "column_type": "VARCHAR"},
30
+ {"column_name": "active", "column_type": "VARCHAR"},
31
+ {"column_name": "awardnumber", "column_type": "VARCHAR"},
32
+ {"column_name": "awarddate", "column_type": "VARCHAR"},
33
+ {"column_name": "award", "column_type": "DOUBLE"},
34
+ {"column_name": "awardee", "column_type": "VARCHAR"},
35
+ {"column_name": "state", "column_type": "VARCHAR"},
36
+ {"column_name": "city", "column_type": "VARCHAR"},
37
+ {"column_name": "zipcode", "column_type": "VARCHAR"},
38
+ {"column_name": "countrycode", "column_type": "VARCHAR"}
39
+ ]
40
+
41
+ @lru_cache(maxsize=1)
42
+ def get_schema():
43
+ return schema
44
+
45
+ COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
46
+
47
+ # =========================
48
+ # OpenAI API Integration
49
+ # =========================
50
+
51
+ def parse_query(nl_query):
52
+ messages = [
53
+ {"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
54
+ {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
55
+ ]
56
+
57
+ try:
58
+ response = openai.chat.completions.create(
59
+ model="gpt-4",
60
+ messages=messages,
61
+ temperature=0,
62
+ max_tokens=150,
63
+ )
64
+ sql_query = response.choices[0].message.content.strip()
65
+ return sql_query, ""
66
+ except Exception as e:
67
+ return "", f"Error generating SQL query: {e}"
68
+
69
+ # =========================
70
+ # Database Interaction
71
+ # =========================
72
+
73
+ def execute_sql_query(sql_query):
74
+ try:
75
+ con = duckdb.connect()
76
+ con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
77
+ result_df = con.execute(sql_query).fetchdf()
78
+ con.close()
79
+ return result_df, ""
80
+ except Exception as e:
81
+ return None, f"Error executing query: {e}"
82
+
83
+ # =========================
84
+ # Gradio Application UI
85
+ # =========================
86
+
87
+ with gr.Blocks() as demo:
88
+ gr.Markdown("""
89
+ # Use Text to SQL to analyze US Government contract data
90
+
91
+ ## Instructions
92
+
93
+ ### 1. **Describe the data you want**: e.g., `Show awards over 1M in CA`
94
+ ### 2. **Use Example Queries**: Click on any example query button below to execute.
95
+ ### 3. **Generate SQL**: Or, enter your own query and click "Generate SQL" to see the SQL query.
96
+
97
+ ## Example Queries
98
+ """)
99
+
100
+ with gr.Row():
101
+ with gr.Column(scale=1):
102
+
103
+ gr.Markdown("### Click on an example query:")
104
+ with gr.Row():
105
+ btn_example1 = gr.Button("Retrieve the top 15 records from contract_data where basetype is Award Notice, awardee has at least 12 characters, and popcity has more than 5 characters. Exclude the fields sub_tier, popzip, awardnumber, basetype, popstate, active, popcountry, type, countrycode, and popstreetaddress")
106
+ btn_example2 = gr.Button("Show top 10 departments by award amount")
107
+ btn_example3 = gr.Button("Execute: SELECT * from contract_data LIMIT 10;")
108
+
109
+ query_input = gr.Textbox(
110
+ label="Your Query",
111
+ placeholder='e.g., "What are the total awards over 1M in California?"',
112
+ lines=1
113
+ )
114
+
115
+ btn_generate_sql = gr.Button("Generate SQL Query")
116
+ sql_query_out = gr.Code(label="Generated SQL Query", language="sql")
117
+ btn_execute_query = gr.Button("Execute Query")
118
+ error_out = gr.Markdown("", visible=False)
119
+ with gr.Column(scale=2):
120
+ results_out = gr.Dataframe(label="Query Results", interactive=False)
121
+
122
+ with gr.Tab("Dataset Schema"):
123
+ gr.Markdown("### Dataset Schema")
124
+ schema_display = gr.JSON(label="Schema", value=get_schema())
125
+
126
+ # =========================
127
+ # Event Functions
128
+ # =========================
129
+
130
+ def generate_sql(nl_query):
131
+ sql_query, error = parse_query(nl_query)
132
+ return sql_query, error
133
+
134
+ def execute_query(sql_query):
135
+ result_df, error = execute_sql_query(sql_query)
136
+ return result_df, error
137
+
138
+ def handle_example_click(example_query):
139
+ if example_query.strip().upper().startswith("SELECT"):
140
+ sql_query = example_query
141
+ result_df, error = execute_sql_query(sql_query)
142
+ return sql_query, "", result_df, error
143
+ else:
144
+ sql_query, error = parse_query(example_query)
145
+ if error:
146
+ return sql_query, error, None, error
147
+ result_df, exec_error = execute_sql_query(sql_query)
148
+ return sql_query, exec_error, result_df, exec_error
149
+
150
+ # =========================
151
+ # Button Click Event Handlers
152
+ # =========================
153
+
154
+ btn_generate_sql.click(
155
+ fn=generate_sql,
156
+ inputs=query_input,
157
+ outputs=[sql_query_out, error_out]
158
+ )
159
+
160
+ btn_execute_query.click(
161
+ fn=execute_query,
162
+ inputs=sql_query_out,
163
+ outputs=[results_out, error_out]
164
+ )
165
+
166
+ btn_example1.click(
167
+ fn=lambda: handle_example_click("Retrieve the top 15 records from contract_data where basetype is Award Notice, awardee has at least 12 characters, and popcity has more than 5 characters. Exclude the fields sub_tier, popzip, awardnumber, basetype, popstate, active, popcountry, type, countrycode, and popstreetaddress"),
168
+ outputs=[sql_query_out, error_out, results_out, error_out]
169
+ )
170
+ btn_example2.click(
171
+ fn=lambda: handle_example_click("Show top 10 departments by award amount. Round to zero decimal places."),
172
+ outputs=[sql_query_out, error_out, results_out, error_out]
173
+ )
174
+ btn_example3.click(
175
+ fn=lambda: handle_example_click("SELECT * from contract_data LIMIT 10;"),
176
+ outputs=[sql_query_out, error_out, results_out, error_out]
177
+ )
178
+
179
+ # Launch the Gradio App
180
+ demo.launch()