LeonceNsh's picture
Update app.py
f5a9d48 verified
raw
history blame
5.55 kB
import json
import openai
import gradio as gr
import duckdb
from functools import lru_cache
import pandas as pd
import plotly.express as px
import os
# =========================
# Configuration and Setup
# =========================
openai.api_key = os.getenv("OPENAI_API_KEY")
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
schema = [
{"column_name": "department_ind_agency", "column_type": "VARCHAR"},
{"column_name": "cgac", "column_type": "BIGINT"},
{"column_name": "sub_tier", "column_type": "VARCHAR"},
{"column_name": "fpds_code", "column_type": "VARCHAR"},
{"column_name": "office", "column_type": "VARCHAR"},
{"column_name": "aac_code", "column_type": "VARCHAR"},
{"column_name": "posteddate", "column_type": "VARCHAR"},
{"column_name": "type", "column_type": "VARCHAR"},
{"column_name": "basetype", "column_type": "VARCHAR"},
{"column_name": "popstreetaddress", "column_type": "VARCHAR"},
{"column_name": "popcity", "column_type": "VARCHAR"},
{"column_name": "popstate", "column_type": "VARCHAR"},
{"column_name": "popzip", "column_type": "VARCHAR"},
{"column_name": "popcountry", "column_type": "VARCHAR"},
{"column_name": "active", "column_type": "VARCHAR"},
{"column_name": "awardnumber", "column_type": "VARCHAR"},
{"column_name": "awarddate", "column_type": "VARCHAR"},
{"column_name": "award", "column_type": "DOUBLE"},
{"column_name": "awardee", "column_type": "VARCHAR"},
{"column_name": "state", "column_type": "VARCHAR"},
{"column_name": "city", "column_type": "VARCHAR"},
{"column_name": "zipcode", "column_type": "VARCHAR"},
{"column_name": "countrycode", "column_type": "VARCHAR"}
]
@lru_cache(maxsize=1)
def get_schema():
return schema
COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
# =========================
# Database Interaction
# =========================
def load_dataset_schema():
con = duckdb.connect()
try:
con.execute("DROP VIEW IF EXISTS contract_data")
con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
return True
except Exception as e:
print(f"Error loading dataset schema: {e}")
return False
finally:
con.close()
# =========================
# OpenAI API Integration
# =========================
def parse_query(nl_query):
messages = [
{"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
{"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
]
try:
response = openai.ChatCompletion.create(
model="gpt-4",
messages=messages,
temperature=0,
max_tokens=150,
)
sql_query = response.choices[0].message['content'].strip()
return sql_query
except Exception as e:
return f"Error generating SQL query: {e}"
def detect_plot_intent(nl_query):
plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize']
return any(keyword in nl_query.lower() for keyword in plot_keywords)
# =========================
# Gradio Application UI
# =========================
with gr.Blocks() as demo:
gr.Markdown("""
# Parquet SQL Query and Plotting App
**Query and visualize data** in `sample_contract_df.parquet`
## Instructions
1. **Describe the data you want**: e.g., `Show awards over 1M in CA`
2. **Use Example Queries**: Click on any example query button below to execute.
3. **Generate SQL**: Or, enter your own query and click "Generate SQL" to see the SQL query.
4. **Execute Query**: Run the query to view results and plots.
5. **Dataset Schema**: See available columns and types in the "Schema" tab.
## Example Queries
""")
with gr.Row():
with gr.Column(scale=1):
query = gr.Textbox(
label="Your Query",
placeholder='e.g., "What are the total awards over 1M in California?"',
lines=1
)
# Button to generate the SQL query from NL
btn_generate_sql = gr.Button("Generate SQL Query")
# Textbox to display generated SQL
sql_query_out = gr.Textbox(label="Generated SQL Query", interactive=False)
# Execute button
btn_execute_query = gr.Button("Execute Query")
error_out = gr.Markdown("", visible=False)
# Results and Plot output
results_out = gr.DataFrame(label="Query Results")
plot_out = gr.Plot(label="Plot")
# =========================
# Event Functions
# =========================
def generate_sql(nl_query):
sql_query = parse_query(nl_query)
return sql_query
def execute_sql_query(sql_query):
try:
con = duckdb.connect()
con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
result_df = con.execute(sql_query).fetchdf()
con.close()
return result_df, ""
except Exception as e:
return None, f"Error executing query: {e}"
# Button click event handlers
btn_generate_sql.click(fn=generate_sql, inputs=query, outputs=sql_query_out)
btn_execute_query.click(fn=execute_sql_query, inputs=sql_query_out, outputs=[results_out, error_out])
# Launch the Gradio App
demo.launch()