File size: 6,953 Bytes
564d637 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import json
import openai
import gradio as gr
import duckdb
from functools import lru_cache
import os
# =========================
# Configuration and Setup
# =========================
openai.api_key = os.getenv("OPENAI_API_KEY")
dataset_path = 'hsas.parquet' # Update with your Parquet file path
schema = [
{"column_name": "department_ind_agency", "column_type": "VARCHAR"},
{"column_name": "cgac", "column_type": "BIGINT"},
{"column_name": "sub_tier", "column_type": "VARCHAR"},
{"column_name": "fpds_code", "column_type": "VARCHAR"},
{"column_name": "office", "column_type": "VARCHAR"},
{"column_name": "aac_code", "column_type": "VARCHAR"},
{"column_name": "posteddate", "column_type": "VARCHAR"},
{"column_name": "type", "column_type": "VARCHAR"},
{"column_name": "basetype", "column_type": "VARCHAR"},
{"column_name": "popstreetaddress", "column_type": "VARCHAR"},
{"column_name": "popcity", "column_type": "VARCHAR"},
{"column_name": "popstate", "column_type": "VARCHAR"},
{"column_name": "popzip", "column_type": "VARCHAR"},
{"column_name": "popcountry", "column_type": "VARCHAR"},
{"column_name": "active", "column_type": "VARCHAR"},
{"column_name": "awardnumber", "column_type": "VARCHAR"},
{"column_name": "awarddate", "column_type": "VARCHAR"},
{"column_name": "award", "column_type": "DOUBLE"},
{"column_name": "awardee", "column_type": "VARCHAR"},
{"column_name": "state", "column_type": "VARCHAR"},
{"column_name": "city", "column_type": "VARCHAR"},
{"column_name": "zipcode", "column_type": "VARCHAR"},
{"column_name": "countrycode", "column_type": "VARCHAR"}
]
@lru_cache(maxsize=1)
def get_schema():
return schema
COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
# =========================
# OpenAI API Integration
# =========================
def parse_query(nl_query):
messages = [
{"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
{"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
]
try:
response = openai.chat.completions.create(
model="gpt-4",
messages=messages,
temperature=0,
max_tokens=150,
)
sql_query = response.choices[0].message.content.strip()
return sql_query, ""
except Exception as e:
return "", f"Error generating SQL query: {e}"
# =========================
# Database Interaction
# =========================
def execute_sql_query(sql_query):
try:
con = duckdb.connect()
con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
result_df = con.execute(sql_query).fetchdf()
con.close()
return result_df, ""
except Exception as e:
return None, f"Error executing query: {e}"
# =========================
# Gradio Application UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("""
# Use Text to SQL to analyze US Government contract data
## Instructions
### 1. **Describe the data you want**: e.g., `Show awards over 1M in CA`
### 2. **Use Example Queries**: Click on any example query button below to execute.
### 3. **Generate SQL**: Or, enter your own query and click "Generate SQL" to see the SQL query.
## Example Queries
""")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Click on an example query:")
with gr.Row():
btn_example1 = gr.Button("Retrieve the top 15 records from contract_data where basetype is Award Notice, awardee has at least 12 characters, and popcity has more than 5 characters. Exclude the fields sub_tier, popzip, awardnumber, basetype, popstate, active, popcountry, type, countrycode, and popstreetaddress")
btn_example2 = gr.Button("Show top 10 departments by award amount")
btn_example3 = gr.Button("Execute: SELECT * from contract_data LIMIT 10;")
query_input = gr.Textbox(
label="Your Query",
placeholder='e.g., "What are the total awards over 1M in California?"',
lines=1
)
btn_generate_sql = gr.Button("Generate SQL Query")
sql_query_out = gr.Code(label="Generated SQL Query", language="sql")
btn_execute_query = gr.Button("Execute Query")
error_out = gr.Markdown("", visible=False)
with gr.Column(scale=2):
results_out = gr.Dataframe(label="Query Results", interactive=False)
with gr.Tab("Dataset Schema"):
gr.Markdown("### Dataset Schema")
schema_display = gr.JSON(label="Schema", value=get_schema())
# =========================
# Event Functions
# =========================
def generate_sql(nl_query):
sql_query, error = parse_query(nl_query)
return sql_query, error
def execute_query(sql_query):
result_df, error = execute_sql_query(sql_query)
return result_df, error
def handle_example_click(example_query):
if example_query.strip().upper().startswith("SELECT"):
sql_query = example_query
result_df, error = execute_sql_query(sql_query)
return sql_query, "", result_df, error
else:
sql_query, error = parse_query(example_query)
if error:
return sql_query, error, None, error
result_df, exec_error = execute_sql_query(sql_query)
return sql_query, exec_error, result_df, exec_error
# =========================
# Button Click Event Handlers
# =========================
btn_generate_sql.click(
fn=generate_sql,
inputs=query_input,
outputs=[sql_query_out, error_out]
)
btn_execute_query.click(
fn=execute_query,
inputs=sql_query_out,
outputs=[results_out, error_out]
)
btn_example1.click(
fn=lambda: handle_example_click("Retrieve the top 15 records from contract_data where basetype is Award Notice, awardee has at least 12 characters, and popcity has more than 5 characters. Exclude the fields sub_tier, popzip, awardnumber, basetype, popstate, active, popcountry, type, countrycode, and popstreetaddress"),
outputs=[sql_query_out, error_out, results_out, error_out]
)
btn_example2.click(
fn=lambda: handle_example_click("Show top 10 departments by award amount. Round to zero decimal places."),
outputs=[sql_query_out, error_out, results_out, error_out]
)
btn_example3.click(
fn=lambda: handle_example_click("SELECT * from contract_data LIMIT 10;"),
outputs=[sql_query_out, error_out, results_out, error_out]
)
# Launch the Gradio App
demo.launch() |