Spaces:
Sleeping
Sleeping
import json | |
import openai | |
import gradio as gr | |
import duckdb | |
from functools import lru_cache | |
import pandas as pd | |
import plotly.express as px | |
import os | |
# ========================= | |
# Configuration and Setup | |
# ========================= | |
# Set OpenAI API key | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
# Load the Parquet dataset path | |
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path | |
# Provided schema | |
schema = [ | |
{"column_name": "department_ind_agency", "column_type": "VARCHAR"}, | |
{"column_name": "cgac", "column_type": "BIGINT"}, | |
{"column_name": "sub_tier", "column_type": "VARCHAR"}, | |
{"column_name": "fpds_code", "column_type": "VARCHAR"}, | |
{"column_name": "office", "column_type": "VARCHAR"}, | |
{"column_name": "aac_code", "column_type": "VARCHAR"}, | |
{"column_name": "posteddate", "column_type": "VARCHAR"}, | |
{"column_name": "type", "column_type": "VARCHAR"}, | |
{"column_name": "basetype", "column_type": "VARCHAR"}, | |
{"column_name": "popstreetaddress", "column_type": "VARCHAR"}, | |
{"column_name": "popcity", "column_type": "VARCHAR"}, | |
{"column_name": "popstate", "column_type": "VARCHAR"}, | |
{"column_name": "popzip", "column_type": "VARCHAR"}, | |
{"column_name": "popcountry", "column_type": "VARCHAR"}, | |
{"column_name": "active", "column_type": "VARCHAR"}, | |
{"column_name": "awardnumber", "column_type": "VARCHAR"}, | |
{"column_name": "awarddate", "column_type": "VARCHAR"}, | |
{"column_name": "award", "column_type": "DOUBLE"}, | |
{"column_name": "awardee", "column_type": "VARCHAR"}, | |
{"column_name": "state", "column_type": "VARCHAR"}, | |
{"column_name": "city", "column_type": "VARCHAR"}, | |
{"column_name": "zipcode", "column_type": "VARCHAR"}, | |
{"column_name": "countrycode", "column_type": "VARCHAR"} | |
] | |
def get_schema(): | |
return schema | |
COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()} | |
# ========================= | |
# Database Interaction | |
# ========================= | |
def load_dataset_schema(): | |
""" | |
Loads the dataset schema into DuckDB by creating a view. | |
""" | |
con = duckdb.connect() | |
try: | |
con.execute("DROP VIEW IF EXISTS contract_data") | |
con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'") | |
return True | |
except Exception as e: | |
print(f"Error loading dataset schema: {e}") | |
return False | |
finally: | |
con.close() | |
# Load the dataset schema at startup | |
load_dataset_schema() | |
# ========================= | |
# OpenAI API Integration | |
# ========================= | |
def parse_query(nl_query): | |
""" | |
Converts a natural language query into a SQL query using OpenAI's API. | |
""" | |
messages = [ | |
{"role": "system", "content": "You are an assistant that converts natural language queries into SQL queries for the 'contract_data' table."}, | |
{"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"} | |
] | |
try: | |
response = openai.ChatCompletion.create( | |
model="gpt-3.5-turbo", | |
messages=messages, | |
temperature=0, | |
max_tokens=150, | |
) | |
sql_query = response.choices[0].message.content.strip() | |
return sql_query | |
except Exception as e: | |
return f"Error generating SQL query: {e}" | |
# ========================= | |
# Plotting Utilities | |
# ========================= | |
def detect_plot_intent(nl_query): | |
""" | |
Detects if the user's query involves plotting. | |
""" | |
plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'trend', 'histogram', 'bar', 'line', 'scatter', 'pie'] | |
return any(keyword in nl_query.lower() for keyword in plot_keywords) | |
def generate_plot(nl_query, result_df): | |
""" | |
Generates a Plotly figure based on the result DataFrame and the user's intent. | |
""" | |
if not detect_plot_intent(nl_query): | |
return None, "" | |
columns = result_df.columns.tolist() | |
if len(columns) < 2: | |
return None, "Not enough data to generate a plot." | |
# Simple heuristic to choose plot type based on keywords | |
if 'bar' in nl_query.lower(): | |
fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart') | |
elif 'line' in nl_query.lower(): | |
fig = px.line(result_df, x=columns[0], y=columns[1], title='Line Chart') | |
elif 'scatter' in nl_query.lower(): | |
fig = px.scatter(result_df, x=columns[0], y=columns[1], title='Scatter Plot') | |
elif 'pie' in nl_query.lower(): | |
fig = px.pie(result_df, names=columns[0], values=columns[1], title='Pie Chart') | |
else: | |
# Default to bar chart | |
fig = px.bar(result_df, x=columns[0], y=columns[1], title='Bar Chart') | |
fig.update_layout(title_x=0.5) | |
return fig, "" | |
# ========================= | |
# Gradio Application UI | |
# ========================= | |
with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: | |
gr.Markdown(""" | |
<h1 style="text-align: center; font-size: 2.5em; color: #333333;">Parquet Data Explorer</h1> | |
<p style="text-align: center; color: #666666;">Query and visualize your data effortlessly.</p> | |
""", elem_id="main-title") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
query = gr.Textbox( | |
label="Your Query", | |
placeholder='e.g., "What are the total awards over 1M in California?"', | |
lines=1 | |
) | |
# Hidden schema display that appears on focus | |
schema_display = gr.JSON( | |
label="Dataset Schema", | |
value=get_schema(), | |
interactive=False, | |
visible=False | |
) | |
error_out = gr.Markdown( | |
value="", | |
visible=False | |
) | |
with gr.Column(scale=2): | |
results_out = gr.DataFrame( | |
label="Results", | |
interactive=False | |
) | |
plot_out = gr.Plot( | |
label="Visualization" | |
) | |
gr.Markdown(""" | |
<style> | |
/* Center the content */ | |
.gradio-container { | |
max-width: 1000px; | |
margin: auto; | |
} | |
/* Style the main title */ | |
#main-title h1 { | |
font-weight: bold; | |
} | |
/* Style the error alert */ | |
.gradio-container .alert-error { | |
background-color: #ffe6e6; | |
color: #cc0000; | |
border: 1px solid #cc0000; | |
} | |
</style> | |
""") | |
# ========================= | |
# Click Event Handlers | |
# ========================= | |
def on_query_submit(nl_query): | |
""" | |
Handles the submission of a natural language query. | |
""" | |
if not nl_query.strip(): | |
return gr.update(visible=True, value="Please enter a query."), None, None | |
sql_query = parse_query(nl_query) | |
if sql_query.startswith("Error"): | |
return gr.update(visible=True, value=sql_query), None, None | |
result_df, error_msg = execute_query(sql_query) | |
if error_msg: | |
return gr.update(visible=True, value=error_msg), None, None | |
fig, plot_error = generate_plot(nl_query, result_df) | |
if plot_error: | |
return gr.update(visible=True, value=plot_error), None, None | |
return gr.update(visible=False, value=""), result_df, fig | |
def on_input_focus(): | |
""" | |
Shows the dataset schema when the input box is focused. | |
""" | |
return gr.update(visible=True) | |
# ========================= | |
# Assign Event Handlers | |
# ========================= | |
query.submit( | |
fn=on_query_submit, | |
inputs=query, | |
outputs=[error_out, results_out, plot_out] | |
) | |
query.focus( | |
fn=lambda: gr.update(visible=True), | |
inputs=None, | |
outputs=schema_display | |
) | |
# ========================= | |
# Helper Functions | |
# ========================= | |
def execute_query(sql_query): | |
""" | |
Executes the SQL query and returns the results. | |
""" | |
try: | |
con = duckdb.connect() | |
con.execute("PRAGMA threads=4") # Optimize for performance | |
con.execute("DROP VIEW IF EXISTS contract_data") | |
con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'") | |
result_df = con.execute(sql_query).fetchdf() | |
con.close() | |
return result_df, "" | |
except Exception as e: | |
return None, f"Error executing query: {e}" | |
# ========================= | |
# Launch the Gradio App | |
# ========================= | |
if __name__ == "__main__": | |
demo.launch() | |