Ari
Update app.py
75829f5 verified
raw
history blame
2.3 kB
import streamlit as st
import pandas as pd
import sqlite3
import plotly.express as px
import json
# Set paths to the default files
DEFAULT_PROMPT_PATH = "prompt_engineering.json"
DEFAULT_METADATA_PATH = "default_metadata.csv"
DEFAULT_DATA_PATH = "default_data.csv"
# Load the prompt engineering JSON file (use default if no user-uploaded prompt file)
with open(DEFAULT_PROMPT_PATH) as f:
prompt_data = json.load(f)
# Function to find a query based on the user prompt
def get_query_from_prompt(user_prompt):
for item in prompt_data['prompts']:
if item['question'].lower() in user_prompt.lower():
return item['query']
return None # Return None if no matching query is found
# Step 1: Upload metadata.csv file (or use default)
metadata_file = st.file_uploader("Upload your metadata.csv file", type=["csv"])
if metadata_file is None:
metadata = pd.read_csv(DEFAULT_METADATA_PATH)
st.write("Using default metadata.csv file.")
else:
metadata = pd.read_csv(metadata_file)
st.write("Metadata loaded successfully!")
st.dataframe(metadata)
# Step 2: Upload CSV data file (or use default)
csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
if csv_file is None:
data = pd.read_csv(DEFAULT_DATA_PATH)
st.write("Using default data.csv file.")
else:
data = pd.read_csv(csv_file)
st.write("Data Preview:")
st.dataframe(data.head())
# Step 3: Load CSV data into a SQLite database (SQL agent)
conn = sqlite3.connect(':memory:') # Use an in-memory SQLite database
data.to_sql('sales_data', conn, index=False, if_exists='replace')
# Step 4: Get user prompt and map to SQL query
user_prompt = st.text_input("Enter your natural language prompt:")
# Step 5: Process the prompt and generate SQL query dynamically
if user_prompt:
query = get_query_from_prompt(user_prompt)
if query:
result = pd.read_sql(query, conn)
st.write("Query Results:")
st.dataframe(result)
# If the query involves plotting (like "plot sales"), show the chart
if "plot" in user_prompt.lower():
fig = px.bar(result, x='Date', y='Sales', title="Sales Over Time")
st.plotly_chart(fig)
else:
st.write("Sorry, I couldn't find a matching query for your prompt.")