Spaces:
Runtime error
Runtime error
import streamlit as st | |
import pandas as pd | |
import sqlite3 | |
import plotly.express as px | |
import json | |
# Set paths to the default files | |
DEFAULT_PROMPT_PATH = "prompt_engineering.json" | |
DEFAULT_METADATA_PATH = "default_metadata.csv" | |
DEFAULT_DATA_PATH = "default_data.csv" | |
# Load the prompt engineering JSON file (use default if no user-uploaded prompt file) | |
with open(DEFAULT_PROMPT_PATH) as f: | |
prompt_data = json.load(f) | |
# Function to find a query based on the user prompt | |
def get_query_from_prompt(user_prompt): | |
for item in prompt_data['prompts']: | |
if item['question'].lower() in user_prompt.lower(): | |
return item['query'] | |
return None # Return None if no matching query is found | |
# Step 1: Upload metadata.csv file (or use default) | |
metadata_file = st.file_uploader("Upload your metadata.csv file", type=["csv"]) | |
if metadata_file is None: | |
metadata = pd.read_csv(DEFAULT_METADATA_PATH) | |
st.write("Using default metadata.csv file.") | |
else: | |
metadata = pd.read_csv(metadata_file) | |
st.write("Metadata loaded successfully!") | |
st.dataframe(metadata) | |
# Step 2: Upload CSV data file (or use default) | |
csv_file = st.file_uploader("Upload your CSV file", type=["csv"]) | |
if csv_file is None: | |
data = pd.read_csv(DEFAULT_DATA_PATH) | |
st.write("Using default data.csv file.") | |
else: | |
data = pd.read_csv(csv_file) | |
st.write("Data Preview:") | |
st.dataframe(data.head()) | |
# Step 3: Load CSV data into a SQLite database (SQL agent) | |
conn = sqlite3.connect(':memory:') # Use an in-memory SQLite database | |
data.to_sql('sales_data', conn, index=False, if_exists='replace') | |
# Step 4: Get user prompt and map to SQL query | |
user_prompt = st.text_input("Enter your natural language prompt:") | |
# Step 5: Process the prompt and generate SQL query dynamically | |
if user_prompt: | |
query = get_query_from_prompt(user_prompt) | |
if query: | |
result = pd.read_sql(query, conn) | |
st.write("Query Results:") | |
st.dataframe(result) | |
# If the query involves plotting (like "plot sales"), show the chart | |
if "plot" in user_prompt.lower(): | |
fig = px.bar(result, x='Date', y='Sales', title="Sales Over Time") | |
st.plotly_chart(fig) | |
else: | |
st.write("Sorry, I couldn't find a matching query for your prompt.") | |