DrishtiSharma commited on
Commit
556425f
·
verified ·
1 Parent(s): 9f9c5f9

Create oh_yeah.py

Browse files
Files changed (1) hide show
  1. lab/oh_yeah.py +121 -0
lab/oh_yeah.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import os
4
+ from dotenv import load_dotenv
5
+ from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader
6
+ from llama_index.core.readers.base import BaseReader
7
+ from llama_index.readers.file.paged_csv.base import PagedCSVReader
8
+ from llama_index.embeddings.openai import OpenAIEmbedding
9
+ from llama_index.llms.openai import OpenAI
10
+ from llama_index.vector_stores.faiss import FaissVectorStore
11
+ from llama_index.core.ingestion import IngestionPipeline
12
+ from langchain_community.document_loaders.csv_loader import CSVLoader
13
+ from langchain_community.vectorstores import FAISS as LangChainFAISS
14
+ from langchain.chains import create_retrieval_chain
15
+ from langchain.chains.combine_documents import create_stuff_documents_chain
16
+ from langchain_core.prompts import ChatPromptTemplate
17
+ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
18
+ import faiss
19
+
20
+ # Load environment variables
21
+ os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
22
+
23
+ # Global settings for LlamaIndex
24
+ EMBED_DIMENSION = 512
25
+ Settings.llm = OpenAI(model="gpt-3.5-turbo")
26
+ Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small", dimensions=EMBED_DIMENSION)
27
+
28
+ # Streamlit app
29
+ st.title("Chat w CSV Files - LangChain Vs LlamaIndex ")
30
+
31
+ # File uploader
32
+ uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
33
+ if uploaded_file:
34
+ # Save the uploaded file temporarily
35
+ temp_file_path = f"temp_{uploaded_file.name}"
36
+ with open(temp_file_path, "wb") as temp_file:
37
+ temp_file.write(uploaded_file.getbuffer())
38
+
39
+ # Read and preview CSV data
40
+ data = pd.read_csv(temp_file_path)
41
+ st.write("Preview of uploaded data:")
42
+ st.dataframe(data)
43
+
44
+ # Tabs
45
+ tab1, tab2 = st.tabs(["Chat w CSV using LangChain", "Chat w CSV using LlamaIndex"])
46
+
47
+ # LangChain Tab
48
+ with tab1:
49
+ st.subheader("LangChain Query")
50
+ loader = CSVLoader(file_path=temp_file_path)
51
+ docs = loader.load_and_split()
52
+
53
+ # Preview the first document
54
+ if docs:
55
+ st.write("Preview of a document chunk (LangChain):")
56
+ st.text(docs[0].page_content)
57
+
58
+ # LangChain FAISS VectorStore
59
+ langchain_index = faiss.IndexFlatL2(EMBED_DIMENSION)
60
+ langchain_vector_store = LangChainFAISS(
61
+ embedding_function=OpenAIEmbeddings(),
62
+ index=langchain_index,
63
+ )
64
+ langchain_vector_store.add_documents(docs)
65
+
66
+ # LangChain Retrieval Chain
67
+ retriever = langchain_vector_store.as_retriever()
68
+ system_prompt = (
69
+ "You are an assistant for question-answering tasks. "
70
+ "Use the following pieces of retrieved context to answer "
71
+ "the question. If you don't know the answer, say that you "
72
+ "don't know. Use three sentences maximum and keep the "
73
+ "answer concise.\n\n{context}"
74
+ )
75
+ prompt = ChatPromptTemplate.from_messages(
76
+ [("system", system_prompt), ("human", "{input}")]
77
+ )
78
+ question_answer_chain = create_stuff_documents_chain(ChatOpenAI(), prompt)
79
+ langchain_rag_chain = create_retrieval_chain(retriever, question_answer_chain)
80
+
81
+ # Query input for LangChain
82
+ query = st.text_input("Ask a question about your data (LangChain):")
83
+ if query:
84
+ answer = langchain_rag_chain.invoke({"input": query})
85
+ st.write(f"Answer: {answer['answer']}")
86
+
87
+ # LlamaIndex Tab
88
+ with tab2:
89
+ st.subheader("LlamaIndex Query")
90
+ csv_reader = PagedCSVReader()
91
+ reader = SimpleDirectoryReader(
92
+ input_files=[temp_file_path],
93
+ file_extractor={".csv": csv_reader},
94
+ )
95
+ docs = reader.load_data()
96
+
97
+ # Preview the first document
98
+ if docs:
99
+ st.write("Preview of a document chunk (LlamaIndex):")
100
+ st.text(docs[0].text)
101
+
102
+ # Initialize FAISS Vector Store
103
+ llama_faiss_index = faiss.IndexFlatL2(EMBED_DIMENSION)
104
+ llama_vector_store = FaissVectorStore(faiss_index=llama_faiss_index)
105
+
106
+ # Create the ingestion pipeline and process the data
107
+ pipeline = IngestionPipeline(vector_store=llama_vector_store, documents=docs)
108
+ nodes = pipeline.run()
109
+
110
+ # Create a query engine
111
+ llama_index = VectorStoreIndex(nodes)
112
+ query_engine = llama_index.as_query_engine(similarity_top_k=3)
113
+
114
+ # Query input for LlamaIndex
115
+ query = st.text_input("Ask a question about your data (LlamaIndex):")
116
+ if query:
117
+ response = query_engine.query(query)
118
+ st.write(f"Answer: {response.response}")
119
+
120
+ # Cleanup temporary file
121
+ os.remove(temp_file_path)