Spaces:
Running
Running
import os | |
import streamlit as st | |
import pandas as pd | |
import openai | |
import torch | |
import matplotlib.pyplot as plt | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from dotenv import load_dotenv | |
import anthropic | |
# Load environment variables | |
load_dotenv() | |
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY") | |
os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY") | |
# UI Styling | |
st.markdown( | |
""" | |
<style> | |
.stButton button { | |
background-color: #1F6FEB; | |
color: white; | |
border-radius: 8px; | |
border: none; | |
padding: 10px 20px; | |
font-weight: bold; | |
} | |
.stButton button:hover { | |
background-color: #1A4FC5; | |
} | |
.stTextInput > div > input { | |
border: 1px solid #30363D; | |
background-color: #161B22; | |
color: #C9D1D9; | |
border-radius: 6px; | |
padding: 10px; | |
} | |
.stFileUploader > div { | |
border: 2px dashed #30363D; | |
background-color: #161B22; | |
color: #C9D1D9; | |
border-radius: 6px; | |
padding: 10px; | |
} | |
.response-box { | |
background-color: #161B22; | |
padding: 10px; | |
border-radius: 6px; | |
margin-bottom: 10px; | |
color: #FFFFFF; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
st.title("Excel Q&A Chatbot π") | |
# Model Selection | |
model_choice = st.selectbox("Select LLM Model", ["OpenAI GPT-3.5", "Claude 3 Haiku", "Mistral-7B"]) | |
# Load appropriate model based on selection | |
if model_choice == "Mistral-7B": | |
model_name = "mistralai/Mistral-7B-Instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16) | |
def ask_mistral(query): | |
inputs = tokenizer(query, return_tensors="pt").to("cuda") | |
output = model.generate(**inputs) | |
return tokenizer.decode(output[0]) | |
elif model_choice == "Claude 3 Haiku": | |
client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"]) | |
def ask_claude(query): | |
response = client.messages.create( | |
model="claude-3-haiku", | |
messages=[{"role": "user", "content": query}] | |
) | |
return response.content | |
else: | |
client = openai.OpenAI() | |
def ask_gpt(query): | |
response = client.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[{"role": "user", "content": query}] | |
) | |
return response.choices[0].message.content | |
# File Upload | |
uploaded_file = st.file_uploader("Upload an Excel file", type=["csv", "xlsx"]) | |
if uploaded_file is not None: | |
file_extension = uploaded_file.name.split(".")[-1].lower() | |
df = pd.read_csv(uploaded_file) if file_extension == "csv" else pd.read_excel(uploaded_file) | |
st.write("### Preview of Data:") | |
st.write(df.head()) | |
# Extract metadata | |
column_names = df.columns.tolist() | |
data_types = df.dtypes.apply(lambda x: x.name).to_dict() | |
missing_values = df.isnull().sum().to_dict() | |
# Display metadata | |
st.write("### Column Details:") | |
st.write(pd.DataFrame({"Column": column_names, "Type": data_types.values(), "Missing Values": missing_values.values()})) | |
# User Query | |
query = st.text_input("Ask a question about this data:") | |
if st.button("Submit Query"): | |
if query: | |
# Interpret the query using selected LLM | |
if model_choice == "Mistral-7B": | |
parsed_query = ask_mistral(f"Convert this question into a Pandas operation: {query}") | |
elif model_choice == "Claude 3 Haiku": | |
parsed_query = ask_claude(f"Convert this question into a Pandas operation: {query}") | |
else: | |
parsed_query = ask_gpt(f"Convert this question into a Pandas operation: {query}") | |
# Execute the query | |
try: | |
result = eval(f"df.{parsed_query}") | |
st.write("### Result:") | |
st.write(result if isinstance(result, pd.DataFrame) else str(result)) | |
# If numerical data, show a visualization | |
if isinstance(result, pd.Series) and result.dtype in ["int64", "float64"]: | |
fig, ax = plt.subplots() | |
result.plot(kind="bar", ax=ax) | |
st.pyplot(fig) | |
except Exception as e: | |
st.error(f"Error executing query: {str(e)}") | |
# Memory for context retention | |
if "query_history" not in st.session_state: | |
st.session_state.query_history = [] | |
st.session_state.query_history.append(query) | |