DrishtiSharma's picture
Create interim.py
8f252e0 verified
import streamlit as st
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.schema import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from rank_bm25 import BM25Okapi
from langchain.retrievers import ContextualCompressionRetriever, BM25Retriever, EnsembleRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_groq import ChatGroq
from langchain.prompts import ChatPromptTemplate
import hashlib
from typing import List
# Contextual Retrieval Class
class ContextualRetrieval:
def __init__(self):
self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
model_name = "BAAI/bge-large-en-v1.5"
model_kwargs = {'device': 'cpu'}
encode_kwargs = {'normalize_embeddings': False}
self.embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)
self.llm = ChatGroq(model="llama-3.2-3b-preview", temperature=0)
def process_document(self, document: str) -> List[Document]:
return self.text_splitter.create_documents([document])
def create_vectorstore(self, chunks: List[Document]) -> FAISS:
return FAISS.from_documents(chunks, self.embeddings)
def create_bm25_retriever(self, chunks: List[Document]) -> BM25Retriever:
return BM25Retriever.from_documents(chunks)
def generate_answer(self, query: str, docs: List[Document]) -> str:
prompt = ChatPromptTemplate.from_template("""
Question: {query}
Relevant Information: {chunks}
Answer:""")
messages = prompt.format_messages(query=query, chunks="\n\n".join([doc.page_content for doc in docs]))
response = self.llm.invoke(messages)
return response.content
# Streamlit UI
def main():
st.title("Interactive Document Retrieval Analysis")
st.write("Upload a document, experiment with retrieval methods, and analyze content interactively.")
# Document Upload
uploaded_file = st.file_uploader("Upload a Text Document", type=['txt', 'md'])
if uploaded_file:
document = uploaded_file.read().decode("utf-8")
st.success("Document successfully uploaded!")
# Initialize Retrieval System
cr = ContextualRetrieval()
chunks = cr.process_document(document)
vectorstore = cr.create_vectorstore(chunks)
bm25_retriever = cr.create_bm25_retriever(chunks)
# Query Input
query = st.text_input("Enter your question about the document:")
if query:
# Retrieve Results
with st.spinner("Fetching results..."):
vector_results = vectorstore.similarity_search(query, k=3)
bm25_results = bm25_retriever.get_relevant_documents(query)
vector_answer = cr.generate_answer(query, vector_results)
bm25_answer = cr.generate_answer(query, bm25_results)
# Display Results
st.subheader("Results from Vector Search")
st.write(vector_answer)
st.subheader("Results from BM25 Search")
st.write(bm25_answer)
# Display Sources
st.subheader("Top Retrieved Chunks")
st.write("**Vector Search Results:**")
for i, doc in enumerate(vector_results, 1):
st.write(f"{i}. {doc.page_content[:300]}...")
st.write("**BM25 Search Results:**")
for i, doc in enumerate(bm25_results, 1):
st.write(f"{i}. {doc.page_content[:300]}...")
# Run the Streamlit App
if __name__ == "__main__":
main()