import json import openai import gradio as gr import duckdb from functools import lru_cache import pandas as pd import plotly.express as px import os # Set OpenAI API key client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) # ========================= # Configuration and Setup # ========================= # Load the Parquet dataset path dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path # Provided schema schema = [ {"column_name": "department_ind_agency", "column_type": "VARCHAR"}, {"column_name": "cgac", "column_type": "BIGINT"}, {"column_name": "sub_tier", "column_type": "VARCHAR"}, {"column_name": "fpds_code", "column_type": "VARCHAR"}, {"column_name": "office", "column_type": "VARCHAR"}, {"column_name": "aac_code", "column_type": "VARCHAR"}, {"column_name": "posteddate", "column_type": "VARCHAR"}, {"column_name": "type", "column_type": "VARCHAR"}, {"column_name": "basetype", "column_type": "VARCHAR"}, {"column_name": "popstreetaddress", "column_type": "VARCHAR"}, {"column_name": "popcity", "column_type": "VARCHAR"}, {"column_name": "popstate", "column_type": "VARCHAR"}, {"column_name": "popzip", "column_type": "VARCHAR"}, {"column_name": "popcountry", "column_type": "VARCHAR"}, {"column_name": "active", "column_type": "VARCHAR"}, {"column_name": "awardnumber", "column_type": "VARCHAR"}, {"column_name": "awarddate", "column_type": "VARCHAR"}, {"column_name": "award", "column_type": "DOUBLE"}, {"column_name": "awardee", "column_type": "VARCHAR"}, {"column_name": "state", "column_type": "VARCHAR"}, {"column_name": "city", "column_type": "VARCHAR"}, {"column_name": "zipcode", "column_type": "VARCHAR"}, {"column_name": "countrycode", "column_type": "VARCHAR"} ] @lru_cache(maxsize=1) def get_schema(): return schema COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()} # ========================= # Database Interaction # ========================= def load_dataset_schema(): """ Loads the dataset schema into DuckDB by creating a view. """ con = duckdb.connect() try: con.execute("DROP VIEW IF EXISTS contract_data") con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'") return True except Exception as e: print(f"Error loading dataset schema: {e}") return False finally: con.close() # ========================= # OpenAI API Integration # ========================= async def parse_query(nl_query): """ Converts a natural language query into a SQL query using OpenAI's API. """ messages = [ {"role": "system", "content": "Convert natural language queries to SQL queries for 'contract_data'."}, {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"} ] try: response = openai.chat.completions.create( model="gpt-4", messages=messages, temperature=0, max_tokens=150, ) sql_query = response.choices[0].message.content.strip() return sql_query except Exception as e: return f"Error generating SQL query: {e}" # ========================= # Plotting Utilities # ========================= def detect_plot_intent(nl_query): """ Detects if the user's query involves plotting. """ plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize'] return any(keyword in nl_query.lower() for keyword in plot_keywords) async def generate_sql_and_plot_code(query): """ Generates SQL query and optional plotting code. """ is_plot = detect_plot_intent(query) sql_query = await parse_query(query) plot_code = "" if is_plot and not sql_query.startswith("Error"): plot_code = """ import plotly.express as px fig = px.bar(result_df, x='x_column', y='y_column', title='Generated Plot') fig.update_layout(title_x=0.5) """ return sql_query, plot_code def execute_query(sql_query): """ Executes the SQL query and returns the results. """ if sql_query.startswith("Error"): return None, sql_query try: con = duckdb.connect() con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'") result_df = con.execute(sql_query).fetchdf() con.close() return result_df, "" except Exception as e: return None, f"Error executing query: {e}" def generate_plot(plot_code, result_df): """ Executes the plot code to generate a plot from the result DataFrame. """ if not plot_code.strip(): return None, "No plot code provided." try: columns = result_df.columns.tolist() if len(columns) < 2: return None, "Not enough columns to plot." plot_code = plot_code.replace('x_column', columns[0]) plot_code = plot_code.replace('y_column', columns[1]) local_vars = {'result_df': result_df, 'px': px} exec(plot_code, {}, local_vars) fig = local_vars.get('fig', None) return fig, "" if fig else "Plot could not be generated." except Exception as e: return None, f"Error generating plot: {e}" # ========================= # Gradio Application UI # ========================= with gr.Blocks() as demo: gr.Markdown(""" # Parquet SQL Query and Plotting App **Query and visualize data** in `sample_contract_df.parquet` ## Instructions 1. **Describe the data you want**: e.g., `Show awards over 1M in CA` 2. **Use Example Queries**: Click on any example query button below to execute. 3. **Generate SQL**: Or, enter your own query and click "Generate SQL" to see the SQL query. 4. **Execute Query**: Run the query to view results and plots. 5. **Dataset Schema**: See available columns and types in the "Schema" tab. ## Example Queries """) with gr.Tabs(): with gr.TabItem("Query Data"): with gr.Row(): with gr.Column(scale=1): query = gr.Textbox(label="Natural Language Query", placeholder='e.g., "Awards > 1M in CA"') # Example query buttons gr.Markdown("### Click on an example query:") with gr.Row(): btn_example1 = gr.Button("Show awards over 1M in CA") btn_example2 = gr.Button("List all contracts in New York") btn_example3 = gr.Button("Show top 5 departments by award amount") btn_example4 = gr.Button("Execute: SELECT * from contract_data LIMIT 10;") btn_generate = gr.Button("Generate SQL") sql_out = gr.Code(label="Generated SQL Query", language="sql") plot_code_out = gr.Code(label="Generated Plot Code", language="python") btn_execute = gr.Button("Execute Query") error_out = gr.Markdown("", visible=False) with gr.Column(scale=2): results_out = gr.Dataframe(label="Query Results", interactive=False) plot_out = gr.Plot(label="Plot") with gr.TabItem("Dataset Schema"): gr.Markdown("### Dataset Schema") schema_display = gr.JSON(label="Schema", value=json.loads(json.dumps(get_schema(), indent=2))) # ========================= # Click Event Handlers # ========================= async def on_generate_click(nl_query): """ Handles the "Generate SQL" button click event. """ sql_query, plot_code = await generate_sql_and_plot_code(nl_query) return sql_query, plot_code def on_execute_click(sql_query, plot_code): """ Handles the "Execute Query" button click event. """ result_df, error_msg = execute_query(sql_query) if error_msg: return None, None, error_msg if plot_code.strip(): fig, plot_error = generate_plot(plot_code, result_df) if plot_error: return result_df, None, plot_error else: return result_df, fig, "" else: return result_df, None, "" # Functions for example query buttons async def on_example_nl_click(query_text): sql_query, plot_code = await generate_sql_and_plot_code(query_text) result_df, error_msg = execute_query(sql_query) fig = None if error_msg: return sql_query, plot_code, None, None, error_msg if plot_code.strip(): fig, plot_error = generate_plot(plot_code, result_df) if plot_error: error_msg = plot_error else: error_msg = "" else: fig = None error_msg = "" return sql_query, plot_code, result_df, fig, error_msg def on_example_sql_click(sql_query): result_df, error_msg = execute_query(sql_query) fig = None plot_code = "" if error_msg: return sql_query, plot_code, None, None, error_msg else: return sql_query, plot_code, result_df, fig, "" async def on_example1_click(): return await on_example_nl_click("Show awards over 1M in CA") async def on_example2_click(): return await on_example_nl_click("List all contracts in New York") async def on_example3_click(): return await on_example_nl_click("Show top 5 departments by award amount") def on_example4_click(): return on_example_sql_click("SELECT * from contract_data LIMIT 10;") btn_example1.click(fn=on_example1_click, inputs=[], outputs=[sql_out, plot_code_out, results_out, plot_out, error_out]) btn_example2.click(fn=on_example2_click, inputs=[], outputs=[sql_out, plot_code_out, results_out, plot_out, error_out]) btn_example3.click(fn=on_example3_click, inputs=[], outputs=[sql_out, plot_code_out, results_out, plot_out, error_out]) btn_example4.click(fn=on_example4_click, inputs=[], outputs=[sql_out, plot_code_out, results_out, plot_out, error_out]) btn_generate.click(fn=on_generate_click, inputs=query, outputs=[sql_out, plot_code_out]) btn_execute.click(fn=on_execute_click, inputs=[sql_out, plot_code_out], outputs=[results_out, plot_out, error_out]) # ========================= # Launch the Gradio App # ========================= demo.launch()