DrishtiSharma's picture
Update app.py
71c8775 verified
# ref: https://github.com/plaban1981/Agents/blob/main/Contextual_Retrieval_processing_prompt.ipynb
import streamlit as st
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.retrievers import BM25Retriever, ContextualCompressionRetriever, EnsembleRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain_community.document_transformers.embeddings_redundant_filter import EmbeddingsRedundantFilter
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate
import hashlib
from typing import List
# Contextual Retrieval Class
class ContextualRetrieval:
def __init__(self):
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
model_name = "BAAI/bge-large-en-v1.5"
self.embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs={'device': 'cpu'})
self.llm = ChatGroq(model="llama-3.2-3b-preview", temperature=0)
def process_document(self, document: str) -> List[Document]:
return self.text_splitter.create_documents([document])
def generate_contextualized_chunks(self, document: str, chunks: List[Document]) -> List[Document]:
contextualized_chunks = []
for chunk in chunks:
context = self._generate_context(document, chunk.page_content)
contextualized_content = f"{context}\n\n{chunk.page_content}"
contextualized_chunks.append(Document(page_content=contextualized_content))
return contextualized_chunks
def _generate_context(self, document: str, chunk: str) -> str:
prompt = ChatPromptTemplate.from_template("""
Based on the document and a specific chunk of text, generate a 2-3 sentence summary that contextualizes the chunk:
Document:
{document}
Chunk:
{chunk}
Context:
""")
messages = prompt.format_messages(document=document, chunk=chunk)
response = self.llm.invoke(messages)
return response.content.strip()
def create_vectorstore(self, chunks: List[Document]) -> FAISS:
return FAISS.from_documents(chunks, self.embeddings)
def create_bm25_retriever(self, chunks: List[Document]) -> BM25Retriever:
return BM25Retriever.from_documents(chunks)
def create_reranker(self, vectorstore):
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
return ContextualCompressionRetriever(base_compressor=FlashrankRerank(), base_retriever=retriever)
def create_ensemble(self, vectorstore, bm25_retriever):
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
return EnsembleRetriever(
retrievers=[vector_retriever, bm25_retriever],
weights=[0.5, 0.5]
)
def generate_answer(self, query: str, docs: List[Document]) -> str:
prompt = ChatPromptTemplate.from_template("""
Question: {query}
Relevant Information: {chunks}
Answer:
""")
messages = prompt.format_messages(query=query, chunks="\n\n".join([doc.page_content for doc in docs]))
response = self.llm.invoke(messages)
return response.content.strip()
# Streamlit UI
def main():
st.title("Interactive Ranking and Retrieval Analysis")
st.write("Experiment with multiple retrieval methods, ranking techniques, and dynamic contextualization.")
# Document Upload
uploaded_file = st.file_uploader("Upload a Text Document", type=['txt', 'md'])
if uploaded_file:
document = uploaded_file.read().decode("utf-8")
st.success("Document uploaded successfully!")
# Initialize Retrieval System
cr = ContextualRetrieval()
chunks = cr.process_document(document)
contextualized_chunks = cr.generate_contextualized_chunks(document, chunks)
# Create indexes and retrievers
original_vectorstore = cr.create_vectorstore(chunks)
contextualized_vectorstore = cr.create_vectorstore(contextualized_chunks)
original_bm25_retriever = cr.create_bm25_retriever(chunks)
contextualized_bm25_retriever = cr.create_bm25_retriever(contextualized_chunks)
# Rerankers and Ensemble Retrievers
original_reranker = cr.create_reranker(original_vectorstore)
contextualized_reranker = cr.create_reranker(contextualized_vectorstore)
original_ensemble = cr.create_ensemble(original_vectorstore, original_bm25_retriever)
contextualized_ensemble = cr.create_ensemble(contextualized_vectorstore, contextualized_bm25_retriever)
# Query Input
query = st.text_input("Enter your query:")
if query:
with st.spinner("Fetching results..."):
# Retrieve results
original_vector_results = original_vectorstore.similarity_search(query, k=3)
contextualized_vector_results = contextualized_vectorstore.similarity_search(query, k=3)
original_bm25_results = original_bm25_retriever.get_relevant_documents(query)
contextualized_bm25_results = contextualized_bm25_retriever.get_relevant_documents(query)
original_reranker_results = original_reranker.invoke(query)
contextualized_reranker_results = contextualized_reranker.invoke(query)
original_ensemble_results = original_ensemble.invoke(query)
contextualized_ensemble_results = contextualized_ensemble.invoke(query)
# Generate answers
original_vector_answer = cr.generate_answer(query, original_vector_results)
contextualized_vector_answer = cr.generate_answer(query, contextualized_vector_results)
original_bm25_answer = cr.generate_answer(query, original_bm25_results)
contextualized_bm25_answer = cr.generate_answer(query, contextualized_bm25_results)
original_reranker_answer = cr.generate_answer(query, original_reranker_results)
contextualized_reranker_answer = cr.generate_answer(query, contextualized_reranker_results)
original_ensemble_answer = cr.generate_answer(query, original_ensemble_results)
contextualized_ensemble_answer = cr.generate_answer(query, contextualized_ensemble_results)
# Display Results
st.subheader("Results Comparison")
col1, col2 = st.columns(2)
with col1:
st.write("### Original Results")
st.write("**Vector Search Answer**")
st.info(original_vector_answer)
st.write("**BM25 Search Answer**")
st.info(original_bm25_answer)
st.write("**Reranker Answer**")
st.info(original_reranker_answer)
st.write("**Ensemble Answer**")
st.info(original_ensemble_answer)
with col2:
st.write("### Contextualized Results")
st.write("**Vector Search Answer**")
st.info(contextualized_vector_answer)
st.write("**BM25 Search Answer**")
st.info(contextualized_bm25_answer)
st.write("**Reranker Answer**")
st.info(contextualized_reranker_answer)
st.write("**Ensemble Answer**")
st.info(contextualized_ensemble_answer)
if __name__ == "__main__":
main()