DrishtiSharma commited on
Commit
32cea41
·
verified ·
1 Parent(s): 4f254c9

Update lab/app.py

Browse files
Files changed (1) hide show
  1. lab/app.py +111 -0
lab/app.py CHANGED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader
6
+ from llama_index.readers.file import PagedCSVReader
7
+ from llama_index.embeddings.openai import OpenAIEmbedding
8
+ from llama_index.llms.openai import OpenAI
9
+ from llama_index.vector_stores.faiss import FaissVectorStore
10
+ from llama_index.core.ingestion import IngestionPipeline
11
+ from langchain_community.document_loaders.csv_loader import CSVLoader
12
+ from langchain_community.vectorstores import FAISS as LangChainFAISS
13
+ from langchain.chains import create_retrieval_chain
14
+ from langchain.chains.combine_documents import create_stuff_documents_chain
15
+ from langchain_core.prompts import ChatPromptTemplate
16
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
17
+ import faiss
18
+
19
+ # Load environment variables
20
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
21
+
22
+ # Global OpenAI and FAISS settings
23
+ EMBED_DIMENSION = 512
24
+ llama_llm = OpenAI(model="gpt-3.5-turbo")
25
+ llama_embedding_model = OpenAIEmbedding(model="text-embedding-3-small", dimensions=EMBED_DIMENSION)
26
+ langchain_llm = ChatOpenAI(model="gpt-3.5-turbo-0125")
27
+
28
+ # Streamlit app
29
+ st.title("Streamlit App with LangChain and LlamaIndex")
30
+
31
+ # File uploader
32
+ uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
33
+ if uploaded_file:
34
+ data = pd.read_csv(uploaded_file)
35
+ st.write("Preview of uploaded data:")
36
+ st.dataframe(data)
37
+
38
+ # Tabs
39
+ tab1, tab2 = st.tabs(["Chat w CSV using LangChain", "Chat w CSV using LlamaIndex"])
40
+
41
+ # LangChain Tab
42
+ with tab1:
43
+ st.subheader("LangChain Query")
44
+ loader = CSVLoader(file_path=uploaded_file.name)
45
+ docs = loader.load_and_split()
46
+
47
+ # Preview the first document
48
+ st.write("Preview of a document chunk (LangChain):")
49
+ st.text(docs[0].page_content)
50
+
51
+ # LangChain FAISS VectorStore
52
+ langchain_index = faiss.IndexFlatL2(EMBED_DIMENSION)
53
+ langchain_vector_store = LangChainFAISS(
54
+ embedding_function=OpenAIEmbeddings(),
55
+ index=langchain_index,
56
+ )
57
+ langchain_vector_store.add_documents(docs)
58
+
59
+ # LangChain Retrieval Chain
60
+ retriever = langchain_vector_store.as_retriever()
61
+ system_prompt = (
62
+ "You are an assistant for question-answering tasks. "
63
+ "Use the following pieces of retrieved context to answer "
64
+ "the question. If you don't know the answer, say that you "
65
+ "don't know. Use three sentences maximum and keep the "
66
+ "answer concise.\n\n{context}"
67
+ )
68
+ prompt = ChatPromptTemplate.from_messages(
69
+ [("system", system_prompt), ("human", "{input}")]
70
+ )
71
+ question_answer_chain = create_stuff_documents_chain(langchain_llm, prompt)
72
+ langchain_rag_chain = create_retrieval_chain(retriever, question_answer_chain)
73
+
74
+ # Query input for LangChain
75
+ query = st.text_input("Ask a question about your data (LangChain):")
76
+ if query:
77
+ answer = langchain_rag_chain.invoke({"input": query})
78
+ st.write(f"Answer: {answer['answer']}")
79
+
80
+ # LlamaIndex Tab
81
+ with tab2:
82
+ st.subheader("LlamaIndex Query")
83
+ # Use PagedCSVReader for CSV loading
84
+ csv_reader = PagedCSVReader()
85
+ reader = SimpleDirectoryReader(
86
+ input_files=[uploaded_file.name],
87
+ file_extractor={".csv": csv_reader},
88
+ )
89
+ docs = reader.load_data()
90
+
91
+ # Preview the first document
92
+ st.write("Preview of a document chunk (LlamaIndex):")
93
+ st.text(docs[0].text)
94
+
95
+ # Initialize FAISS Vector Store
96
+ llama_faiss_index = faiss.IndexFlatL2(EMBED_DIMENSION)
97
+ llama_vector_store = FaissVectorStore(faiss_index=llama_faiss_index)
98
+
99
+ # Create the ingestion pipeline and process the data
100
+ pipeline = IngestionPipeline(vector_store=llama_vector_store, documents=docs)
101
+ nodes = pipeline.run()
102
+
103
+ # Create a query engine
104
+ llama_index = VectorStoreIndex(nodes)
105
+ query_engine = llama_index.as_query_engine(similarity_top_k=2)
106
+
107
+ # Query input for LlamaIndex
108
+ query = st.text_input("Ask a question about your data (LlamaIndex):")
109
+ if query:
110
+ response = query_engine.query(query)
111
+ st.write(f"Answer: {response.response}")