# ref: https://github.com/plaban1981/Agents/blob/main/Contextual_Retrieval_processing_prompt.ipynb import streamlit as st from langchain.text_splitter import RecursiveCharacterTextSplitter from langchain.schema import Document from langchain.embeddings import HuggingFaceEmbeddings from langchain.vectorstores import FAISS from langchain.retrievers import BM25Retriever, ContextualCompressionRetriever, EnsembleRetriever from langchain.retrievers.document_compressors import FlashrankRerank from langchain_community.document_transformers.embeddings_redundant_filter import EmbeddingsRedundantFilter from langchain_groq import ChatGroq from langchain.prompts import ChatPromptTemplate import hashlib from typing import List # Contextual Retrieval Class class ContextualRetrieval: def __init__(self): self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100) model_name = "BAAI/bge-large-en-v1.5" self.embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs={'device': 'cpu'}) self.llm = ChatGroq(model="llama-3.2-3b-preview", temperature=0) def process_document(self, document: str) -> List[Document]: return self.text_splitter.create_documents([document]) def generate_contextualized_chunks(self, document: str, chunks: List[Document]) -> List[Document]: contextualized_chunks = [] for chunk in chunks: context = self._generate_context(document, chunk.page_content) contextualized_content = f"{context}\n\n{chunk.page_content}" contextualized_chunks.append(Document(page_content=contextualized_content)) return contextualized_chunks def _generate_context(self, document: str, chunk: str) -> str: prompt = ChatPromptTemplate.from_template(""" Based on the document and a specific chunk of text, generate a 2-3 sentence summary that contextualizes the chunk: Document: {document} Chunk: {chunk} Context: """) messages = prompt.format_messages(document=document, chunk=chunk) response = self.llm.invoke(messages) return response.content.strip() def create_vectorstore(self, chunks: List[Document]) -> FAISS: return FAISS.from_documents(chunks, self.embeddings) def create_bm25_retriever(self, chunks: List[Document]) -> BM25Retriever: return BM25Retriever.from_documents(chunks) def create_reranker(self, vectorstore): retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) return ContextualCompressionRetriever(base_compressor=FlashrankRerank(), base_retriever=retriever) def create_ensemble(self, vectorstore, bm25_retriever): vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) return EnsembleRetriever( retrievers=[vector_retriever, bm25_retriever], weights=[0.5, 0.5] ) def generate_answer(self, query: str, docs: List[Document]) -> str: prompt = ChatPromptTemplate.from_template(""" Question: {query} Relevant Information: {chunks} Answer: """) messages = prompt.format_messages(query=query, chunks="\n\n".join([doc.page_content for doc in docs])) response = self.llm.invoke(messages) return response.content.strip() # Streamlit UI def main(): st.title("Interactive Ranking and Retrieval Analysis") st.write("Experiment with multiple retrieval methods, ranking techniques, and dynamic contextualization.") # Document Upload uploaded_file = st.file_uploader("Upload a Text Document", type=['txt', 'md']) if uploaded_file: document = uploaded_file.read().decode("utf-8") st.success("Document uploaded successfully!") # Initialize Retrieval System cr = ContextualRetrieval() chunks = cr.process_document(document) contextualized_chunks = cr.generate_contextualized_chunks(document, chunks) # Create indexes and retrievers original_vectorstore = cr.create_vectorstore(chunks) contextualized_vectorstore = cr.create_vectorstore(contextualized_chunks) original_bm25_retriever = cr.create_bm25_retriever(chunks) contextualized_bm25_retriever = cr.create_bm25_retriever(contextualized_chunks) # Rerankers and Ensemble Retrievers original_reranker = cr.create_reranker(original_vectorstore) contextualized_reranker = cr.create_reranker(contextualized_vectorstore) original_ensemble = cr.create_ensemble(original_vectorstore, original_bm25_retriever) contextualized_ensemble = cr.create_ensemble(contextualized_vectorstore, contextualized_bm25_retriever) # Query Input query = st.text_input("Enter your query:") if query: with st.spinner("Fetching results..."): # Retrieve results original_vector_results = original_vectorstore.similarity_search(query, k=3) contextualized_vector_results = contextualized_vectorstore.similarity_search(query, k=3) original_bm25_results = original_bm25_retriever.get_relevant_documents(query) contextualized_bm25_results = contextualized_bm25_retriever.get_relevant_documents(query) original_reranker_results = original_reranker.invoke(query) contextualized_reranker_results = contextualized_reranker.invoke(query) original_ensemble_results = original_ensemble.invoke(query) contextualized_ensemble_results = contextualized_ensemble.invoke(query) # Generate answers original_vector_answer = cr.generate_answer(query, original_vector_results) contextualized_vector_answer = cr.generate_answer(query, contextualized_vector_results) original_bm25_answer = cr.generate_answer(query, original_bm25_results) contextualized_bm25_answer = cr.generate_answer(query, contextualized_bm25_results) original_reranker_answer = cr.generate_answer(query, original_reranker_results) contextualized_reranker_answer = cr.generate_answer(query, contextualized_reranker_results) original_ensemble_answer = cr.generate_answer(query, original_ensemble_results) contextualized_ensemble_answer = cr.generate_answer(query, contextualized_ensemble_results) # Display Results st.subheader("Results Comparison") col1, col2 = st.columns(2) with col1: st.write("### Original Results") st.write("**Vector Search Answer**") st.info(original_vector_answer) st.write("**BM25 Search Answer**") st.info(original_bm25_answer) st.write("**Reranker Answer**") st.info(original_reranker_answer) st.write("**Ensemble Answer**") st.info(original_ensemble_answer) with col2: st.write("### Contextualized Results") st.write("**Vector Search Answer**") st.info(contextualized_vector_answer) st.write("**BM25 Search Answer**") st.info(contextualized_bm25_answer) st.write("**Reranker Answer**") st.info(contextualized_reranker_answer) st.write("**Ensemble Answer**") st.info(contextualized_ensemble_answer) if __name__ == "__main__": main()