import json import openai import gradio as gr import duckdb from functools import lru_cache import pandas as pd import plotly.express as px import os # Set OpenAI API key openai.api_key = os.getenv("OPENAI_API_KEY") # ========================= # Configuration and Setup # ========================= # Load the Parquet dataset path dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path # Provided schema schema = [ {"column_name": "department_ind_agency", "column_type": "VARCHAR"}, {"column_name": "cgac", "column_type": "BIGINT"}, {"column_name": "sub_tier", "column_type": "VARCHAR"}, {"column_name": "fpds_code", "column_type": "VARCHAR"}, {"column_name": "office", "column_type": "VARCHAR"}, {"column_name": "aac_code", "column_type": "VARCHAR"}, {"column_name": "posteddate", "column_type": "VARCHAR"}, {"column_name": "type", "column_type": "VARCHAR"}, {"column_name": "basetype", "column_type": "VARCHAR"}, {"column_name": "popstreetaddress", "column_type": "VARCHAR"}, {"column_name": "popcity", "column_type": "VARCHAR"}, {"column_name": "popstate", "column_type": "VARCHAR"}, {"column_name": "popzip", "column_type": "VARCHAR"}, {"column_name": "popcountry", "column_type": "VARCHAR"}, {"column_name": "active", "column_type": "VARCHAR"}, {"column_name": "awardnumber", "column_type": "VARCHAR"}, {"column_name": "awarddate", "column_type": "VARCHAR"}, {"column_name": "award", "column_type": "DOUBLE"}, {"column_name": "awardee", "column_type": "VARCHAR"}, {"column_name": "state", "column_type": "VARCHAR"}, {"column_name": "city", "column_type": "VARCHAR"}, {"column_name": "zipcode", "column_type": "VARCHAR"}, {"column_name": "countrycode", "column_type": "VARCHAR"} ] @lru_cache(maxsize=1) def get_schema(): return schema COLUMN_TYPES = {col['column_name']: col['column_type'] for col in get_schema()} # ========================= # Database Interaction # ========================= def load_dataset_schema(): """ Loads the dataset schema into DuckDB by creating a view. """ con = duckdb.connect() try: con.execute("DROP VIEW IF EXISTS contract_data") con.execute(f"CREATE VIEW contract_data AS SELECT * FROM '{dataset_path}'") return True except Exception as e: print(f"Error loading dataset schema: {e}") return False finally: con.close() # ========================= # OpenAI API Integration # ========================= def parse_query(nl_query): """ Converts a natural language query into a SQL query using OpenAI's API. """ messages = [ {"role": "system", "content": "Convert natural language queries to SQL queries for 'contract_data'."}, {"role": "user", "content": f"Schema:\n{json.dumps(schema, indent=2)}\n\nQuery:\n\"{nl_query}\"\n\nSQL:"} ] try: response = openai.chat.completions.create( model="gpt-4o-mini", messages=messages, temperature=0, max_tokens=150, ) sql_query = response.choices[0].message.content.strip() return sql_query except Exception as e: return f"Error generating SQL query: {e}" # ========================= # Plotting Utilities # ========================= def detect_plot_intent(nl_query): """ Detects if the user's query involves plotting. """ plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'trend', 'histogram', 'bar', 'line'] return any(keyword in nl_query.lower() for keyword in plot_keywords) def generate_plot_code(sql_query, result_df): """ Generates plotting code based on the SQL query and result DataFrame. """ if not detect_plot_intent(sql_query): return None columns = result_df.columns.tolist() if len(columns) >= 2: fig = px.bar(result_df, x=columns[0], y=columns[1], title='Generated Plot') fig.update_layout(title_x=0.5) return fig else: return None # ========================= # Gradio Application UI # ========================= with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue")) as demo: gr.Markdown(""" ## Parquet Data Explorer **Query and visualize data effortlessly.** """, elem_id="main-title") with gr.Row(): with gr.Column(scale=1): query = gr.Textbox( label="Ask a question about the data", placeholder='e.g., "What are the total awards over 1M in California?"', lines=1 ) # Display schema next to the input schema_display = gr.JSON(value=json.loads(json.dumps(get_schema(), indent=2)), visible=False) error_out = gr.Alert(variant="error", visible=False) with gr.Column(scale=2): results_out = gr.DataFrame(label="Results") plot_out = gr.Plot() def on_query_submit(nl_query): sql_query = parse_query(nl_query) if sql_query.startswith("Error"): return gr.update(visible=True, value=sql_query), None, None result_df, error_msg = execute_query(sql_query) if error_msg: return gr.update(visible=True, value=error_msg), None, None fig = generate_plot_code(nl_query, result_df) return gr.update(visible=False), result_df, fig def on_focus(): return gr.update(visible=True) query.submit( fn=on_query_submit, inputs=query, outputs=[error_out, results_out, plot_out] ) query.focus( fn=on_focus, outputs=schema_display ) # ========================= # Helper Functions # ========================= def execute_query(sql_query): """ Executes the SQL query and returns the results. """ if sql_query.startswith("Error"): return None, sql_query try: con = duckdb.connect() con.execute(f"CREATE OR REPLACE VIEW contract_data AS SELECT * FROM '{dataset_path}'") result_df = con.execute(sql_query).fetchdf() con.close() return result_df, "" except Exception as e: return None, f"Error executing query: {e}" # ========================= # Launch the Gradio App # ========================= demo.launch()