LeonceNsh's picture
Update app.py
94bf8f1 verified
raw
history blame
8.55 kB
import json
import openai
import gradio as gr
import duckdb
from functools import lru_cache
import pandas as pd
import plotly.express as px
import os
# =========================
# Configuration and Setup
# =========================
# Set OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")
# Load the Parquet dataset path
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
# Provided schema
schema = [
{"column_name": "department_ind_agency", "column_type": "VARCHAR"},
{"column_name": "cgac", "column_type": "BIGINT"},
{"column_name": "sub_tier", "column_type": "VARCHAR"},
{"column_name": "fpds_code", "column_type": "VARCHAR"},
{"column_name": "office", "column_type": "VARCHAR"},
{"column_name": "aac_code", "column_type": "VARCHAR"},
{"column_name": "posteddate", "column_type": "VARCHAR"},
{"column_name": "type", "column_type": "VARCHAR"},
{"column_name": "basetype", "column_type": "VARCHAR"},
{"column_name": "popstreetaddress", "column_type": "VARCHAR"},
{"column_name": "popcity", "column_type": "VARCHAR"},
{"column_name": "popstate", "column_type": "VARCHAR"},
{"column_name": "popzip", "column_type": "VARCHAR"},
{"column_name": "popcountry", "column_type": "VARCHAR"},
{"column_name": "active", "column_type": "VARCHAR"},
{"column_name": "awardnumber", "column_type": "VARCHAR"},
{"column_name": "awarddate", "column_type": "VARCHAR"},
{"column_name": "award", "column_type": "DOUBLE"},
{"column_name": "awardee", "column_type": "VARCHAR"},
{"column_name": "state", "column_type": "VARCHAR"},
{"column_name": "city", "column_type": "VARCHAR"},
{"column_name": "zipcode", "column_type": "VARCHAR"},
{"column_name": "countrycode", "column_type": "VARCHAR"}
]
@lru_cache(maxsize=1)
def get_schema():
return schema
COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
# =========================
# Database Interaction
# =========================
def load_dataset_schema():
"""
Loads the dataset schema into DuckDB by creating a view.
"""
con = duckdb.connect()
try:
con.execute("DROP VIEW IF EXISTS contract_data")
con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
return True
except Exception as e:
print(f"Error loading dataset schema: {e}")
return False
finally:
con.close()
# Load the dataset schema at startup
load_dataset_schema()
# =========================
# OpenAI API Integration
# =========================
def parse_query(nl_query):
"""
Converts a natural language query into a SQL query using OpenAI's API.
"""
messages = [
{"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."},
{"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"}
]
try:
response = openai.ChatCompletion.create(
model="gpt-3.5-turbo",
messages=messages,
temperature=0,
max_tokens=150,
)
sql_query = response.choices[0].message.content.strip()
return sql_query
except Exception as e:
return f"Error generating SQL query: {e}"
# =========================
# Plotting Utilities
# =========================
def detect_plot_intent(nl_query):
"""
Detects if the user's query involves plotting.
"""
plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'trend', 'histogram', 'bar', 'line', 'scatter', 'pie']
return any(keyword in nl_query.lower() for keyword in plot_keywords)
def generate_plot(nl_query, result_df):
"""
Generates a Plotly figure based on the result DataFrame and the user's intent.
"""
if not detect_plot_intent(nl_query):
return None, ""
columns = result_df.columns.tolist()
if len(columns) < 2:
return None, "Not enough data to generate a plot."
# Simple heuristic to choose plot type based on keywords
if 'bar' in nl_query.lower():
fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
elif 'line' in nl_query.lower():
fig = px.line(result_df, x=columns[0], y=columns[1], title='Line Chart')
elif 'scatter' in nl_query.lower():
fig = px.scatter(result_df, x=columns[0], y=columns[1], title='Scatter Plot')
elif 'pie' in nl_query.lower():
fig = px.pie(result_df, names=columns[0], values=columns[1], title='Pie Chart')
else:
# Default to bar chart
fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart')
fig.update_layout(title_x=0.5)
return fig, ""
# =========================
# Gradio Application UI
# =========================
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo:
gr.Markdown("""
<h1 style="text-align: center; font-size: 2.5em; color: #333333;">Parquet Data Explorer</h1>
<p style="text-align: center; color: #666666;">Query and visualize your data effortlessly.</p>
""", elem_id="main-title")
with gr.Row():
with gr.Column(scale=1):
query = gr.Textbox(
label="Your Query",
placeholder='e.g., "What are the total awards over 1M in California?"',
lines=1
)
# Hidden schema display that appears on focus
schema_display = gr.JSON(
label="Dataset Schema",
value=get_schema(),
interactive=False,
visible=False
)
error_out = gr.Markdown(
value="",
visible=False
)
with gr.Column(scale=2):
results_out = gr.DataFrame(
label="Results",
interactive=False
)
plot_out = gr.Plot(
label="Visualization"
)
gr.Markdown("""
<style>
/* Center the content */
.gradio-container {
max-width: 1000px;
margin: auto;
}
/* Style the main title */
#main-title h1 {
font-weight: bold;
}
/* Style the error alert */
.gradio-container .alert-error {
background-color: #ffe6e6;
color: #cc0000;
border: 1px solid #cc0000;
}
</style>
""")
# =========================
# Click Event Handlers
# =========================
def on_query_submit(nl_query):
"""
Handles the submission of a natural language query.
"""
if not nl_query.strip():
return gr.update(visible=True, value="Please enter a query."), None, None
sql_query = parse_query(nl_query)
if sql_query.startswith("Error"):
return gr.update(visible=True, value=sql_query), None, None
result_df, error_msg = execute_query(sql_query)
if error_msg:
return gr.update(visible=True, value=error_msg), None, None
fig, plot_error = generate_plot(nl_query, result_df)
if plot_error:
return gr.update(visible=True, value=plot_error), None, None
return gr.update(visible=False, value=""), result_df, fig
def on_input_focus():
"""
Shows the dataset schema when the input box is focused.
"""
return gr.update(visible=True)
# =========================
# Assign Event Handlers
# =========================
query.submit(
fn=on_query_submit,
inputs=query,
outputs=[error_out, results_out, plot_out]
)
query.focus(
fn=lambda: gr.update(visible=True),
inputs=None,
outputs=schema_display
)
# =========================
# Helper Functions
# =========================
def execute_query(sql_query):
"""
Executes the SQL query and returns the results.
"""
try:
con = duckdb.connect()
con.execute("PRAGMA threads=4") # Optimize for performance
con.execute("DROP VIEW IF EXISTS contract_data")
con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
result_df = con.execute(sql_query).fetchdf()
con.close()
return result_df, ""
except Exception as e:
return None, f"Error executing query: {e}"
# =========================
# Launch the Gradio App
# =========================
if __name__ == "__main__":
demo.launch()