File size: 6,953 Bytes
564d637
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import json
import openai
import gradio as gr
import duckdb
from functools import lru_cache
import os

# =========================
# Configuration and Setup
# =========================

openai.api_key = os.getenv("OPENAI_API_KEY")
dataset_path = 'hsas.parquet'  # Update with your Parquet file path

schema = [
    {"column_name": "department_ind_agency", "column_type": "VARCHAR"},
    {"column_name": "cgac", "column_type": "BIGINT"},
    {"column_name": "sub_tier", "column_type": "VARCHAR"},
    {"column_name": "fpds_code", "column_type": "VARCHAR"},
    {"column_name": "office", "column_type": "VARCHAR"},
    {"column_name": "aac_code", "column_type": "VARCHAR"},
    {"column_name": "posteddate", "column_type": "VARCHAR"},
    {"column_name": "type", "column_type": "VARCHAR"},
    {"column_name": "basetype", "column_type": "VARCHAR"},
    {"column_name": "popstreetaddress", "column_type": "VARCHAR"},
    {"column_name": "popcity", "column_type": "VARCHAR"},
    {"column_name": "popstate", "column_type": "VARCHAR"},
    {"column_name": "popzip", "column_type": "VARCHAR"},
    {"column_name": "popcountry", "column_type": "VARCHAR"},
    {"column_name": "active", "column_type": "VARCHAR"},
    {"column_name": "awardnumber", "column_type": "VARCHAR"},
    {"column_name": "awarddate", "column_type": "VARCHAR"},
    {"column_name": "award", "column_type": "DOUBLE"},
    {"column_name": "awardee", "column_type": "VARCHAR"},
    {"column_name": "state", "column_type": "VARCHAR"},
    {"column_name": "city", "column_type": "VARCHAR"},
    {"column_name": "zipcode", "column_type": "VARCHAR"},
    {"column_name": "countrycode", "column_type": "VARCHAR"}
]

@lru_cache(maxsize=1)
def get_schema():
    return schema

COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}

# =========================
# OpenAI API Integration
# =========================

def parse_query(nl_query):
    messages = [
        {"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
        {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
    ]

    try:
        response = openai.chat.completions.create(
            model="gpt-4",
            messages=messages,
            temperature=0,
            max_tokens=150,
        )
        sql_query = response.choices[0].message.content.strip()
        return sql_query, ""
    except Exception as e:
        return "", f"Error generating SQL query: {e}"

# =========================
# Database Interaction
# =========================

def execute_sql_query(sql_query):
    try:
        con = duckdb.connect()
        con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
        result_df = con.execute(sql_query).fetchdf()
        con.close()
        return result_df, ""
    except Exception as e:
        return None, f"Error executing query: {e}"

# =========================
# Gradio Application UI
# =========================

with gr.Blocks() as demo:
    gr.Markdown("""
    # Use Text to SQL to analyze US Government contract data

    ## Instructions

    ### 1. **Describe the data you want**: e.g., `Show awards over 1M in CA`
    ### 2. **Use Example Queries**: Click on any example query button below to execute.
    ### 3. **Generate SQL**: Or, enter your own query and click "Generate SQL" to see the SQL query.

    ## Example Queries
    """)

    with gr.Row():
        with gr.Column(scale=1):
    
            gr.Markdown("### Click on an example query:")
            with gr.Row():
                btn_example1 = gr.Button("Retrieve the top 15 records from contract_data where basetype is Award Notice, awardee has at least 12 characters, and popcity has more than 5 characters. Exclude the fields sub_tier, popzip, awardnumber, basetype, popstate, active, popcountry, type, countrycode, and popstreetaddress")
                btn_example2 = gr.Button("Show top 10 departments by award amount")
                btn_example3 = gr.Button("Execute: SELECT * from contract_data LIMIT 10;")

            query_input = gr.Textbox(
                label="Your Query",
                placeholder='e.g., "What are the total awards over 1M in California?"',
                lines=1
            )

            btn_generate_sql = gr.Button("Generate SQL Query")
            sql_query_out = gr.Code(label="Generated SQL Query", language="sql")
            btn_execute_query = gr.Button("Execute Query")
            error_out = gr.Markdown("", visible=False)
        with gr.Column(scale=2):
            results_out = gr.Dataframe(label="Query Results", interactive=False)

    with gr.Tab("Dataset Schema"):
        gr.Markdown("### Dataset Schema")
        schema_display = gr.JSON(label="Schema", value=get_schema())

    # =========================
    # Event Functions
    # =========================

    def generate_sql(nl_query):
        sql_query, error = parse_query(nl_query)
        return sql_query, error

    def execute_query(sql_query):
        result_df, error = execute_sql_query(sql_query)
        return result_df, error

    def handle_example_click(example_query):
        if example_query.strip().upper().startswith("SELECT"):
            sql_query = example_query
            result_df, error = execute_sql_query(sql_query)
            return sql_query, "", result_df, error
        else:
            sql_query, error = parse_query(example_query)
            if error:
                return sql_query, error, None, error
            result_df, exec_error = execute_sql_query(sql_query)
            return sql_query, exec_error, result_df, exec_error

    # =========================
    # Button Click Event Handlers
    # =========================

    btn_generate_sql.click(
        fn=generate_sql,
        inputs=query_input,
        outputs=[sql_query_out, error_out]
    )

    btn_execute_query.click(
        fn=execute_query,
        inputs=sql_query_out,
        outputs=[results_out, error_out]
    )

    btn_example1.click(
        fn=lambda: handle_example_click("Retrieve the top 15 records from contract_data where basetype is Award Notice, awardee has at least 12 characters, and popcity has more than 5 characters. Exclude the fields sub_tier, popzip, awardnumber, basetype, popstate, active, popcountry, type, countrycode, and popstreetaddress"),
        outputs=[sql_query_out, error_out, results_out, error_out]
    )
    btn_example2.click(
        fn=lambda: handle_example_click("Show top 10 departments by award amount. Round to zero decimal places."),
        outputs=[sql_query_out, error_out, results_out, error_out]
    )
    btn_example3.click(
        fn=lambda: handle_example_click("SELECT * from contract_data LIMIT 10;"),
        outputs=[sql_query_out, error_out, results_out, error_out]
    )

# Launch the Gradio App
demo.launch()