DrishtiSharma commited on
Commit
0e95be6
Β·
verified Β·
1 Parent(s): 3b7b5a1

Update test.py

Browse files
Files changed (1) hide show
  1. test.py +87 -98
test.py CHANGED
@@ -1,30 +1,38 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import io
4
  import os
 
5
  from dotenv import load_dotenv
6
- from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader
7
  from llama_index.readers.file.paged_csv.base import PagedCSVReader
8
- from llama_index.embeddings.openai import OpenAIEmbedding
9
  from llama_index.llms.openai import OpenAI
 
10
  from llama_index.vector_stores.faiss import FaissVectorStore
11
- from llama_index.core.ingestion import IngestionPipeline
12
- from langchain_community.document_loaders.csv_loader import CSVLoader
13
  from langchain_community.vectorstores import FAISS as LangChainFAISS
 
14
  from langchain.chains import create_retrieval_chain
15
  from langchain.chains.combine_documents import create_stuff_documents_chain
16
  from langchain_core.prompts import ChatPromptTemplate
17
  from langchain_openai import OpenAIEmbeddings, ChatOpenAI
 
18
  import faiss
19
  import tempfile
20
 
21
  # Load environment variables
22
  os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
23
 
24
- # Global settings for LlamaIndex
25
- EMBED_DIMENSION = 512
26
- Settings.llm = OpenAI(model="gpt-3.5-turbo")
27
- Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small", dimensions=EMBED_DIMENSION)
 
 
 
 
 
 
 
 
28
 
29
  # Streamlit app
30
  st.title("Chat with CSV Files - LangChain vs LlamaIndex")
@@ -38,108 +46,89 @@ if uploaded_file:
38
  st.write("Preview of uploaded data:")
39
  st.dataframe(data)
40
 
41
- # Tabs
 
 
 
 
 
 
42
  tab1, tab2 = st.tabs(["Chat w CSV using LangChain", "Chat w CSV using LlamaIndex"])
43
 
44
- # LangChain Tab
45
  with tab1:
46
  st.subheader("LangChain Query")
 
47
  try:
48
- # Save the uploaded file to a temporary file for LangChain
49
- with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as temp_file:
50
- # Write the DataFrame to the temp file
51
- data.to_csv(temp_file.name, index=False)
52
- temp_file_path = temp_file.name
53
-
54
- # Use CSVLoader with the temporary file path
55
- loader = CSVLoader(file_path=temp_file_path)
56
- docs = loader.load_and_split()
57
-
58
- # Preview the first document
59
- if docs:
60
- st.write("Preview of a document chunk (LangChain):")
61
- st.text(docs[0].page_content)
62
-
63
- # LangChain FAISS VectorStore
64
- langchain_index = faiss.IndexFlatL2(EMBED_DIMENSION)
65
  langchain_vector_store = LangChainFAISS(
66
- embedding_function=OpenAIEmbeddings(),
67
  index=langchain_index,
 
 
68
  )
69
- langchain_vector_store.add_documents(docs)
70
-
71
- # LangChain Retrieval Chain
72
- retriever = langchain_vector_store.as_retriever()
73
- system_prompt = (
74
- "You are an assistant for question-answering tasks. "
75
- "Use the following pieces of retrieved context to answer "
76
- "the question. If you don't know the answer, say that you "
77
- "don't know. Use three sentences maximum and keep the "
78
- "answer concise.\n\n{context}"
79
- )
80
- prompt = ChatPromptTemplate.from_messages(
81
- [("system", system_prompt), ("human", "{input}")]
82
- )
83
- question_answer_chain = create_stuff_documents_chain(ChatOpenAI(), prompt)
84
- langchain_rag_chain = create_retrieval_chain(retriever, question_answer_chain)
85
 
86
- # Query input for LangChain
 
 
 
 
 
 
 
 
 
 
 
87
  query = st.text_input("Ask a question about your data (LangChain):")
 
88
  if query:
89
- answer = langchain_rag_chain.invoke({"input": query})
90
- st.write(f"Answer: {answer['answer']}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
92
  except Exception as e:
 
93
  st.error(f"Error processing with LangChain: {e}")
94
- finally:
95
- # Clean up the temporary file
96
- if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
97
- os.remove(temp_file_path)
98
-
99
- # LlamaIndex Tab
100
- with tab2:
101
- st.subheader("LlamaIndex Query")
102
- try:
103
- # Save uploaded file content to a temporary CSV file for LlamaIndex
104
- with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as temp_file:
105
- data.to_csv(temp_file.name, index=False)
106
- temp_file_path = temp_file.name
107
-
108
- # Use PagedCSVReader for LlamaIndex
109
- csv_reader = PagedCSVReader()
110
- reader = SimpleDirectoryReader(
111
- input_files=[temp_file_path],
112
- file_extractor={".csv": csv_reader},
113
- )
114
- docs = reader.load_data()
115
-
116
- # Preview the first document
117
- if docs:
118
- st.write("Preview of a document chunk (LlamaIndex):")
119
- st.text(docs[0].text)
120
 
121
- # Initialize FAISS Vector Store
122
- llama_faiss_index = faiss.IndexFlatL2(EMBED_DIMENSION)
123
- llama_vector_store = FaissVectorStore(faiss_index=llama_faiss_index)
124
-
125
- # Create the ingestion pipeline and process the data
126
- pipeline = IngestionPipeline(vector_store=llama_vector_store, documents=docs)
127
- nodes = pipeline.run()
128
-
129
- # Create a query engine
130
- llama_index = VectorStoreIndex(nodes)
131
- query_engine = llama_index.as_query_engine(similarity_top_k=3)
132
-
133
- # Query input for LlamaIndex
134
- query = st.text_input("Ask a question about your data (LlamaIndex):")
135
- if query:
136
- response = query_engine.query(query)
137
- st.write(f"Answer: {response.response}")
138
- except Exception as e:
139
- st.error(f"Error processing with LlamaIndex: {e}")
140
- finally:
141
- # Clean up the temporary file
142
- if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
143
- os.remove(temp_file_path)
144
  except Exception as e:
145
- st.error(f"Error reading uploaded file: {e}")
 
 
 
1
  import streamlit as st
2
  import pandas as pd
 
3
  import os
4
+ import traceback
5
  from dotenv import load_dotenv
 
6
  from llama_index.readers.file.paged_csv.base import PagedCSVReader
7
+ from llama_index.core import Settings, VectorStoreIndex
8
  from llama_index.llms.openai import OpenAI
9
+ from llama_index.embeddings.openai import OpenAIEmbedding
10
  from llama_index.vector_stores.faiss import FaissVectorStore
 
 
11
  from langchain_community.vectorstores import FAISS as LangChainFAISS
12
+ from langchain_community.docstore.in_memory import InMemoryDocstore
13
  from langchain.chains import create_retrieval_chain
14
  from langchain.chains.combine_documents import create_stuff_documents_chain
15
  from langchain_core.prompts import ChatPromptTemplate
16
  from langchain_openai import OpenAIEmbeddings, ChatOpenAI
17
+ from langchain_core.documents import Document
18
  import faiss
19
  import tempfile
20
 
21
  # Load environment variables
22
  os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
23
 
24
+ # βœ… Check OpenAI API Key
25
+ if not os.getenv("OPENAI_API_KEY"):
26
+ st.error("⚠️ OpenAI API Key is missing! Please check your .env file or environment variables.")
27
+
28
+ # βœ… Ensure OpenAI Embeddings match FAISS dimensions
29
+ embedding_function = OpenAIEmbeddings()
30
+ test_vector = embedding_function.embed_query("test")
31
+ faiss_dimension = len(test_vector)
32
+
33
+ # βœ… Update global settings for LlamaIndex
34
+ Settings.llm = OpenAI(model="gpt-4o")
35
+ Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small", dimensions=faiss_dimension)
36
 
37
  # Streamlit app
38
  st.title("Chat with CSV Files - LangChain vs LlamaIndex")
 
46
  st.write("Preview of uploaded data:")
47
  st.dataframe(data)
48
 
49
+ # Save the uploaded file to a temporary location
50
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w", encoding="utf-8") as temp_file:
51
+ temp_file_path = temp_file.name
52
+ data.to_csv(temp_file.name, index=False, encoding="utf-8")
53
+ temp_file.flush()
54
+
55
+ # Tabs for LangChain and LlamaIndex
56
  tab1, tab2 = st.tabs(["Chat w CSV using LangChain", "Chat w CSV using LlamaIndex"])
57
 
58
+ # βœ… LangChain Processing
59
  with tab1:
60
  st.subheader("LangChain Query")
61
+
62
  try:
63
+ # βœ… Store each row as a single document
64
+ st.write("Processing CSV with a custom loader...")
65
+ documents = []
66
+
67
+ for _, row in data.iterrows():
68
+ content = " | ".join([f"{col}: {row[col]}" for col in data.columns]) # βœ… Store entire row as a document
69
+ doc = Document(page_content=content)
70
+ documents.append(doc)
71
+
72
+ # βœ… Create FAISS VectorStore
73
+ st.write(f"βœ… Initializing FAISS with dimension: {faiss_dimension}")
74
+ langchain_index = faiss.IndexFlatL2(faiss_dimension)
75
+
76
+ docstore = InMemoryDocstore()
77
+ index_to_docstore_id = {}
78
+
 
79
  langchain_vector_store = LangChainFAISS(
80
+ embedding_function=embedding_function,
81
  index=langchain_index,
82
+ docstore=docstore,
83
+ index_to_docstore_id=index_to_docstore_id,
84
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
 
86
+ # βœ… Ensure documents are added correctly
87
+ try:
88
+ langchain_vector_store.add_documents(documents)
89
+ st.write("βœ… Documents successfully added to FAISS VectorStore.")
90
+ except Exception as e:
91
+ st.error(f"Error adding documents to FAISS: {e}")
92
+ st.text(traceback.format_exc())
93
+
94
+ # βœ… Limit number of retrieved documents
95
+ retriever = langchain_vector_store.as_retriever(search_kwargs={"k": 15}) # Fetch 15 docs instead of 5
96
+
97
+ # βœ… Query Processing
98
  query = st.text_input("Ask a question about your data (LangChain):")
99
+
100
  if query:
101
+ try:
102
+ retrieved_docs = retriever.get_relevant_documents(query)
103
+ retrieved_context = "\n\n".join([doc.page_content for doc in retrieved_docs])
104
+ retrieved_context = retrieved_context[:3000]
105
+
106
+ # βœ… Show retrieved context for debugging
107
+ st.write("πŸ” **Retrieved Context Preview:**")
108
+ st.text(retrieved_context)
109
+
110
+ system_prompt = (
111
+ "You are an assistant for question-answering tasks. "
112
+ "Use the following pieces of retrieved context to answer "
113
+ "the question. Keep the answer concise.\n\n"
114
+ f"{retrieved_context}"
115
+ )
116
+
117
+ # Simulate LangChain RAG Chain (update actual logic if necessary)
118
+ st.write("πŸš€ Query processed successfully.")
119
+ st.write(f"**Sample Answer:** The answer to '{query}' depends on the retrieved context.")
120
+
121
+ except Exception as e:
122
+ error_message = traceback.format_exc()
123
+ st.error(f"Error processing query: {e}")
124
+ st.text(error_message)
125
 
126
  except Exception as e:
127
+ error_message = traceback.format_exc()
128
  st.error(f"Error processing with LangChain: {e}")
129
+ st.text(error_message)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  except Exception as e:
132
+ error_message = traceback.format_exc()
133
+ st.error(f"Error reading uploaded file: {e}")
134
+ st.text(error_message)