File size: 6,984 Bytes
b474ae1
ec9d21a
06f01b3
b474ae1
d33fe62
1fa796c
5b4c268
ae610aa
 
 
 
94bf8f1
f146007
5b4c268
d33fe62
5a73339
 
92494e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d33fe62
 
 
 
 
 
04fd164
 
ae610aa
 
 
 
c490b83
a1792a1
94bf8f1
88c83f6
a1792a1
1fa796c
dfe1769
9977fca
2cc33e1
13f0f94
88c83f6
dfe1769
 
238955b
04fd164
dfe1769
04fd164
dfe1769
b89b3ba
 
 
 
 
 
 
 
 
 
 
 
 
dfe1769
ae610aa
 
 
 
f5a9d48
06f01b3
f5a9d48
 
0d87975
f5a9d48
 
 
0d87975
 
 
f5a9d48
 
78f16f0
b474ae1
c490b83
 
efc74be
b89b3ba
 
cf2c742
0925d63
 
b89b3ba
efc74be
 
 
 
 
 
f5a9d48
b89b3ba
f5a9d48
 
b89b3ba
 
 
04fd164
b89b3ba
 
00c05fa
94bf8f1
f5a9d48
94bf8f1
dfe1769
04fd164
 
 
94bf8f1
04fd164
 
 
 
 
 
 
 
 
c27620c
04fd164
 
 
 
 
c27620c
04fd164
 
 
c27620c
04fd164
 
 
 
 
 
 
 
 
 
 
 
 
cf2c742
04fd164
 
 
 
 
 
0925d63
04fd164
 
 
c27620c
f5a9d48
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import json
import openai
import gradio as gr
import duckdb
from functools import lru_cache
import os

# =========================
# Configuration and Setup
# =========================

openai.api_key = os.getenv("OPENAI_API_KEY")
dataset_path = 'sample_contract_df.parquet'  # Update with your Parquet file path

schema = [
    {"column_name": "department_ind_agency", "column_type": "VARCHAR"},
    {"column_name": "cgac", "column_type": "BIGINT"},
    {"column_name": "sub_tier", "column_type": "VARCHAR"},
    {"column_name": "fpds_code", "column_type": "VARCHAR"},
    {"column_name": "office", "column_type": "VARCHAR"},
    {"column_name": "aac_code", "column_type": "VARCHAR"},
    {"column_name": "posteddate", "column_type": "VARCHAR"},
    {"column_name": "type", "column_type": "VARCHAR"},
    {"column_name": "basetype", "column_type": "VARCHAR"},
    {"column_name": "popstreetaddress", "column_type": "VARCHAR"},
    {"column_name": "popcity", "column_type": "VARCHAR"},
    {"column_name": "popstate", "column_type": "VARCHAR"},
    {"column_name": "popzip", "column_type": "VARCHAR"},
    {"column_name": "popcountry", "column_type": "VARCHAR"},
    {"column_name": "active", "column_type": "VARCHAR"},
    {"column_name": "awardnumber", "column_type": "VARCHAR"},
    {"column_name": "awarddate", "column_type": "VARCHAR"},
    {"column_name": "award", "column_type": "DOUBLE"},
    {"column_name": "awardee", "column_type": "VARCHAR"},
    {"column_name": "state", "column_type": "VARCHAR"},
    {"column_name": "city", "column_type": "VARCHAR"},
    {"column_name": "zipcode", "column_type": "VARCHAR"},
    {"column_name": "countrycode", "column_type": "VARCHAR"}
]

@lru_cache(maxsize=1)
def get_schema():
    return schema

COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}

# =========================
# OpenAI API Integration
# =========================

def parse_query(nl_query):
    messages = [
        {"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
        {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
    ]

    try:
        response = openai.chat.completions.create(
            model="gpt-4",
            messages=messages,
            temperature=0,
            max_tokens=150,
        )
        sql_query = response.choices[0].message.content.strip()
        return sql_query, ""
    except Exception as e:
        return "", f"Error generating SQL query: {e}"

# =========================
# Database Interaction
# =========================

def execute_sql_query(sql_query):
    try:
        con = duckdb.connect()
        con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
        result_df = con.execute(sql_query).fetchdf()
        con.close()
        return result_df, ""
    except Exception as e:
        return None, f"Error executing query: {e}"

# =========================
# Gradio Application UI
# =========================

with gr.Blocks() as demo:
    gr.Markdown("""
    # Parquet SQL Query and Plotting App

    ## **Query and visualize data** in `sample_contract_df.parquet`

    ## Instructions

    ### 1. **Describe the data you want**: e.g., `Show awards over 1M in CA`
    ### 2. **Use Example Queries**: Click on any example query button below to execute.
    ### 3. **Generate SQL**: Or, enter your own query and click "Generate SQL" to see the SQL query.

    ## Example Queries
    """)

    with gr.Row():
        with gr.Column(scale=1):
    
            gr.Markdown("### Click on an example query:")
            with gr.Row():
                btn_example1 = gr.Button("Retrieve the top 15 records from contract_data where basetype is Award Notice, awardee has at least 12 characters, and popcity has more than 5 characters. Exclude the fields sub_tier, popzip, awardnumber, basetype, popstate, active, popcountry, type, countrycode, and popstreetaddress")
                btn_example2 = gr.Button("Show top 5 departments by award amount")
                btn_example3 = gr.Button("Execute: SELECT * from contract_data LIMIT 10;")

            query_input = gr.Textbox(
                label="Your Query",
                placeholder='e.g., "What are the total awards over 1M in California?"',
                lines=1
            )

            btn_generate_sql = gr.Button("Generate SQL Query")
            sql_query_out = gr.Code(label="Generated SQL Query", language="sql")
            btn_execute_query = gr.Button("Execute Query")
            error_out = gr.Markdown("", visible=False)
        with gr.Column(scale=2):
            results_out = gr.Dataframe(label="Query Results", interactive=False)

    with gr.Tab("Dataset Schema"):
        gr.Markdown("### Dataset Schema")
        schema_display = gr.JSON(label="Schema", value=get_schema())

    # =========================
    # Event Functions
    # =========================

    def generate_sql(nl_query):
        sql_query, error = parse_query(nl_query)
        return sql_query, error

    def execute_query(sql_query):
        result_df, error = execute_sql_query(sql_query)
        return result_df, error

    def handle_example_click(example_query):
        if example_query.strip().upper().startswith("SELECT"):
            sql_query = example_query
            result_df, error = execute_sql_query(sql_query)
            return sql_query, "", result_df, error
        else:
            sql_query, error = parse_query(example_query)
            if error:
                return sql_query, error, None, error
            result_df, exec_error = execute_sql_query(sql_query)
            return sql_query, exec_error, result_df, exec_error

    # =========================
    # Button Click Event Handlers
    # =========================

    btn_generate_sql.click(
        fn=generate_sql,
        inputs=query_input,
        outputs=[sql_query_out, error_out]
    )

    btn_execute_query.click(
        fn=execute_query,
        inputs=sql_query_out,
        outputs=[results_out, error_out]
    )

    btn_example1.click(
        fn=lambda: handle_example_click("Retrieve the top 15 records from contract_data where basetype is Award Notice, awardee has at least 12 characters, and popcity has more than 5 characters. Exclude the fields sub_tier, popzip, awardnumber, basetype, popstate, active, popcountry, type, countrycode, and popstreetaddress"),
        outputs=[sql_query_out, error_out, results_out, error_out]
    )
    btn_example2.click(
        fn=lambda: handle_example_click("Show top 5 departments by award amount"),
        outputs=[sql_query_out, error_out, results_out, error_out]
    )
    btn_example3.click(
        fn=lambda: handle_example_click("SELECT * from contract_data LIMIT 10;"),
        outputs=[sql_query_out, error_out, results_out, error_out]
    )

# Launch the Gradio App
demo.launch()