chat-w-csv / test.py
DrishtiSharma's picture
Create test.py
40e0a99 verified
raw
history blame
6.37 kB
import streamlit as st
import pandas as pd
import io
import os
from dotenv import load_dotenv
from llama_index.core import Settings, VectorStoreIndex, SimpleDirectoryReader
from llama_index.readers.file.paged_csv.base import PagedCSVReader
from llama_index.embeddings.openai import OpenAIEmbedding
from llama_index.llms.openai import OpenAI
from llama_index.vector_stores.faiss import FaissVectorStore
from llama_index.core.ingestion import IngestionPipeline
from langchain_community.document_loaders.csv_loader import CSVLoader
from langchain_community.vectorstores import FAISS as LangChainFAISS
from langchain.chains import create_retrieval_chain
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
import faiss
import tempfile
# Load environment variables
os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
# Global settings for LlamaIndex
EMBED_DIMENSION = 512
Settings.llm = OpenAI(model="gpt-3.5-turbo")
Settings.embed_model = OpenAIEmbedding(model="text-embedding-3-small", dimensions=EMBED_DIMENSION)
# Streamlit app
st.title("Chat with CSV Files - LangChain vs LlamaIndex")
# File uploader
uploaded_file = st.file_uploader("Upload a CSV file", type=["csv"])
if uploaded_file:
try:
# Read and preview CSV data using pandas
data = pd.read_csv(uploaded_file)
st.write("Preview of uploaded data:")
st.dataframe(data)
# Tabs
tab1, tab2 = st.tabs(["Chat w CSV using LangChain", "Chat w CSV using LlamaIndex"])
# LangChain Tab
with tab1:
st.subheader("LangChain Query")
try:
# Save the uploaded file to a temporary file for LangChain
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as temp_file:
# Write the DataFrame to the temp file
data.to_csv(temp_file.name, index=False)
temp_file_path = temp_file.name
# Use CSVLoader with the temporary file path
loader = CSVLoader(file_path=temp_file_path)
docs = loader.load_and_split()
# Preview the first document
if docs:
st.write("Preview of a document chunk (LangChain):")
st.text(docs[0].page_content)
# LangChain FAISS VectorStore
langchain_index = faiss.IndexFlatL2(EMBED_DIMENSION)
langchain_vector_store = LangChainFAISS(
embedding_function=OpenAIEmbeddings(),
index=langchain_index,
)
langchain_vector_store.add_documents(docs)
# LangChain Retrieval Chain
retriever = langchain_vector_store.as_retriever()
system_prompt = (
"You are an assistant for question-answering tasks. "
"Use the following pieces of retrieved context to answer "
"the question. If you don't know the answer, say that you "
"don't know. Use three sentences maximum and keep the "
"answer concise.\n\n{context}"
)
prompt = ChatPromptTemplate.from_messages(
[("system", system_prompt), ("human", "{input}")]
)
question_answer_chain = create_stuff_documents_chain(ChatOpenAI(), prompt)
langchain_rag_chain = create_retrieval_chain(retriever, question_answer_chain)
# Query input for LangChain
query = st.text_input("Ask a question about your data (LangChain):")
if query:
answer = langchain_rag_chain.invoke({"input": query})
st.write(f"Answer: {answer['answer']}")
except Exception as e:
st.error(f"Error processing with LangChain: {e}")
finally:
# Clean up the temporary file
if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
os.remove(temp_file_path)
# LlamaIndex Tab
with tab2:
st.subheader("LlamaIndex Query")
try:
# Save uploaded file content to a temporary CSV file for LlamaIndex
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv", mode="w") as temp_file:
data.to_csv(temp_file.name, index=False)
temp_file_path = temp_file.name
# Use PagedCSVReader for LlamaIndex
csv_reader = PagedCSVReader()
reader = SimpleDirectoryReader(
input_files=[temp_file_path],
file_extractor={".csv": csv_reader},
)
docs = reader.load_data()
# Preview the first document
if docs:
st.write("Preview of a document chunk (LlamaIndex):")
st.text(docs[0].text)
# Initialize FAISS Vector Store
llama_faiss_index = faiss.IndexFlatL2(EMBED_DIMENSION)
llama_vector_store = FaissVectorStore(faiss_index=llama_faiss_index)
# Create the ingestion pipeline and process the data
pipeline = IngestionPipeline(vector_store=llama_vector_store, documents=docs)
nodes = pipeline.run()
# Create a query engine
llama_index = VectorStoreIndex(nodes)
query_engine = llama_index.as_query_engine(similarity_top_k=3)
# Query input for LlamaIndex
query = st.text_input("Ask a question about your data (LlamaIndex):")
if query:
response = query_engine.query(query)
st.write(f"Answer: {response.response}")
except Exception as e:
st.error(f"Error processing with LlamaIndex: {e}")
finally:
# Clean up the temporary file
if 'temp_file_path' in locals() and os.path.exists(temp_file_path):
os.remove(temp_file_path)
except Exception as e:
st.error(f"Error reading uploaded file: {e}")