DrishtiSharma commited on
Commit
a756b7d
Β·
verified Β·
1 Parent(s): 14448f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -35
app.py CHANGED
@@ -18,6 +18,7 @@ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
18
  from langchain_core.documents import Document
19
  import faiss
20
  import tempfile
 
21
 
22
  # Load environment variables
23
  os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
@@ -28,8 +29,8 @@ if not os.getenv("OPENAI_API_KEY"):
28
 
29
  # βœ… Ensure OpenAI Embeddings match FAISS dimensions
30
  embedding_function = OpenAIEmbeddings()
31
- test_vector = embedding_function.embed_query("test") # Sample embedding
32
- faiss_dimension = len(test_vector) # βœ… Dynamically detect correct dimension
33
 
34
  # βœ… Update global settings for LlamaIndex
35
  Settings.llm = OpenAI(model="gpt-4o")
@@ -53,31 +54,27 @@ if uploaded_file:
53
  data.to_csv(temp_file.name, index=False, encoding="utf-8")
54
  temp_file.flush()
55
 
56
- # Debugging: Verify the temporary file (Display partial content)
57
- st.write("Temporary file path:", temp_file_path)
58
- with open(temp_file_path, "r") as f:
59
- content = f.read()
60
- st.write("Partial file content (first 500 characters):")
61
- st.text(content[:500])
62
-
63
  # Tabs for LangChain and LlamaIndex
64
- tab1, tab2 = st.tabs(["LangChain", "LlamaIndex"])
65
 
66
  # βœ… LangChain Processing
67
  with tab1:
68
  st.subheader("LangChain Query")
69
 
70
  try:
71
- # βœ… Convert CSV rows into LangChain Document objects
72
  st.write("Processing CSV with a custom loader...")
 
73
  documents = []
 
74
  for _, row in data.iterrows():
75
  content = "\n".join([f"{col}: {row[col]}" for col in data.columns])
76
- doc = Document(page_content=content)
77
- documents.append(doc)
 
 
78
 
79
-
80
- # βœ… Create FAISS VectorStore with Correct Dimensions
81
  st.write(f"βœ… Initializing FAISS with dimension: {faiss_dimension}")
82
  langchain_index = faiss.IndexFlatL2(faiss_dimension)
83
 
@@ -98,27 +95,24 @@ if uploaded_file:
98
  except Exception as e:
99
  st.error(f"Error adding documents to FAISS: {e}")
100
 
101
- # βœ… Create LangChain Query Execution Pipeline
102
- retriever = langchain_vector_store.as_retriever()
103
- system_prompt = (
104
- "You are an assistant for question-answering tasks. "
105
- "Use the following pieces of retrieved context to answer "
106
- "the question. If you don't know the answer, say that you "
107
- "don't know. Use three sentences maximum and keep the "
108
- "answer concise.\n\n{context}"
109
- )
110
- prompt = ChatPromptTemplate.from_messages(
111
- [("system", system_prompt), ("human", "{input}")]
112
- )
113
- question_answer_chain = create_stuff_documents_chain(ChatOpenAI(model="gpt-4o"), prompt)
114
- langchain_rag_chain = create_retrieval_chain(retriever, question_answer_chain)
115
 
116
  # βœ… Query Processing
117
  query = st.text_input("Ask a question about your data (LangChain):")
118
 
119
  if query:
120
  try:
121
- st.write("Processing your question...")
 
 
 
 
 
 
 
 
 
122
  answer = langchain_rag_chain.invoke({"input": query})
123
  st.write(f"**Answer:** {answer['answer']}")
124
  except Exception as e:
@@ -130,8 +124,3 @@ if uploaded_file:
130
  error_message = traceback.format_exc()
131
  st.error(f"Error processing with LangChain: {e}")
132
  st.text(error_message)
133
-
134
- except Exception as e:
135
- error_message = traceback.format_exc()
136
- st.error(f"Error reading uploaded file: {e}")
137
- st.text(error_message)
 
18
  from langchain_core.documents import Document
19
  import faiss
20
  import tempfile
21
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
22
 
23
  # Load environment variables
24
  os.environ["OPENAI_API_KEY"] = os.getenv("OPENAI_API_KEY")
 
29
 
30
  # βœ… Ensure OpenAI Embeddings match FAISS dimensions
31
  embedding_function = OpenAIEmbeddings()
32
+ test_vector = embedding_function.embed_query("test")
33
+ faiss_dimension = len(test_vector)
34
 
35
  # βœ… Update global settings for LlamaIndex
36
  Settings.llm = OpenAI(model="gpt-4o")
 
54
  data.to_csv(temp_file.name, index=False, encoding="utf-8")
55
  temp_file.flush()
56
 
 
 
 
 
 
 
 
57
  # Tabs for LangChain and LlamaIndex
58
+ tab1, tab2 = st.tabs(["Chat w CSV using LangChain", "Chat w CSV using LlamaIndex"])
59
 
60
  # βœ… LangChain Processing
61
  with tab1:
62
  st.subheader("LangChain Query")
63
 
64
  try:
65
+ # βœ… Convert CSV rows into LangChain Document objects with chunking
66
  st.write("Processing CSV with a custom loader...")
67
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=90)
68
  documents = []
69
+
70
  for _, row in data.iterrows():
71
  content = "\n".join([f"{col}: {row[col]}" for col in data.columns])
72
+ chunks = text_splitter.split_text(content)
73
+ for chunk in chunks:
74
+ doc = Document(page_content=chunk)
75
+ documents.append(doc)
76
 
77
+ # βœ… Create FAISS VectorStore
 
78
  st.write(f"βœ… Initializing FAISS with dimension: {faiss_dimension}")
79
  langchain_index = faiss.IndexFlatL2(faiss_dimension)
80
 
 
95
  except Exception as e:
96
  st.error(f"Error adding documents to FAISS: {e}")
97
 
98
+ # βœ… Limit number of retrieved documents
99
+ retriever = langchain_vector_store.as_retriever(search_kwargs={"k": 5})
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
  # βœ… Query Processing
102
  query = st.text_input("Ask a question about your data (LangChain):")
103
 
104
  if query:
105
  try:
106
+ retrieved_context = "\n\n".join([doc.page_content for doc in retriever.get_relevant_documents(query)])
107
+ retrieved_context = retrieved_context[:3000]
108
+
109
+ system_prompt = (
110
+ "You are an assistant for question-answering tasks. "
111
+ "Use the following pieces of retrieved context to answer "
112
+ "the question. Keep the answer concise.\n\n"
113
+ f"{retrieved_context}"
114
+ )
115
+
116
  answer = langchain_rag_chain.invoke({"input": query})
117
  st.write(f"**Answer:** {answer['answer']}")
118
  except Exception as e:
 
124
  error_message = traceback.format_exc()
125
  st.error(f"Error processing with LangChain: {e}")
126
  st.text(error_message)