File size: 7,601 Bytes
71c8775 4c06776 eec1f16 4c06776 e0dc0b2 4c06776 8c605d7 4c06776 eec1f16 4c06776 eec1f16 8c605d7 eec1f16 8c605d7 4c06776 e0dc0b2 cfd0d9a eec1f16 e0dc0b2 eec1f16 4c06776 8c605d7 4c06776 8c605d7 4c06776 83ad364 8c605d7 4c06776 eec1f16 4c06776 eec1f16 cfd0d9a eec1f16 cfd0d9a eec1f16 cfd0d9a 4c06776 eec1f16 4c06776 eec1f16 cfd0d9a eec1f16 4c06776 eec1f16 4c06776 e0dc0b2 4c06776 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# ref: https://github.com/plaban1981/Agents/blob/main/Contextual_Retrieval_processing_prompt.ipynb
import streamlit as st
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.retrievers import BM25Retriever, ContextualCompressionRetriever, EnsembleRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain_community.document_transformers.embeddings_redundant_filter import EmbeddingsRedundantFilter
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate
import hashlib
from typing import List
# Contextual Retrieval Class
class ContextualRetrieval:
def __init__(self):
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
model_name = "BAAI/bge-large-en-v1.5"
self.embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs={'device': 'cpu'})
self.llm = ChatGroq(model="llama-3.2-3b-preview", temperature=0)
def process_document(self, document: str) -> List[Document]:
return self.text_splitter.create_documents([document])
def generate_contextualized_chunks(self, document: str, chunks: List[Document]) -> List[Document]:
contextualized_chunks = []
for chunk in chunks:
context = self._generate_context(document, chunk.page_content)
contextualized_content = f"{context}\n\n{chunk.page_content}"
contextualized_chunks.append(Document(page_content=contextualized_content))
return contextualized_chunks
def _generate_context(self, document: str, chunk: str) -> str:
prompt = ChatPromptTemplate.from_template("""
Based on the document and a specific chunk of text, generate a 2-3 sentence summary that contextualizes the chunk:
Document:
{document}
Chunk:
{chunk}
Context:
""")
messages = prompt.format_messages(document=document, chunk=chunk)
response = self.llm.invoke(messages)
return response.content.strip()
def create_vectorstore(self, chunks: List[Document]) -> FAISS:
return FAISS.from_documents(chunks, self.embeddings)
def create_bm25_retriever(self, chunks: List[Document]) -> BM25Retriever:
return BM25Retriever.from_documents(chunks)
def create_reranker(self, vectorstore):
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
return ContextualCompressionRetriever(base_compressor=FlashrankRerank(), base_retriever=retriever)
def create_ensemble(self, vectorstore, bm25_retriever):
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
return EnsembleRetriever(
retrievers=[vector_retriever, bm25_retriever],
weights=[0.5, 0.5]
)
def generate_answer(self, query: str, docs: List[Document]) -> str:
prompt = ChatPromptTemplate.from_template("""
Question: {query}
Relevant Information: {chunks}
Answer:
""")
messages = prompt.format_messages(query=query, chunks="\n\n".join([doc.page_content for doc in docs]))
response = self.llm.invoke(messages)
return response.content.strip()
# Streamlit UI
def main():
st.title("Interactive Ranking and Retrieval Analysis")
st.write("Experiment with multiple retrieval methods, ranking techniques, and dynamic contextualization.")
# Document Upload
uploaded_file = st.file_uploader("Upload a Text Document", type=['txt', 'md'])
if uploaded_file:
document = uploaded_file.read().decode("utf-8")
st.success("Document uploaded successfully!")
# Initialize Retrieval System
cr = ContextualRetrieval()
chunks = cr.process_document(document)
contextualized_chunks = cr.generate_contextualized_chunks(document, chunks)
# Create indexes and retrievers
original_vectorstore = cr.create_vectorstore(chunks)
contextualized_vectorstore = cr.create_vectorstore(contextualized_chunks)
original_bm25_retriever = cr.create_bm25_retriever(chunks)
contextualized_bm25_retriever = cr.create_bm25_retriever(contextualized_chunks)
# Rerankers and Ensemble Retrievers
original_reranker = cr.create_reranker(original_vectorstore)
contextualized_reranker = cr.create_reranker(contextualized_vectorstore)
original_ensemble = cr.create_ensemble(original_vectorstore, original_bm25_retriever)
contextualized_ensemble = cr.create_ensemble(contextualized_vectorstore, contextualized_bm25_retriever)
# Query Input
query = st.text_input("Enter your query:")
if query:
with st.spinner("Fetching results..."):
# Retrieve results
original_vector_results = original_vectorstore.similarity_search(query, k=3)
contextualized_vector_results = contextualized_vectorstore.similarity_search(query, k=3)
original_bm25_results = original_bm25_retriever.get_relevant_documents(query)
contextualized_bm25_results = contextualized_bm25_retriever.get_relevant_documents(query)
original_reranker_results = original_reranker.invoke(query)
contextualized_reranker_results = contextualized_reranker.invoke(query)
original_ensemble_results = original_ensemble.invoke(query)
contextualized_ensemble_results = contextualized_ensemble.invoke(query)
# Generate answers
original_vector_answer = cr.generate_answer(query, original_vector_results)
contextualized_vector_answer = cr.generate_answer(query, contextualized_vector_results)
original_bm25_answer = cr.generate_answer(query, original_bm25_results)
contextualized_bm25_answer = cr.generate_answer(query, contextualized_bm25_results)
original_reranker_answer = cr.generate_answer(query, original_reranker_results)
contextualized_reranker_answer = cr.generate_answer(query, contextualized_reranker_results)
original_ensemble_answer = cr.generate_answer(query, original_ensemble_results)
contextualized_ensemble_answer = cr.generate_answer(query, contextualized_ensemble_results)
# Display Results
st.subheader("Results Comparison")
col1, col2 = st.columns(2)
with col1:
st.write("### Original Results")
st.write("**Vector Search Answer**")
st.info(original_vector_answer)
st.write("**BM25 Search Answer**")
st.info(original_bm25_answer)
st.write("**Reranker Answer**")
st.info(original_reranker_answer)
st.write("**Ensemble Answer**")
st.info(original_ensemble_answer)
with col2:
st.write("### Contextualized Results")
st.write("**Vector Search Answer**")
st.info(contextualized_vector_answer)
st.write("**BM25 Search Answer**")
st.info(contextualized_bm25_answer)
st.write("**Reranker Answer**")
st.info(contextualized_reranker_answer)
st.write("**Ensemble Answer**")
st.info(contextualized_ensemble_answer)
if __name__ == "__main__":
main()
|