DrishtiSharma commited on
Commit
944593e
·
verified ·
1 Parent(s): feb602e

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -0
app.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from langchain_community.document_loaders.csv_loader import CSVLoader
6
+ from langchain_community.vectorstores import FAISS as LangChainFAISS
7
+ from langchain.chains import create_retrieval_chain
8
+ from langchain.chains.combine_documents import create_stuff_documents_chain
9
+ from langchain_core.prompts import ChatPromptTemplate
10
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
11
+ from llama_index import SimpleDirectoryReader, VectorStoreIndex
12
+ from llama_index.embeddings.openai import OpenAIEmbedding
13
+ from llama_index.vector_stores.faiss import FaissVectorStore
14
+ from llama_index.llms.openai import OpenAI
15
+ import faiss
16
+
17
+ # Load environment variables
18
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
19
+
20
+ # Set global OpenAI parameters
21
+ EMBED_DIMENSION = 512
22
+ llama_llm = OpenAI(model="gpt-3.5-turbo")
23
+ llama_embedding_model = OpenAIEmbedding(model="text-embedding-3-small", dimensions=EMBED_DIMENSION)
24
+ langchain_llm = ChatOpenAI(model="gpt-3.5-turbo")
25
+
26
+ # Streamlit app
27
+ st.title("Streamlit App with LangChain and LlamaIndex")
28
+
29
+ # File uploader
30
+ uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
31
+ if uploaded_file:
32
+ data = pd.read_csv(uploaded_file)
33
+ st.write("Preview of uploaded data:")
34
+ st.dataframe(data)
35
+
36
+ # Tabs
37
+ tab1, tab2 = st.tabs(["Chat w CSV using LangChain", "Chat w CSV using LlamaIndex"])
38
+
39
+ # LangChain Tab
40
+ with tab1:
41
+ st.subheader("LangChain Query")
42
+ loader = CSVLoader(file_path=uploaded_file)
43
+ docs = loader.load_and_split()
44
+
45
+ # LangChain FAISS VectorStore
46
+ langchain_index = faiss.IndexFlatL2(EMBED_DIMENSION)
47
+ langchain_vector_store = LangChainFAISS(
48
+ embedding_function=OpenAIEmbeddings(),
49
+ index=langchain_index,
50
+ )
51
+ langchain_vector_store.add_documents(docs)
52
+
53
+ # LangChain Retrieval Chain
54
+ retriever = langchain_vector_store.as_retriever()
55
+ system_prompt = (
56
+ "You are an assistant for question-answering tasks. "
57
+ "Use the following pieces of retrieved context to answer "
58
+ "the question. If you don't know the answer, say that you "
59
+ "don't know. Use three sentences maximum and keep the "
60
+ "answer concise.\n\n{context}"
61
+ )
62
+ prompt = ChatPromptTemplate.from_messages(
63
+ [("system", system_prompt), ("human", "{input}")]
64
+ )
65
+ question_answer_chain = create_stuff_documents_chain(langchain_llm, prompt)
66
+ langchain_rag_chain = create_retrieval_chain(retriever, question_answer_chain)
67
+
68
+ # Query input for LangChain
69
+ query = st.text_input("Ask a question about your data (LangChain):")
70
+ if query:
71
+ answer = langchain_rag_chain.invoke({"input": query})
72
+ st.write(f"Answer: {answer['answer']}")
73
+
74
+ # LlamaIndex Tab
75
+ with tab2:
76
+ st.subheader("LlamaIndex Query")
77
+ csv_reader = SimpleDirectoryReader(
78
+ input_files=[uploaded_file],
79
+ file_extractor={".csv": PagedCSVReader()},
80
+ )
81
+ docs = csv_reader.load_data()
82
+
83
+ # LlamaIndex FAISS VectorStore
84
+ llama_faiss_index = faiss.IndexFlatL2(EMBED_DIMENSION)
85
+ llama_vector_store = FaissVectorStore(faiss_index=llama_faiss_index)
86
+ pipeline = IngestionPipeline(vector_store=llama_vector_store, documents=docs)
87
+ nodes = pipeline.run()
88
+
89
+ # LlamaIndex Query Engine
90
+ llama_index = VectorStoreIndex(nodes)
91
+ query_engine = llama_index.as_query_engine(similarity_top_k=2)
92
+
93
+ # Query input for LlamaIndex
94
+ query = st.text_input("Ask a question about your data (LlamaIndex):")
95
+ if query:
96
+ response = query_engine.query(query)
97
+ st.write(f"Answer: {response.response}")