Ari
Update app.py
f7a9fb8 verified
raw
history blame
7.26 kB
import os
import streamlit as st
import pandas as pd
import numpy as np
import sqlite3
from langchain import OpenAI, LLMChain, PromptTemplate
import sqlparse
import logging
from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
import statsmodels.api as sm # For time series analysis
from sklearn.metrics.pairwise import cosine_similarity # For recommendations
# Configure logging
logging.basicConfig(level=logging.ERROR, format='%(asctime)s - %(levelname)s - %(message)s')
# Step 1: Load the dataset
def load_data():
st.header("Select or Upload a Dataset")
dataset_options = {
"Default Dataset": "default_data.csv",
# Add more datasets as needed
"Upload Your Own Dataset": None
}
selected_option = st.selectbox("Choose a dataset:", list(dataset_options.keys()))
if selected_option == "Upload Your Own Dataset":
uploaded_file = st.file_uploader("Upload your dataset (CSV file)", type=["csv"])
if uploaded_file is not None:
data = pd.read_csv(uploaded_file)
st.success("Data successfully loaded!")
return data
else:
st.info("Please upload a CSV file to proceed.")
return None
else:
file_path = dataset_options[selected_option]
if os.path.exists(file_path):
data = pd.read_csv(file_path)
st.success(f"'{selected_option}' successfully loaded!")
return data
else:
st.error(f"File '{file_path}' not found.")
return None
data = load_data()
if data is not None:
table_name = "selected_table" # Default table name
valid_columns = list(data.columns)
else:
st.stop() # Stop the script if data is not loaded
# Initialize the LLM
llm = OpenAI(temperature=0)
# Prompt Engineering
template = """
You are an expert data scientist assistant. Given a natural language question, the name of the table, and a list of valid columns, generate code that answers the question.
Instructions:
- If the question involves data retrieval or simple aggregations, generate a SQL query.
- If the question requires statistical analysis or time series analysis, generate a Python code snippet using pandas, numpy, and statsmodels.
- If the question involves predictions, modeling, or recommendations, generate a Python code snippet using scikit-learn or pandas.
- Ensure that you only use the columns provided.
- Do not include any import statements in the code.
- Provide the code between <CODE> and </CODE> tags.
Question: {question}
Table name: {table_name}
Valid columns: {columns}
Response:
"""
prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
# Set up the LLM Chain
sql_generation_chain = LLMChain(llm=llm, prompt=prompt)
# Helper functions
def extract_code(response):
"""Extracts code enclosed between <CODE> and </CODE> tags."""
import re
pattern = r"<CODE>(.*?)</CODE>"
match = re.search(pattern, response, re.DOTALL)
if match:
return match.group(1).strip()
else:
return None
def execute_code(code):
"""Executes the generated code and returns the result."""
if code.strip().lower().startswith('select'):
# It's a SQL query
conn = sqlite3.connect(':memory:')
data.to_sql(table_name, conn, index=False)
try:
result = pd.read_sql_query(code, conn)
conn.close()
return result
except Exception as e:
conn.close()
raise e
else:
# It's Python code
local_vars = {
'pd': pd,
'np': np,
'data': data.copy(),
'result': None,
'LinearRegression': LinearRegression,
'train_test_split': train_test_split,
'mean_squared_error': mean_squared_error,
'r2_score': r2_score,
'sm': sm, # Added statsmodels
'cosine_similarity': cosine_similarity # Added cosine_similarity
}
exec(code, {}, local_vars)
result = local_vars.get('result')
return result
# Process user input
def process_input():
user_prompt = st.session_state['user_input']
if user_prompt:
try:
# Append user message to history
st.session_state.history.append({"role": "user", "content": user_prompt})
if "columns" in user_prompt.lower():
assistant_response = f"The columns are: {', '.join(valid_columns)}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else:
columns = ', '.join(valid_columns)
response = sql_generation_chain.run({
'question': user_prompt,
'table_name': table_name,
'columns': columns
})
# Extract code from response
code = extract_code(response)
if code:
st.write(f"**Generated Code:**\n```python\n{code}\n```")
try:
result = execute_code(code)
assistant_response = "Result:"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
st.session_state.history.append({"role": "assistant", "content": result})
except Exception as e:
logging.error(f"An error occurred during code execution: {e}")
assistant_response = f"Error executing code: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
else:
assistant_response = response.strip()
st.session_state.history.append({"role": "assistant", "content": assistant_response})
except Exception as e:
logging.error(f"An error occurred: {e}")
assistant_response = f"Error: {e}"
st.session_state.history.append({"role": "assistant", "content": assistant_response})
# Reset the user_input in session state
st.session_state['user_input'] = ''
# Initialize session state variables
if 'history' not in st.session_state:
st.session_state.history = []
if 'user_input' not in st.session_state:
st.session_state['user_input'] = ''
# Display the conversation history
for message in st.session_state.history:
if message['role'] == 'user':
st.markdown(f"**User:** {message['content']}")
elif message['role'] == 'assistant':
content = message['content']
if isinstance(content, pd.DataFrame):
st.markdown("**Assistant:** Here are the results:")
st.dataframe(content)
elif isinstance(content, (int, float, str, list, dict)):
st.markdown(f"**Assistant:** {content}")
else:
st.markdown(f"**Assistant:** {content}")
# Place the text input after displaying the conversation
st.text_input("Enter your question:", key='user_input', on_change=process_input)