File size: 7,601 Bytes
71c8775
4c06776
 
 
 
 
eec1f16
4c06776
e0dc0b2
4c06776
 
 
 
 
 
 
 
 
 
8c605d7
4c06776
eec1f16
4c06776
 
 
eec1f16
 
 
8c605d7
 
 
eec1f16
 
8c605d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c06776
 
e0dc0b2
cfd0d9a
 
eec1f16
 
 
 
 
 
 
e0dc0b2
eec1f16
 
 
4c06776
 
 
 
 
8c605d7
 
4c06776
 
8c605d7
4c06776
 
 
83ad364
8c605d7
4c06776
 
 
 
 
eec1f16
4c06776
 
 
 
eec1f16
 
cfd0d9a
eec1f16
 
cfd0d9a
 
eec1f16
 
 
 
cfd0d9a
 
4c06776
 
eec1f16
4c06776
 
eec1f16
 
 
cfd0d9a
 
eec1f16
 
 
 
4c06776
eec1f16
 
 
 
 
 
 
 
 
4c06776
 
e0dc0b2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4c06776
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# ref: https://github.com/plaban1981/Agents/blob/main/Contextual_Retrieval_processing_prompt.ipynb
import streamlit as st
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.retrievers import BM25Retriever, ContextualCompressionRetriever, EnsembleRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain_community.document_transformers.embeddings_redundant_filter import EmbeddingsRedundantFilter
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate
import hashlib
from typing import List

# Contextual Retrieval Class
class ContextualRetrieval:
    def __init__(self):
        self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
        model_name = "BAAI/bge-large-en-v1.5"
        self.embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs={'device': 'cpu'})
        self.llm = ChatGroq(model="llama-3.2-3b-preview", temperature=0)

    def process_document(self, document: str) -> List[Document]:
        return self.text_splitter.create_documents([document])

    def generate_contextualized_chunks(self, document: str, chunks: List[Document]) -> List[Document]:
        contextualized_chunks = []
        for chunk in chunks:
            context = self._generate_context(document, chunk.page_content)
            contextualized_content = f"{context}\n\n{chunk.page_content}"
            contextualized_chunks.append(Document(page_content=contextualized_content))
        return contextualized_chunks

    def _generate_context(self, document: str, chunk: str) -> str:
        prompt = ChatPromptTemplate.from_template("""
        Based on the document and a specific chunk of text, generate a 2-3 sentence summary that contextualizes the chunk:
        Document:
        {document}

        Chunk:
        {chunk}

        Context:
        """)
        messages = prompt.format_messages(document=document, chunk=chunk)
        response = self.llm.invoke(messages)
        return response.content.strip()

    def create_vectorstore(self, chunks: List[Document]) -> FAISS:
        return FAISS.from_documents(chunks, self.embeddings)

    def create_bm25_retriever(self, chunks: List[Document]) -> BM25Retriever:
        return BM25Retriever.from_documents(chunks)

    def create_reranker(self, vectorstore):
        retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
        return ContextualCompressionRetriever(base_compressor=FlashrankRerank(), base_retriever=retriever)

    def create_ensemble(self, vectorstore, bm25_retriever):
        vector_retriever = vectorstore.as_retriever(search_kwargs={"k": 10})
        return EnsembleRetriever(
            retrievers=[vector_retriever, bm25_retriever],
            weights=[0.5, 0.5]
        )

    def generate_answer(self, query: str, docs: List[Document]) -> str:
        prompt = ChatPromptTemplate.from_template("""
        Question: {query}
        Relevant Information: {chunks}
        Answer:
        """)
        messages = prompt.format_messages(query=query, chunks="\n\n".join([doc.page_content for doc in docs]))
        response = self.llm.invoke(messages)
        return response.content.strip()

# Streamlit UI
def main():
    st.title("Interactive Ranking and Retrieval Analysis")
    st.write("Experiment with multiple retrieval methods, ranking techniques, and dynamic contextualization.")

    # Document Upload
    uploaded_file = st.file_uploader("Upload a Text Document", type=['txt', 'md'])
    if uploaded_file:
        document = uploaded_file.read().decode("utf-8")
        st.success("Document uploaded successfully!")

        # Initialize Retrieval System
        cr = ContextualRetrieval()
        chunks = cr.process_document(document)
        contextualized_chunks = cr.generate_contextualized_chunks(document, chunks)

        # Create indexes and retrievers
        original_vectorstore = cr.create_vectorstore(chunks)
        contextualized_vectorstore = cr.create_vectorstore(contextualized_chunks)
        original_bm25_retriever = cr.create_bm25_retriever(chunks)
        contextualized_bm25_retriever = cr.create_bm25_retriever(contextualized_chunks)

        # Rerankers and Ensemble Retrievers
        original_reranker = cr.create_reranker(original_vectorstore)
        contextualized_reranker = cr.create_reranker(contextualized_vectorstore)
        original_ensemble = cr.create_ensemble(original_vectorstore, original_bm25_retriever)
        contextualized_ensemble = cr.create_ensemble(contextualized_vectorstore, contextualized_bm25_retriever)

        # Query Input
        query = st.text_input("Enter your query:")
        if query:
            with st.spinner("Fetching results..."):
                # Retrieve results
                original_vector_results = original_vectorstore.similarity_search(query, k=3)
                contextualized_vector_results = contextualized_vectorstore.similarity_search(query, k=3)
                original_bm25_results = original_bm25_retriever.get_relevant_documents(query)
                contextualized_bm25_results = contextualized_bm25_retriever.get_relevant_documents(query)
                original_reranker_results = original_reranker.invoke(query)
                contextualized_reranker_results = contextualized_reranker.invoke(query)
                original_ensemble_results = original_ensemble.invoke(query)
                contextualized_ensemble_results = contextualized_ensemble.invoke(query)

                # Generate answers
                original_vector_answer = cr.generate_answer(query, original_vector_results)
                contextualized_vector_answer = cr.generate_answer(query, contextualized_vector_results)
                original_bm25_answer = cr.generate_answer(query, original_bm25_results)
                contextualized_bm25_answer = cr.generate_answer(query, contextualized_bm25_results)
                original_reranker_answer = cr.generate_answer(query, original_reranker_results)
                contextualized_reranker_answer = cr.generate_answer(query, contextualized_reranker_results)
                original_ensemble_answer = cr.generate_answer(query, original_ensemble_results)
                contextualized_ensemble_answer = cr.generate_answer(query, contextualized_ensemble_results)

            # Display Results
            st.subheader("Results Comparison")

            col1, col2 = st.columns(2)
            with col1:
                st.write("### Original Results")
                st.write("**Vector Search Answer**")
                st.info(original_vector_answer)
                st.write("**BM25 Search Answer**")
                st.info(original_bm25_answer)
                st.write("**Reranker Answer**")
                st.info(original_reranker_answer)
                st.write("**Ensemble Answer**")
                st.info(original_ensemble_answer)

            with col2:
                st.write("### Contextualized Results")
                st.write("**Vector Search Answer**")
                st.info(contextualized_vector_answer)
                st.write("**BM25 Search Answer**")
                st.info(contextualized_bm25_answer)
                st.write("**Reranker Answer**")
                st.info(contextualized_reranker_answer)
                st.write("**Ensemble Answer**")
                st.info(contextualized_ensemble_answer)

if __name__ == "__main__":
    main()