|
|
|
import streamlit as st |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.schema import Document |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.retrievers import BM25Retriever, ContextualCompressionRetriever, EnsembleRetriever |
|
from langchain.retrievers.document_compressors import FlashrankRerank |
|
from langchain_community.document_transformers.embeddings_redundant_filter import EmbeddingsRedundantFilter |
|
from langchain_groq import ChatGroq |
|
from langchain.prompts import ChatPromptTemplate |
|
import hashlib |
|
from typing import List |
|
|
|
|
|
class ContextualRetrieval: |
|
def __init__(self): |
|
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100) |
|
model_name = "BAAI/bge-large-en-v1.5" |
|
self.embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs={'device': 'cpu'}) |
|
self.llm = ChatGroq(model="llama-3.2-3b-preview", temperature=0) |
|
|
|
def process_document(self, document: str) -> List[Document]: |
|
return self.text_splitter.create_documents([document]) |
|
|
|
def generate_contextualized_chunks(self, document: str, chunks: List[Document]) -> List[Document]: |
|
contextualized_chunks = [] |
|
for chunk in chunks: |
|
context = self._generate_context(document, chunk.page_content) |
|
contextualized_content = f"{context}\n\n{chunk.page_content}" |
|
contextualized_chunks.append(Document(page_content=contextualized_content)) |
|
return contextualized_chunks |
|
|
|
def _generate_context(self, document: str, chunk: str) -> str: |
|
prompt = ChatPromptTemplate.from_template(""" |
|
Based on the document and a specific chunk of text, generate a 2-3 sentence summary that contextualizes the chunk: |
|
Document: |
|
{document} |
|
|
|
Chunk: |
|
{chunk} |
|
|
|
Context: |
|
""") |
|
messages = prompt.format_messages(document=document, chunk=chunk) |
|
response = self.llm.invoke(messages) |
|
return response.content.strip() |
|
|
|
def create_vectorstore(self, chunks: List[Document]) -> FAISS: |
|
return FAISS.from_documents(chunks, self.embeddings) |
|
|
|
def create_bm25_retriever(self, chunks: List[Document]) -> BM25Retriever: |
|
return BM25Retriever.from_documents(chunks) |
|
|
|
def create_reranker(self, vectorstore): |
|
retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) |
|
return ContextualCompressionRetriever(base_compressor=FlashrankRerank(), base_retriever=retriever) |
|
|
|
def create_ensemble(self, vectorstore, bm25_retriever): |
|
vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 10}) |
|
return EnsembleRetriever( |
|
retrievers=[vector_retriever, bm25_retriever], |
|
weights=[0.5, 0.5] |
|
) |
|
|
|
def generate_answer(self, query: str, docs: List[Document]) -> str: |
|
prompt = ChatPromptTemplate.from_template(""" |
|
Question: {query} |
|
Relevant Information: {chunks} |
|
Answer: |
|
""") |
|
messages = prompt.format_messages(query=query, chunks="\n\n".join([doc.page_content for doc in docs])) |
|
response = self.llm.invoke(messages) |
|
return response.content.strip() |
|
|
|
|
|
def main(): |
|
st.title("Interactive Ranking and Retrieval Analysis") |
|
st.write("Experiment with multiple retrieval methods, ranking techniques, and dynamic contextualization.") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload a Text Document", type=['txt', 'md']) |
|
if uploaded_file: |
|
document = uploaded_file.read().decode("utf-8") |
|
st.success("Document uploaded successfully!") |
|
|
|
|
|
cr = ContextualRetrieval() |
|
chunks = cr.process_document(document) |
|
contextualized_chunks = cr.generate_contextualized_chunks(document, chunks) |
|
|
|
|
|
original_vectorstore = cr.create_vectorstore(chunks) |
|
contextualized_vectorstore = cr.create_vectorstore(contextualized_chunks) |
|
original_bm25_retriever = cr.create_bm25_retriever(chunks) |
|
contextualized_bm25_retriever = cr.create_bm25_retriever(contextualized_chunks) |
|
|
|
|
|
original_reranker = cr.create_reranker(original_vectorstore) |
|
contextualized_reranker = cr.create_reranker(contextualized_vectorstore) |
|
original_ensemble = cr.create_ensemble(original_vectorstore, original_bm25_retriever) |
|
contextualized_ensemble = cr.create_ensemble(contextualized_vectorstore, contextualized_bm25_retriever) |
|
|
|
|
|
query = st.text_input("Enter your query:") |
|
if query: |
|
with st.spinner("Fetching results..."): |
|
|
|
original_vector_results = original_vectorstore.similarity_search(query, k=3) |
|
contextualized_vector_results = contextualized_vectorstore.similarity_search(query, k=3) |
|
original_bm25_results = original_bm25_retriever.get_relevant_documents(query) |
|
contextualized_bm25_results = contextualized_bm25_retriever.get_relevant_documents(query) |
|
original_reranker_results = original_reranker.invoke(query) |
|
contextualized_reranker_results = contextualized_reranker.invoke(query) |
|
original_ensemble_results = original_ensemble.invoke(query) |
|
contextualized_ensemble_results = contextualized_ensemble.invoke(query) |
|
|
|
|
|
original_vector_answer = cr.generate_answer(query, original_vector_results) |
|
contextualized_vector_answer = cr.generate_answer(query, contextualized_vector_results) |
|
original_bm25_answer = cr.generate_answer(query, original_bm25_results) |
|
contextualized_bm25_answer = cr.generate_answer(query, contextualized_bm25_results) |
|
original_reranker_answer = cr.generate_answer(query, original_reranker_results) |
|
contextualized_reranker_answer = cr.generate_answer(query, contextualized_reranker_results) |
|
original_ensemble_answer = cr.generate_answer(query, original_ensemble_results) |
|
contextualized_ensemble_answer = cr.generate_answer(query, contextualized_ensemble_results) |
|
|
|
|
|
st.subheader("Results Comparison") |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
st.write("### Original Results") |
|
st.write("**Vector Search Answer**") |
|
st.info(original_vector_answer) |
|
st.write("**BM25 Search Answer**") |
|
st.info(original_bm25_answer) |
|
st.write("**Reranker Answer**") |
|
st.info(original_reranker_answer) |
|
st.write("**Ensemble Answer**") |
|
st.info(original_ensemble_answer) |
|
|
|
with col2: |
|
st.write("### Contextualized Results") |
|
st.write("**Vector Search Answer**") |
|
st.info(contextualized_vector_answer) |
|
st.write("**BM25 Search Answer**") |
|
st.info(contextualized_bm25_answer) |
|
st.write("**Reranker Answer**") |
|
st.info(contextualized_reranker_answer) |
|
st.write("**Ensemble Answer**") |
|
st.info(contextualized_ensemble_answer) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|