DrishtiSharma commited on
Commit
40e0a99
·
verified ·
1 Parent(s): 914e762

Create test.py

Browse files
Files changed (1) hide show
  1. test.py +145 -0
test.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import io
4
+ import os
5
+ from dotenv import load_dotenv
6
+ from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader
7
+ from llama_index.readers.file.paged_csv.base import PagedCSVReader
8
+ from llama_index.embeddings.openai import OpenAIEmbedding
9
+ from llama_index.llms.openai import OpenAI
10
+ from llama_index.vector_stores.faiss import FaissVectorStore
11
+ from llama_index.core.ingestion import IngestionPipeline
12
+ from langchain_community.document_loaders.csv_loader import CSVLoader
13
+ from langchain_community.vectorstores import FAISS as LangChainFAISS
14
+ from langchain.chains import create_retrieval_chain
15
+ from langchain.chains.combine_documents import create_stuff_documents_chain
16
+ from langchain_core.prompts import ChatPromptTemplate
17
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
18
+ import faiss
19
+ import tempfile
20
+
21
+ # Load environment variables
22
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
23
+
24
+ # Global settings for LlamaIndex
25
+ EMBED_DIMENSION = 512
26
+ Settings.llm = OpenAI(model="gpt-3.5-turbo")
27
+ Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small", dimensions=EMBED_DIMENSION)
28
+
29
+ # Streamlit app
30
+ st.title("Chat with CSV Files - LangChain vs LlamaIndex")
31
+
32
+ # File uploader
33
+ uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
34
+ if uploaded_file:
35
+ try:
36
+ # Read and preview CSV data using pandas
37
+ data = pd.read_csv(uploaded_file)
38
+ st.write("Preview of uploaded data:")
39
+ st.dataframe(data)
40
+
41
+ # Tabs
42
+ tab1, tab2 = st.tabs(["Chat w CSV using LangChain", "Chat w CSV using LlamaIndex"])
43
+
44
+ # LangChain Tab
45
+ with tab1:
46
+ st.subheader("LangChain Query")
47
+ try:
48
+ # Save the uploaded file to a temporary file for LangChain
49
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as temp_file:
50
+ # Write the DataFrame to the temp file
51
+ data.to_csv(temp_file.name, index=False)
52
+ temp_file_path = temp_file.name
53
+
54
+ # Use CSVLoader with the temporary file path
55
+ loader = CSVLoader(file_path=temp_file_path)
56
+ docs = loader.load_and_split()
57
+
58
+ # Preview the first document
59
+ if docs:
60
+ st.write("Preview of a document chunk (LangChain):")
61
+ st.text(docs[0].page_content)
62
+
63
+ # LangChain FAISS VectorStore
64
+ langchain_index = faiss.IndexFlatL2(EMBED_DIMENSION)
65
+ langchain_vector_store = LangChainFAISS(
66
+ embedding_function=OpenAIEmbeddings(),
67
+ index=langchain_index,
68
+ )
69
+ langchain_vector_store.add_documents(docs)
70
+
71
+ # LangChain Retrieval Chain
72
+ retriever = langchain_vector_store.as_retriever()
73
+ system_prompt = (
74
+ "You are an assistant for question-answering tasks. "
75
+ "Use the following pieces of retrieved context to answer "
76
+ "the question. If you don't know the answer, say that you "
77
+ "don't know. Use three sentences maximum and keep the "
78
+ "answer concise.\n\n{context}"
79
+ )
80
+ prompt = ChatPromptTemplate.from_messages(
81
+ [("system", system_prompt), ("human", "{input}")]
82
+ )
83
+ question_answer_chain = create_stuff_documents_chain(ChatOpenAI(), prompt)
84
+ langchain_rag_chain = create_retrieval_chain(retriever, question_answer_chain)
85
+
86
+ # Query input for LangChain
87
+ query = st.text_input("Ask a question about your data (LangChain):")
88
+ if query:
89
+ answer = langchain_rag_chain.invoke({"input": query})
90
+ st.write(f"Answer: {answer['answer']}")
91
+
92
+ except Exception as e:
93
+ st.error(f"Error processing with LangChain: {e}")
94
+ finally:
95
+ # Clean up the temporary file
96
+ if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
97
+ os.remove(temp_file_path)
98
+
99
+ # LlamaIndex Tab
100
+ with tab2:
101
+ st.subheader("LlamaIndex Query")
102
+ try:
103
+ # Save uploaded file content to a temporary CSV file for LlamaIndex
104
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as temp_file:
105
+ data.to_csv(temp_file.name, index=False)
106
+ temp_file_path = temp_file.name
107
+
108
+ # Use PagedCSVReader for LlamaIndex
109
+ csv_reader = PagedCSVReader()
110
+ reader = SimpleDirectoryReader(
111
+ input_files=[temp_file_path],
112
+ file_extractor={".csv": csv_reader},
113
+ )
114
+ docs = reader.load_data()
115
+
116
+ # Preview the first document
117
+ if docs:
118
+ st.write("Preview of a document chunk (LlamaIndex):")
119
+ st.text(docs[0].text)
120
+
121
+ # Initialize FAISS Vector Store
122
+ llama_faiss_index = faiss.IndexFlatL2(EMBED_DIMENSION)
123
+ llama_vector_store = FaissVectorStore(faiss_index=llama_faiss_index)
124
+
125
+ # Create the ingestion pipeline and process the data
126
+ pipeline = IngestionPipeline(vector_store=llama_vector_store, documents=docs)
127
+ nodes = pipeline.run()
128
+
129
+ # Create a query engine
130
+ llama_index = VectorStoreIndex(nodes)
131
+ query_engine = llama_index.as_query_engine(similarity_top_k=3)
132
+
133
+ # Query input for LlamaIndex
134
+ query = st.text_input("Ask a question about your data (LlamaIndex):")
135
+ if query:
136
+ response = query_engine.query(query)
137
+ st.write(f"Answer: {response.response}")
138
+ except Exception as e:
139
+ st.error(f"Error processing with LlamaIndex: {e}")
140
+ finally:
141
+ # Clean up the temporary file
142
+ if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
143
+ os.remove(temp_file_path)
144
+ except Exception as e:
145
+ st.error(f"Error reading uploaded file: {e}")