DrishtiSharma commited on
Commit
8f252e0
·
verified ·
1 Parent(s): eec1f16

Create interim.py

Browse files
Files changed (1) hide show
  1. interim.py +91 -0
interim.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
3
+ from langchain.schema import Document
4
+ from langchain.embeddings import HuggingFaceEmbeddings
5
+ from langchain.vectorstores import FAISS
6
+ from rank_bm25 import BM25Okapi
7
+ from langchain.retrievers import ContextualCompressionRetriever, BM25Retriever, EnsembleRetriever
8
+ from langchain.retrievers.document_compressors import FlashrankRerank
9
+ from langchain_community.embeddings import HuggingFaceEmbeddings
10
+ from langchain_groq import ChatGroq
11
+ from langchain.prompts import ChatPromptTemplate
12
+
13
+ import hashlib
14
+ from typing import List
15
+
16
+ # Contextual Retrieval Class
17
+ class ContextualRetrieval:
18
+ def __init__(self):
19
+ self.text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
20
+ model_name = "BAAI/bge-large-en-v1.5"
21
+ model_kwargs = {'device': 'cpu'}
22
+ encode_kwargs = {'normalize_embeddings': False}
23
+ self.embeddings = HuggingFaceEmbeddings(model_name=model_name, model_kwargs=model_kwargs, encode_kwargs=encode_kwargs)
24
+ self.llm = ChatGroq(model="llama-3.2-3b-preview", temperature=0)
25
+
26
+ def process_document(self, document: str) -> List[Document]:
27
+ return self.text_splitter.create_documents([document])
28
+
29
+ def create_vectorstore(self, chunks: List[Document]) -> FAISS:
30
+ return FAISS.from_documents(chunks, self.embeddings)
31
+
32
+ def create_bm25_retriever(self, chunks: List[Document]) -> BM25Retriever:
33
+ return BM25Retriever.from_documents(chunks)
34
+
35
+ def generate_answer(self, query: str, docs: List[Document]) -> str:
36
+ prompt = ChatPromptTemplate.from_template("""
37
+ Question: {query}
38
+ Relevant Information: {chunks}
39
+ Answer:""")
40
+ messages = prompt.format_messages(query=query, chunks="\n\n".join([doc.page_content for doc in docs]))
41
+ response = self.llm.invoke(messages)
42
+ return response.content
43
+
44
+ # Streamlit UI
45
+ def main():
46
+ st.title("Interactive Document Retrieval Analysis")
47
+ st.write("Upload a document, experiment with retrieval methods, and analyze content interactively.")
48
+
49
+ # Document Upload
50
+ uploaded_file = st.file_uploader("Upload a Text Document", type=['txt', 'md'])
51
+ if uploaded_file:
52
+ document = uploaded_file.read().decode("utf-8")
53
+ st.success("Document successfully uploaded!")
54
+
55
+ # Initialize Retrieval System
56
+ cr = ContextualRetrieval()
57
+ chunks = cr.process_document(document)
58
+ vectorstore = cr.create_vectorstore(chunks)
59
+ bm25_retriever = cr.create_bm25_retriever(chunks)
60
+
61
+ # Query Input
62
+ query = st.text_input("Enter your question about the document:")
63
+ if query:
64
+ # Retrieve Results
65
+ with st.spinner("Fetching results..."):
66
+ vector_results = vectorstore.similarity_search(query, k=3)
67
+ bm25_results = bm25_retriever.get_relevant_documents(query)
68
+
69
+ vector_answer = cr.generate_answer(query, vector_results)
70
+ bm25_answer = cr.generate_answer(query, bm25_results)
71
+
72
+ # Display Results
73
+ st.subheader("Results from Vector Search")
74
+ st.write(vector_answer)
75
+
76
+ st.subheader("Results from BM25 Search")
77
+ st.write(bm25_answer)
78
+
79
+ # Display Sources
80
+ st.subheader("Top Retrieved Chunks")
81
+ st.write("**Vector Search Results:**")
82
+ for i, doc in enumerate(vector_results, 1):
83
+ st.write(f"{i}. {doc.page_content[:300]}...")
84
+
85
+ st.write("**BM25 Search Results:**")
86
+ for i, doc in enumerate(bm25_results, 1):
87
+ st.write(f"{i}. {doc.page_content[:300]}...")
88
+
89
+ # Run the Streamlit App
90
+ if __name__ == "__main__":
91
+ main()