Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import json | |
import gradio as gr | |
import duckdb | |
from functools import lru_cache | |
from transformers import pipeline | |
import pandas as pd | |
import plotly.express as px | |
import openai | |
# Load the Parquet dataset path | |
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path | |
# Provided schema | |
schema = [ | |
{"column_name": "department_ind_agency", "column_type": "VARCHAR"}, | |
{"column_name": "cgac", "column_type": "BIGINT"}, | |
{"column_name": "sub_tier", "column_type": "VARCHAR"}, | |
{"column_name": "fpds_code", "column_type": "VARCHAR"}, | |
{"column_name": "office", "column_type": "VARCHAR"}, | |
{"column_name": "aac_code", "column_type": "VARCHAR"}, | |
{"column_name": "posteddate", "column_type": "VARCHAR"}, | |
{"column_name": "type", "column_type": "VARCHAR"}, | |
{"column_name": "basetype", "column_type": "VARCHAR"}, | |
{"column_name": "popstreetaddress", "column_type": "VARCHAR"}, | |
{"column_name": "popcity", "column_type": "VARCHAR"}, | |
{"column_name": "popstate", "column_type": "VARCHAR"}, | |
{"column_name": "popzip", "column_type": "VARCHAR"}, | |
{"column_name": "popcountry", "column_type": "VARCHAR"}, | |
{"column_name": "active", "column_type": "VARCHAR"}, | |
{"column_name": "awardnumber", "column_type": "VARCHAR"}, | |
{"column_name": "awarddate", "column_type": "VARCHAR"}, | |
{"column_name": "award", "column_type": "DOUBLE"}, | |
{"column_name": "awardee", "column_type": "VARCHAR"}, | |
{"column_name": "state", "column_type": "VARCHAR"}, | |
{"column_name": "city", "column_type": "VARCHAR"}, | |
{"column_name": "zipcode", "column_type": "VARCHAR"}, | |
{"column_name": "countrycode", "column_type": "VARCHAR"} | |
] | |
# Cache the schema loading | |
def get_schema(): | |
return schema | |
# Map column names to their types | |
COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()} | |
# Function to load the dataset schema into DuckDB | |
def load_dataset_schema(): | |
con = duckdb.connect() | |
try: | |
# Drop the view if it exists to avoid errors | |
con.execute("DROP VIEW IF EXISTS contract_data") | |
con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'") | |
return True | |
except Exception as e: | |
print(f"Error loading dataset schema: {e}") | |
return False | |
finally: | |
con.close() | |
# Advanced Natural Language to SQL Parser using OpenAI's GPT-3 | |
def parse_query(nl_query): | |
""" | |
Converts a natural language query into SQL query using OpenAI GPT-3. | |
""" | |
openai.api_key = 'YOUR_OPENAI_API_KEY' # Replace with your OpenAI API key | |
prompt = f""" | |
Convert the following natural language query into a SQL query for a DuckDB database. Use 'contract_data' as the table name. | |
Schema: | |
{json.dumps(schema, indent=2)} | |
Query: | |
"{nl_query}" | |
""" | |
try: | |
response = openai.Completion.create( | |
engine="text-davinci-003", | |
prompt=prompt, | |
temperature=0, | |
max_tokens=150, | |
top_p=1, | |
n=1, | |
stop=None | |
) | |
sql_query = response.choices[0].text.strip() | |
return sql_query | |
except Exception as e: | |
return f"Error generating SQL query: {e}" | |
# Function to detect if the user wants a plot | |
def detect_plot_intent(nl_query): | |
""" | |
Detects if the user's query involves plotting. | |
""" | |
plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'histogram', 'bar chart', 'line chart', 'scatter plot', 'pie chart'] | |
for keyword in plot_keywords: | |
if keyword in nl_query.lower(): | |
return True | |
return False | |
# Generate SQL and Plot Code based on user query | |
def generate_sql_and_plot_code(query): | |
""" | |
Generates SQL query and plotting code based on the natural language input. | |
""" | |
is_plot = detect_plot_intent(query) | |
sql_query = parse_query(query) | |
plot_code = "" | |
if is_plot: | |
# Generate plot code based on the query | |
# For simplicity, we'll generate a basic plot code | |
plot_code = """ | |
import plotly.express as px | |
fig = px.bar(result_df, x='x_column', y='y_column') | |
""" | |
return sql_query, plot_code | |
# Execute the SQL query and return results or error | |
def execute_query(sql_query): | |
""" | |
Executes the SQL query and returns the results as a DataFrame. | |
""" | |
try: | |
con = duckdb.connect() | |
# Ensure the view is created | |
con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'") | |
result_df = con.execute(sql_query).fetchdf() | |
con.close() | |
return result_df, "" | |
except Exception as e: | |
# In case of error, return None and error message | |
return None, f"Error executing query: {e}" | |
# Generate and display plot | |
def generate_plot(plot_code, result_df): | |
""" | |
Executes the plot code to generate a plot from the result DataFrame. | |
""" | |
if not plot_code.strip(): | |
return None, "No plot code provided." | |
try: | |
# Replace placeholders in plot_code with actual column names | |
if result_df.empty: | |
return None, "Result DataFrame is empty." | |
columns = result_df.columns.tolist() | |
if len(columns) < 2: | |
return None, "Not enough columns to plot." | |
plot_code = plot_code.replace('x_column', columns[0]) | |
plot_code = plot_code.replace('y_column', columns[1]) | |
# Execute the plot code | |
local_vars = {'result_df': result_df} | |
exec(plot_code, {'px': px}, local_vars) | |
fig = local_vars.get('fig', None) | |
if fig: | |
return fig, "" | |
else: | |
return None, "Plot could not be generated." | |
except Exception as e: | |
return None, f"Error generating plot: {e}" | |
# Cache the schema JSON for display | |
def get_schema_json(): | |
return json.dumps(get_schema(), indent=2) | |
# Initialize the dataset schema | |
if not load_dataset_schema(): | |
raise Exception("Failed to load dataset schema. Please check the dataset path and format.") | |
# Gradio app UI | |
with gr.Blocks() as demo: | |
gr.Markdown(""" | |
# Parquet SQL Query and Plotting App | |
**Query and visualize data** in `sample_contract_df.parquet` | |
## Instructions | |
1. **Describe the data you want to retrieve or plot**: For example: | |
- `Show all awards greater than 1,000,000 in California` | |
- `Plot the distribution of awards by state` | |
- `Show a bar chart of total awards per department` | |
- `List awardees who received multiple awards along with award amounts` | |
- `Number of awards issued by each department division` | |
2. **Generate SQL**: Click "Generate SQL" to see the SQL query that will be executed. | |
3. **Execute Query**: Click "Execute Query" to run the query and view the results. | |
4. **View Plot**: If your query involves plotting, the plot will be displayed. | |
5. **View Dataset Schema**: Check the "Dataset Schema" tab to understand available columns and their types. | |
## Example Queries | |
- `Plot the total award amount by state` | |
- `Show a histogram of awards over time` | |
- `award greater than 1000000 and state equal to "CA"` | |
- `List awards where department_ind_agency contains "Defense"` | |
""") | |
with gr.Tabs(): | |
# Query Tab | |
with gr.TabItem("Query Data"): | |
with gr.Row(): | |
with gr.Column(scale=1): | |
query = gr.Textbox( | |
label="Natural Language Query", | |
placeholder='e.g., "Show all awards greater than 1,000,000 in California"', | |
lines=4 | |
) | |
btn_generate = gr.Button("Generate SQL") | |
sql_out = gr.Code(label="Generated SQL Query", language="sql") | |
plot_code_out = gr.Code(label="Generated Plot Code", language="python") | |
btn_execute = gr.Button("Execute Query") | |
error_out = gr.Markdown("", visible=False) | |
with gr.Column(scale=2): | |
results_out = gr.Dataframe(label="Query Results", interactive=False) | |
plot_out = gr.Plot(label="Plot") | |
# Schema Tab | |
with gr.TabItem("Dataset Schema"): | |
gr.Markdown("### Dataset Schema") | |
schema_display = gr.JSON(label="Schema", value=json.loads(get_schema_json())) | |
# Set up click events | |
def on_generate_click(nl_query): | |
sql_query, plot_code = generate_sql_and_plot_code(nl_query) | |
return sql_query, plot_code | |
def on_execute_click(sql_query, plot_code): | |
result_df, error_msg = execute_query(sql_query) | |
if error_msg: | |
return None, None, error_msg | |
if plot_code.strip(): | |
fig, plot_error = generate_plot(plot_code, result_df) | |
if plot_error: | |
return result_df, None, plot_error | |
else: | |
return result_df, fig, "" | |
else: | |
return result_df, None, "" | |
btn_generate.click( | |
fn=on_generate_click, | |
inputs=query, | |
outputs=[sql_out, plot_code_out], | |
) | |
btn_execute.click( | |
fn=on_execute_click, | |
inputs=[sql_out, plot_code_out], | |
outputs=[results_out, plot_out, error_out], | |
) | |
# Launch the app | |
demo.launch() | |