|
import streamlit as st |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain.schema import Document |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from rank_bm25 import BM25Okapi |
|
from langchain.retrievers import ContextualCompressionRetriever, BM25Retriever, EnsembleRetriever |
|
from langchain.retrievers.document_compressors import FlashrankRerank |
|
from langchain_community.embeddings import HuggingFaceEmbeddings |
|
from langchain_groq import ChatGroq |
|
from langchain.prompts import ChatPromptTemplate |
|
|
|
import hashlib |
|
from typing import List |
|
|
|
|
|
class ContextualRetrieval: |
|
def __init__(self): |
|
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100) |
|
model_name = "BAAI/bge-large-en-v1.5" |
|
model_kwargs = {'device': 'cpu'} |
|
encode_kwargs = {'normalize_embeddings': False} |
|
self.embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs) |
|
self.llm = ChatGroq(model="llama-3.2-3b-preview", temperature=0) |
|
|
|
def process_document(self, document: str) -> List[Document]: |
|
return self.text_splitter.create_documents([document]) |
|
|
|
def create_vectorstore(self, chunks: List[Document]) -> FAISS: |
|
return FAISS.from_documents(chunks, self.embeddings) |
|
|
|
def create_bm25_retriever(self, chunks: List[Document]) -> BM25Retriever: |
|
return BM25Retriever.from_documents(chunks) |
|
|
|
def generate_answer(self, query: str, docs: List[Document]) -> str: |
|
prompt = ChatPromptTemplate.from_template(""" |
|
Question: {query} |
|
Relevant Information: {chunks} |
|
Answer:""") |
|
messages = prompt.format_messages(query=query, chunks="\n\n".join([doc.page_content for doc in docs])) |
|
response = self.llm.invoke(messages) |
|
return response.content |
|
|
|
|
|
def main(): |
|
st.title("Interactive Document Retrieval Analysis") |
|
st.write("Upload a document, experiment with retrieval methods, and analyze content interactively.") |
|
|
|
|
|
uploaded_file = st.file_uploader("Upload a Text Document", type=['txt', 'md']) |
|
if uploaded_file: |
|
document = uploaded_file.read().decode("utf-8") |
|
st.success("Document successfully uploaded!") |
|
|
|
|
|
cr = ContextualRetrieval() |
|
chunks = cr.process_document(document) |
|
vectorstore = cr.create_vectorstore(chunks) |
|
bm25_retriever = cr.create_bm25_retriever(chunks) |
|
|
|
|
|
query = st.text_input("Enter your question about the document:") |
|
if query: |
|
|
|
with st.spinner("Fetching results..."): |
|
vector_results = vectorstore.similarity_search(query, k=3) |
|
bm25_results = bm25_retriever.get_relevant_documents(query) |
|
|
|
vector_answer = cr.generate_answer(query, vector_results) |
|
bm25_answer = cr.generate_answer(query, bm25_results) |
|
|
|
|
|
st.subheader("Results from Vector Search") |
|
st.write(vector_answer) |
|
|
|
st.subheader("Results from BM25 Search") |
|
st.write(bm25_answer) |
|
|
|
|
|
st.subheader("Top Retrieved Chunks") |
|
st.write("**Vector Search Results:**") |
|
for i, doc in enumerate(vector_results, 1): |
|
st.write(f"{i}. {doc.page_content[:300]}...") |
|
|
|
st.write("**BM25 Search Results:**") |
|
for i, doc in enumerate(bm25_results, 1): |
|
st.write(f"{i}. {doc.page_content[:300]}...") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |