LeonceNsh's picture
Update app.py
dfe1769 verified
raw
history blame
9.35 kB
import json
import gradio as gr
import duckdb
from functools import lru_cache
from transformers import pipeline
import pandas as pd
import plotly.express as px
import openai
# Load the Parquet dataset path
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
# Provided schema
schema = [
{"column_name": "department_ind_agency", "column_type": "VARCHAR"},
{"column_name": "cgac", "column_type": "BIGINT"},
{"column_name": "sub_tier", "column_type": "VARCHAR"},
{"column_name": "fpds_code", "column_type": "VARCHAR"},
{"column_name": "office", "column_type": "VARCHAR"},
{"column_name": "aac_code", "column_type": "VARCHAR"},
{"column_name": "posteddate", "column_type": "VARCHAR"},
{"column_name": "type", "column_type": "VARCHAR"},
{"column_name": "basetype", "column_type": "VARCHAR"},
{"column_name": "popstreetaddress", "column_type": "VARCHAR"},
{"column_name": "popcity", "column_type": "VARCHAR"},
{"column_name": "popstate", "column_type": "VARCHAR"},
{"column_name": "popzip", "column_type": "VARCHAR"},
{"column_name": "popcountry", "column_type": "VARCHAR"},
{"column_name": "active", "column_type": "VARCHAR"},
{"column_name": "awardnumber", "column_type": "VARCHAR"},
{"column_name": "awarddate", "column_type": "VARCHAR"},
{"column_name": "award", "column_type": "DOUBLE"},
{"column_name": "awardee", "column_type": "VARCHAR"},
{"column_name": "state", "column_type": "VARCHAR"},
{"column_name": "city", "column_type": "VARCHAR"},
{"column_name": "zipcode", "column_type": "VARCHAR"},
{"column_name": "countrycode", "column_type": "VARCHAR"}
]
# Cache the schema loading
@lru_cache(maxsize=1)
def get_schema():
return schema
# Map column names to their types
COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()}
# Function to load the dataset schema into DuckDB
@lru_cache(maxsize=1)
def load_dataset_schema():
con = duckdb.connect()
try:
# Drop the view if it exists to avoid errors
con.execute("DROP VIEW IF EXISTS contract_data")
con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
return True
except Exception as e:
print(f"Error loading dataset schema: {e}")
return False
finally:
con.close()
# Advanced Natural Language to SQL Parser using OpenAI's GPT-3
def parse_query(nl_query):
"""
Converts a natural language query into SQL query using OpenAI GPT-3.
"""
openai.api_key = 'YOUR_OPENAI_API_KEY' # Replace with your OpenAI API key
prompt = f"""
Convert the following natural language query into a SQL query for a DuckDB database. Use 'contract_data' as the table name.
Schema:
{json.dumps(schema, indent=2)}
Query:
"{nl_query}"
"""
try:
response = openai.Completion.create(
engine="text-davinci-003",
prompt=prompt,
temperature=0,
max_tokens=150,
top_p=1,
n=1,
stop=None
)
sql_query = response.choices[0].text.strip()
return sql_query
except Exception as e:
return f"Error generating SQL query: {e}"
# Function to detect if the user wants a plot
def detect_plot_intent(nl_query):
"""
Detects if the user's query involves plotting.
"""
plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'histogram', 'bar chart', 'line chart', 'scatter plot', 'pie chart']
for keyword in plot_keywords:
if keyword in nl_query.lower():
return True
return False
# Generate SQL and Plot Code based on user query
def generate_sql_and_plot_code(query):
"""
Generates SQL query and plotting code based on the natural language input.
"""
is_plot = detect_plot_intent(query)
sql_query = parse_query(query)
plot_code = ""
if is_plot:
# Generate plot code based on the query
# For simplicity, we'll generate a basic plot code
plot_code = """
import plotly.express as px
fig = px.bar(result_df, x='x_column', y='y_column')
"""
return sql_query, plot_code
# Execute the SQL query and return results or error
def execute_query(sql_query):
"""
Executes the SQL query and returns the results as a DataFrame.
"""
try:
con = duckdb.connect()
# Ensure the view is created
con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'")
result_df = con.execute(sql_query).fetchdf()
con.close()
return result_df, ""
except Exception as e:
# In case of error, return None and error message
return None, f"Error executing query: {e}"
# Generate and display plot
def generate_plot(plot_code, result_df):
"""
Executes the plot code to generate a plot from the result DataFrame.
"""
if not plot_code.strip():
return None, "No plot code provided."
try:
# Replace placeholders in plot_code with actual column names
if result_df.empty:
return None, "Result DataFrame is empty."
columns = result_df.columns.tolist()
if len(columns) < 2:
return None, "Not enough columns to plot."
plot_code = plot_code.replace('x_column', columns[0])
plot_code = plot_code.replace('y_column', columns[1])
# Execute the plot code
local_vars = {'result_df': result_df}
exec(plot_code, {'px': px}, local_vars)
fig = local_vars.get('fig', None)
if fig:
return fig, ""
else:
return None, "Plot could not be generated."
except Exception as e:
return None, f"Error generating plot: {e}"
# Cache the schema JSON for display
@lru_cache(maxsize=1)
def get_schema_json():
return json.dumps(get_schema(), indent=2)
# Initialize the dataset schema
if not load_dataset_schema():
raise Exception("Failed to load dataset schema. Please check the dataset path and format.")
# Gradio app UI
with gr.Blocks() as demo:
gr.Markdown("""
# Parquet SQL Query and Plotting App
**Query and visualize data** in `sample_contract_df.parquet`
## Instructions
1. **Describe the data you want to retrieve or plot**: For example:
- `Show all awards greater than 1,000,000 in California`
- `Plot the distribution of awards by state`
- `Show a bar chart of total awards per department`
- `List awardees who received multiple awards along with award amounts`
- `Number of awards issued by each department division`
2. **Generate SQL**: Click "Generate SQL" to see the SQL query that will be executed.
3. **Execute Query**: Click "Execute Query" to run the query and view the results.
4. **View Plot**: If your query involves plotting, the plot will be displayed.
5. **View Dataset Schema**: Check the "Dataset Schema" tab to understand available columns and their types.
## Example Queries
- `Plot the total award amount by state`
- `Show a histogram of awards over time`
- `award greater than 1000000 and state equal to "CA"`
- `List awards where department_ind_agency contains "Defense"`
""")
with gr.Tabs():
# Query Tab
with gr.TabItem("Query Data"):
with gr.Row():
with gr.Column(scale=1):
query = gr.Textbox(
label="Natural Language Query",
placeholder='e.g., "Show all awards greater than 1,000,000 in California"',
lines=4
)
btn_generate = gr.Button("Generate SQL")
sql_out = gr.Code(label="Generated SQL Query", language="sql")
plot_code_out = gr.Code(label="Generated Plot Code", language="python")
btn_execute = gr.Button("Execute Query")
error_out = gr.Markdown("", visible=False)
with gr.Column(scale=2):
results_out = gr.Dataframe(label="Query Results", interactive=False)
plot_out = gr.Plot(label="Plot")
# Schema Tab
with gr.TabItem("Dataset Schema"):
gr.Markdown("### Dataset Schema")
schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json()))
# Set up click events
def on_generate_click(nl_query):
sql_query, plot_code = generate_sql_and_plot_code(nl_query)
return sql_query, plot_code
def on_execute_click(sql_query, plot_code):
result_df, error_msg = execute_query(sql_query)
if error_msg:
return None, None, error_msg
if plot_code.strip():
fig, plot_error = generate_plot(plot_code, result_df)
if plot_error:
return result_df, None, plot_error
else:
return result_df, fig, ""
else:
return result_df, None, ""
btn_generate.click(
fn=on_generate_click,
inputs=query,
outputs=[sql_out, plot_code_out],
)
btn_execute.click(
fn=on_execute_click,
inputs=[sql_out, plot_code_out],
outputs=[results_out, plot_out, error_out],
)
# Launch the app
demo.launch()