garyd1 commited on
Commit
e0d64af
·
verified ·
1 Parent(s): dfe36ba

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -84
app.py CHANGED
@@ -1,95 +1,98 @@
 
1
  import streamlit as st
2
  import pandas as pd
3
- import os
4
- import tempfile
5
- from PyPDF2 import PdfReader
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from sentence_transformers import SentenceTransformer
8
- import faiss
9
  import openai
 
 
 
 
 
10
 
11
- # OpenAI API key configuration
12
- st.set_page_config(page_title="RAG Chatbot with Files", layout="centered")
13
- openai.api_key = st.sidebar.text_input("Enter OpenAI API Key:", type="password")
14
-
15
- # Initialize FAISS and embedding model
16
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
17
- faiss_index = None
18
- data_chunks = []
19
- chunk_mapping = {}
20
-
21
- # File Upload and Processing
22
- def load_files(uploaded_files):
23
- global data_chunks, chunk_mapping, faiss_index
24
- data_chunks = []
25
- chunk_mapping = {}
26
- for uploaded_file in uploaded_files:
27
- file_type = uploaded_file.name.split('.')[-1]
28
- with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
29
- tmp_file.write(uploaded_file.read())
30
- tmp_file_path = tmp_file.name
31
-
32
- if file_type == "csv":
33
- df = pd.read_csv(tmp_file_path)
34
- content = "\n".join(df.astype(str).values.flatten())
35
- elif file_type == "xlsx":
36
- df = pd.read_excel(tmp_file_path)
37
- content = "\n".join(df.astype(str).values.flatten())
38
- elif file_type == "pdf":
39
- reader = PdfReader(tmp_file_path)
40
- content = "".join([page.extract_text() for page in reader.pages])
41
- else:
42
- st.error(f"Unsupported file type: {file_type}")
43
- continue
44
-
45
- # Split into chunks
46
- splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=50)
47
- chunks = splitter.split_text(content)
48
- data_chunks.extend(chunks)
49
- chunk_mapping.update({i: (uploaded_file.name, chunk) for i, chunk in enumerate(chunks)})
50
-
51
- # Create FAISS index
52
- embeddings = embedding_model.encode(data_chunks)
53
- faiss_index = faiss.IndexFlatL2(embeddings.shape[1])
54
- faiss_index.add(embeddings)
55
 
56
- # Query Processing
57
- def handle_query(query):
58
- if not faiss_index:
59
- return "No data available. Please upload files first."
60
 
61
- # Generate embedding for the query
62
- query_embedding = embedding_model.encode([query])
63
- distances, indices = faiss_index.search(query_embedding, k=5)
64
- relevant_chunks = [chunk_mapping[idx][1] for idx in indices[0]]
65
 
66
- # Use OpenAI for summarization
67
- prompt = "Summarize the following information:\n" + "\n".join(relevant_chunks)
68
- response = openai.Completion.create(
69
- engine="text-davinci-003",
70
- prompt=prompt,
71
- max_tokens=150
72
- )
73
- return response['choices'][0]['text']
 
74
 
75
- # Streamlit UI
76
- def main():
77
- st.title("RAG Chatbot with Files")
78
- st.sidebar.title("Options")
79
- uploaded_files = st.sidebar.file_uploader("Upload files (CSV, Excel, PDF):", type=["csv", "xlsx", "pdf"], accept_multiple_files=True)
 
 
 
80
 
81
- if uploaded_files:
82
- load_files(uploaded_files)
83
- st.sidebar.success("Files loaded successfully!")
 
 
 
 
 
84
 
85
- query = st.text_input("Ask a question about the data:")
86
- if st.button("Get Answer"):
87
- if openai.api_key and query:
88
- answer = handle_query(query)
89
- st.subheader("Answer:")
90
- st.write(answer)
91
- else:
92
- st.error("Please provide a valid API key and query.")
93
 
94
- if __name__ == "__main__":
95
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
  import streamlit as st
3
  import pandas as pd
 
 
 
 
 
 
4
  import openai
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
8
+ from dotenv import load_dotenv
9
+ import anthropic
10
 
11
+ # Load environment variables
12
+ load_dotenv()
13
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
14
+ os.environ["ANTHROPIC_API_KEY"] = os.getenv("ANTHROPIC_API_KEY")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
+ st.title("Excel Q&A Chatbot 📊")
 
 
 
17
 
18
+ # Model Selection
19
+ model_choice = st.selectbox("Select LLM Model", ["OpenAI GPT-3.5", "Claude 3 Haiku", "Mistral-7B"])
 
 
20
 
21
+ # Load appropriate model based on selection
22
+ if model_choice == "Mistral-7B":
23
+ model_name = "mistralai/Mistral-7B-Instruct"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+ model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
26
+ def ask_mistral(query):
27
+ inputs = tokenizer(query, return_tensors="pt").to("cuda")
28
+ output = model.generate(**inputs)
29
+ return tokenizer.decode(output[0])
30
 
31
+ elif model_choice == "Claude 3 Haiku":
32
+ client = anthropic.Anthropic(api_key=os.environ["ANTHROPIC_API_KEY"])
33
+ def ask_claude(query):
34
+ response = client.messages.create(
35
+ model="claude-3-haiku",
36
+ messages=[{"role": "user", "content": query}]
37
+ )
38
+ return response.content
39
 
40
+ else:
41
+ client = openai.OpenAI()
42
+ def ask_gpt(query):
43
+ response = client.chat.completions.create(
44
+ model="gpt-3.5-turbo",
45
+ messages=[{"role": "user", "content": query}]
46
+ )
47
+ return response.choices[0].message.content
48
 
49
+ # File Upload
50
+ uploaded_file = st.file_uploader("Upload an Excel file", type=["csv", "xlsx"])
 
 
 
 
 
 
51
 
52
+ if uploaded_file is not None:
53
+ file_extension = uploaded_file.name.split(".")[-1].lower()
54
+ df = pd.read_csv(uploaded_file) if file_extension == "csv" else pd.read_excel(uploaded_file)
55
+ st.write("### Preview of Data:")
56
+ st.write(df.head())
57
+
58
+ # Extract metadata
59
+ column_names = df.columns.tolist()
60
+ data_types = df.dtypes.apply(lambda x: x.name).to_dict()
61
+ missing_values = df.isnull().sum().to_dict()
62
+
63
+ # Display metadata
64
+ st.write("### Column Details:")
65
+ st.write(pd.DataFrame({"Column": column_names, "Type": data_types.values(), "Missing Values": missing_values.values()}))
66
+
67
+ # User Query
68
+ query = st.text_input("Ask a question about this data:")
69
+
70
+ if st.button("Submit Query"):
71
+ if query:
72
+ # Interpret the query using selected LLM
73
+ if model_choice == "Mistral-7B":
74
+ parsed_query = ask_mistral(f"Convert this question into a Pandas operation: {query}")
75
+ elif model_choice == "Claude 3 Haiku":
76
+ parsed_query = ask_claude(f"Convert this question into a Pandas operation: {query}")
77
+ else:
78
+ parsed_query = ask_gpt(f"Convert this question into a Pandas operation: {query}")
79
+
80
+ # Execute the query
81
+ try:
82
+ result = eval(f"df.{parsed_query}")
83
+ st.write("### Result:")
84
+ st.write(result if isinstance(result, pd.DataFrame) else str(result))
85
+
86
+ # If numerical data, show a visualization
87
+ if isinstance(result, pd.Series) and result.dtype in ["int64", "float64"]:
88
+ fig, ax = plt.subplots()
89
+ result.plot(kind="bar", ax=ax)
90
+ st.pyplot(fig)
91
+
92
+ except Exception as e:
93
+ st.error(f"Error executing query: {str(e)}")
94
+
95
+ # Memory for context retention
96
+ if "query_history" not in st.session_state:
97
+ st.session_state.query_history = []
98
+ st.session_state.query_history.append(query)