File size: 5,554 Bytes
b474ae1
ec9d21a
06f01b3
b474ae1
d33fe62
dfe1769
 
1fa796c
5b4c268
ae610aa
 
 
 
94bf8f1
f146007
5b4c268
d33fe62
5a73339
 
92494e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d33fe62
 
 
 
 
 
 
 
f5a9d48
 
 
 
f146007
b474ae1
d33fe62
 
 
 
 
 
 
 
 
5b4c268
ae610aa
 
 
 
c490b83
a1792a1
94bf8f1
88c83f6
a1792a1
1fa796c
dfe1769
f5a9d48
78f16f0
13f0f94
88c83f6
dfe1769
 
f5a9d48
dfe1769
 
 
 
 
f5a9d48
88c83f6
dfe1769
ae610aa
 
 
 
f5a9d48
06f01b3
f5a9d48
 
 
 
 
 
 
 
 
 
 
 
 
78f16f0
b474ae1
c490b83
 
 
94bf8f1
c490b83
 
 
f5a9d48
 
 
 
 
 
 
 
 
 
 
00c05fa
94bf8f1
f5a9d48
94bf8f1
dfe1769
f5a9d48
c490b83
f5a9d48
94bf8f1
f5a9d48
 
 
 
 
 
 
 
 
 
 
 
 
00c05fa
f5a9d48
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import json
import openai
import gradio as gr
import duckdb
from functools import lru_cache
import pandas as pd
import plotly.express as px
import os

# =========================
# Configuration and Setup
# =========================

openai.api_key = os.getenv("OPENAI_API_KEY")
dataset_path = 'sample_contract_df.parquet'  # Update with your Parquet file path

schema = [
    {"column_name": "department_ind_agency", "column_type": "VARCHAR"},
    {"column_name": "cgac", "column_type": "BIGINT"},
    {"column_name": "sub_tier", "column_type": "VARCHAR"},
    {"column_name": "fpds_code", "column_type": "VARCHAR"},
    {"column_name": "office", "column_type": "VARCHAR"},
    {"column_name": "aac_code", "column_type": "VARCHAR"},
    {"column_name": "posteddate", "column_type": "VARCHAR"},
    {"column_name": "type", "column_type": "VARCHAR"},
    {"column_name": "basetype", "column_type": "VARCHAR"},
    {"column_name": "popstreetaddress", "column_type": "VARCHAR"},
    {"column_name": "popcity", "column_type": "VARCHAR"},
    {"column_name": "popstate", "column_type": "VARCHAR"},
    {"column_name": "popzip", "column_type": "VARCHAR"},
    {"column_name": "popcountry", "column_type": "VARCHAR"},
    {"column_name": "active", "column_type": "VARCHAR"},
    {"column_name": "awardnumber", "column_type": "VARCHAR"},
    {"column_name": "awarddate", "column_type": "VARCHAR"},
    {"column_name": "award", "column_type": "DOUBLE"},
    {"column_name": "awardee", "column_type": "VARCHAR"},
    {"column_name": "state", "column_type": "VARCHAR"},
    {"column_name": "city", "column_type": "VARCHAR"},
    {"column_name": "zipcode", "column_type": "VARCHAR"},
    {"column_name": "countrycode", "column_type": "VARCHAR"}
]

@lru_cache(maxsize=1)
def get_schema():
    return schema

COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}

# =========================
# Database Interaction
# =========================

def load_dataset_schema():
    con = duckdb.connect()
    try:
        con.execute("DROP VIEW IF EXISTS contract_data")
        con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
        return True
    except Exception as e:
        print(f"Error loading dataset schema: {e}")
        return False
    finally:
        con.close()

# =========================
# OpenAI API Integration
# =========================

def parse_query(nl_query):
    messages = [
        {"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
        {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
    ]

    try:
        response = openai.ChatCompletion.create(
            model="gpt-4",
            messages=messages,
            temperature=0,
            max_tokens=150,
        )
        sql_query = response.choices[0].message['content'].strip()
        return sql_query
    except Exception as e:
        return f"Error generating SQL query: {e}"

def detect_plot_intent(nl_query):
    plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize']
    return any(keyword in nl_query.lower() for keyword in plot_keywords)

# =========================
# Gradio Application UI
# =========================

with gr.Blocks() as demo:
    gr.Markdown("""
    # Parquet SQL Query and Plotting App

    **Query and visualize data** in `sample_contract_df.parquet`

    ## Instructions

    1. **Describe the data you want**: e.g., `Show awards over 1M in CA`
    2. **Use Example Queries**: Click on any example query button below to execute.
    3. **Generate SQL**: Or, enter your own query and click "Generate SQL" to see the SQL query.
    4. **Execute Query**: Run the query to view results and plots.
    5. **Dataset Schema**: See available columns and types in the "Schema" tab.

    ## Example Queries
    """)

    with gr.Row():
        with gr.Column(scale=1):
            query = gr.Textbox(
                label="Your Query",
                placeholder='e.g., "What are the total awards over 1M in California?"',
                lines=1
            )
            # Button to generate the SQL query from NL
            btn_generate_sql = gr.Button("Generate SQL Query")
            # Textbox to display generated SQL
            sql_query_out = gr.Textbox(label="Generated SQL Query", interactive=False)
            # Execute button
            btn_execute_query = gr.Button("Execute Query")
            error_out = gr.Markdown("", visible=False)
        
        # Results and Plot output
        results_out = gr.DataFrame(label="Query Results")
        plot_out = gr.Plot(label="Plot")

    # =========================
    # Event Functions
    # =========================

    def generate_sql(nl_query):
        sql_query = parse_query(nl_query)
        return sql_query

    def execute_sql_query(sql_query):
        try:
            con = duckdb.connect()
            con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
            result_df = con.execute(sql_query).fetchdf()
            con.close()
            return result_df, ""
        except Exception as e:
            return None, f"Error executing query: {e}"

    # Button click event handlers
    btn_generate_sql.click(fn=generate_sql, inputs=query, outputs=sql_query_out)
    btn_execute_query.click(fn=execute_sql_query, inputs=sql_query_out, outputs=[results_out, error_out])

# Launch the Gradio App
demo.launch()