varun1011 commited on
Commit
fc540fe
·
verified ·
1 Parent(s): 6c5f507

Upload 4 files

Browse files
Files changed (4) hide show
  1. app.py +62 -0
  2. data_preprocessing.py +133 -0
  3. rag.py +112 -0
  4. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from pathlib import Path
3
+ from data_preprocessing import process_docs
4
+ from rag import create_rag_chain
5
+ import time
6
+
7
+ def response_generator(prompt,chain):
8
+ response = chain.invoke(prompt)
9
+ for word in response.split():
10
+ yield word + " "
11
+ time.sleep(0.05)
12
+
13
+
14
+
15
+ # Set up the file uploader
16
+ # uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
17
+
18
+ # Specify the directory to save files
19
+ save_directory = "docs"
20
+ save_path="docs/file.pdf"
21
+
22
+
23
+
24
+
25
+
26
+ st.title("📝 InsureAgent")
27
+ with st.sidebar:
28
+ uploaded_file = st.file_uploader("Upload a document", type=("pdf"))
29
+ if uploaded_file is not None:
30
+ with open(save_path, "wb") as f:
31
+ f.write(uploaded_file.getbuffer())
32
+ st.success(f"File saved successfully: {save_path}")
33
+ retriever=process_docs(save_path)
34
+ chain,chain_with_sources=create_rag_chain(retriever)
35
+
36
+ # Initialize chat history
37
+ if "messages" not in st.session_state:
38
+ st.session_state.messages = []
39
+ for message in st.session_state.messages:
40
+ with st.chat_message(message["role"]):
41
+ st.markdown(message["content"])
42
+ # Streamed response emulator
43
+
44
+ if prompt := st.chat_input("What is up?"):
45
+ # Add user message to chat history
46
+ st.session_state.messages.append({"role": "user", "content": prompt})
47
+ # Display user message in chat message container
48
+ with st.chat_message("user"):
49
+ st.markdown(prompt)
50
+
51
+ # Display assistant response in chat message container
52
+ with st.chat_message("assistant"):
53
+ response = st.write_stream(response_generator(prompt,chain))
54
+ # Add assistant response to chat history
55
+ st.session_state.messages.append({"role": "assistant", "content": response})
56
+
57
+
58
+
59
+
60
+
61
+
62
+
data_preprocessing.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pdfplumber
2
+ import uuid
3
+ from langchain_groq import ChatGroq
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from langchain.storage import InMemoryStore
7
+ from langchain.schema.document import Document
8
+ from langchain.embeddings import OpenAIEmbeddings
9
+ from langchain.retrievers.multi_vector import MultiVectorRetriever
10
+ from langchain_huggingface import HuggingFaceEmbeddings
11
+ from pinecone import Pinecone as pineC, ServerlessSpec
12
+ from langchain_pinecone import Pinecone
13
+ import os
14
+ from dotenv import load_dotenv
15
+ load_dotenv()
16
+
17
+ def extract_pdf(file_path):
18
+ texts=[]
19
+ tables=[]
20
+ # Open the PDF and extract pages
21
+ with pdfplumber.open(file_path) as pdf:
22
+ for page in pdf.pages:
23
+ texts.append(page.extract_text())
24
+ # print(text) # Extract plain text
25
+ if page.extract_tables():
26
+ tables.append(page.extract_tables())
27
+ # Extract tables
28
+ return texts, tables
29
+ def summarize_data(texts,tables):
30
+ prompt_text = """
31
+ You are an assistant tasked with summarizing tables and text.
32
+ Give a concise summary of the table or text that perfectly describes the table in starting 2 sentences.
33
+
34
+ Respond only with the summary, no additionnal comment.
35
+ Do not start your message by saying "Here is a summary" or anything like that.
36
+ Just give the summary as it is.
37
+
38
+ Table or text chunk: {element}
39
+ """
40
+ prompt = ChatPromptTemplate.from_template(prompt_text)
41
+
42
+ #
43
+ # Summary chain
44
+ model = ChatGroq(temperature=0, model="llama-3.1-8b-instant",api_key=os.environ["GROQ_API_KEY"])
45
+ summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
46
+ # Summarize extracted text
47
+ text_summaries = []
48
+ if texts:
49
+ text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})
50
+
51
+ # Summarize extracted tables
52
+ tables_html = [str(table) for table in tables] # Convert tables to string format
53
+ table_summaries = []
54
+ if tables_html:
55
+ table_summaries = summarize_chain.batch(tables_html, {"max_concurrency": 5})
56
+ return texts,text_summaries,tables,table_summaries
57
+
58
+ def create_vectorstore():
59
+
60
+ model_name = "intfloat/multilingual-e5-large-instruct"
61
+ model_kwargs = {'device': 'cpu'}
62
+ encode_kwargs = {'normalize_embeddings': False}
63
+ hf = HuggingFaceEmbeddings(
64
+ model_name=model_name,
65
+ model_kwargs=model_kwargs,
66
+ encode_kwargs=encode_kwargs
67
+ )
68
+ # index= pc.Index("gaido-rag")
69
+ # The vectorstore to use to index the child chunks
70
+ # vectorstore = Chroma(collection_name="multi_modal_rag", embedding_function=hf)
71
+
72
+ # The storage layer for the parent documents
73
+ store = InMemoryStore()
74
+ id_key = "doc_id"
75
+
76
+ pc = pineC(api_key=os.environ["PINECONE_API_KEY"])
77
+
78
+ index_name = "gaidorag"
79
+ text_field = "text"
80
+ cloud ='aws'
81
+ region = 'us-east-1'
82
+
83
+ spec = ServerlessSpec(cloud=cloud, region=region)
84
+ # check if index already exists (it shouldn't if this is first time)
85
+ if index_name not in pc.list_indexes().names():
86
+ # if does not exist, create index
87
+ pc.create_index(
88
+ index_name,
89
+ dimension=1024, # dimensionality of text-embedding-ada-002
90
+ metric='cosine',
91
+ spec=spec
92
+ )
93
+ # switch back to normal index for langchain
94
+ index = pc.Index(index_name)
95
+
96
+ vectorstore = Pinecone(
97
+ index, hf, text_field
98
+ )
99
+
100
+
101
+
102
+ # The retriever (empty to start)
103
+ retriever = MultiVectorRetriever(
104
+ vectorstore=vectorstore,
105
+ docstore=store,
106
+ id_key=id_key,
107
+ )
108
+ return retriever
109
+ def embed_docs(retriever,texts,text_summaries,tables,table_summaries):
110
+ # Add texts
111
+ id_key = "doc_id"
112
+ doc_ids = [str(uuid.uuid4()) for _ in texts]
113
+ summary_texts = [
114
+ Document(page_content=summary, metadata={id_key: doc_ids[i]}) for i, summary in enumerate(text_summaries)
115
+ ]
116
+ retriever.vectorstore.add_documents(summary_texts)
117
+ retriever.docstore.mset(list(zip(doc_ids, texts)))
118
+
119
+ # Add tables
120
+ table_ids = [str(uuid.uuid4()) for _ in tables]
121
+ summary_tables = [
122
+ Document(page_content=summary, metadata={id_key: table_ids[i]}) for i, summary in enumerate(table_summaries)
123
+ ]
124
+ retriever.vectorstore.add_documents(summary_tables)
125
+ retriever.docstore.mset(list(zip(table_ids, tables)))
126
+
127
+
128
+ def process_docs(file_path):
129
+ texts,tables=extract_pdf(file_path)
130
+ texts,text_summaries,tables,table_summaries=summarize_data(texts,tables)
131
+ retriever=create_vectorstore()
132
+ embed_docs(retriever,texts,text_summaries,tables,table_summaries)
133
+ return retriever
rag.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain_core.runnables import RunnablePassthrough, RunnableLambda
2
+ from langchain_core.messages import SystemMessage, HumanMessage
3
+ from langchain_groq import ChatGroq
4
+ from langchain_core.prompts import ChatPromptTemplate
5
+ from langchain_core.output_parsers import StrOutputParser
6
+ from base64 import b64decode
7
+ import os
8
+ from dotenv import load_dotenv
9
+ load_dotenv
10
+
11
+
12
+ def parse_docs(docs):
13
+ """Split base64-encoded images and texts"""
14
+ b64 = []
15
+ text = []
16
+ for doc in docs:
17
+ try:
18
+ b64decode(doc)
19
+ b64.append(doc)
20
+ except Exception as e:
21
+ text.append(doc)
22
+ return {"images": b64, "texts": text}
23
+
24
+
25
+ def build_prompt(kwargs):
26
+ docs_by_type = kwargs["context"]
27
+ user_question = kwargs["question"]
28
+
29
+ # Extract text context
30
+ context_text = ""
31
+ if docs_by_type.get("texts"):
32
+ for text_element in docs_by_type["texts"]:
33
+ if isinstance(text_element, list):
34
+ # Flatten nested lists before joining
35
+ flat_text = " ".join(
36
+ " ".join(map(str, sub_element)) if isinstance(sub_element, list) else str(sub_element)
37
+ for sub_element in text_element
38
+ )
39
+ context_text += flat_text + "\n"
40
+ else:
41
+ context_text += str(text_element) + "\n"
42
+
43
+ # Extract table context
44
+ context_tables = ""
45
+ if docs_by_type.get("tables"):
46
+ for table in docs_by_type["tables"]:
47
+ table_str = "\n".join([" | ".join(map(str, row)) for row in table]) # Convert table rows to strings
48
+ context_tables += f"\nTable:\n{table_str}\n"
49
+
50
+ # Construct prompt with context (including tables)
51
+ prompt_template = f"""
52
+ Answer the question based only on the following context only Ground Truth Final Answer, which includes text and tables.
53
+ If you don't know the answer, just say that you don't know, don't try to make up an answer.
54
+ Be specific
55
+
56
+ Context:
57
+ {context_text}
58
+
59
+ {context_tables}
60
+
61
+ Question: {user_question}
62
+ """
63
+
64
+ prompt_content = [{"type": "text", "text": prompt_template}]
65
+
66
+ # If images are provided, include them
67
+ # if docs_by_type.get("images"):
68
+ # for image in docs_by_type["images"]:
69
+ # prompt_content.append(
70
+ # {
71
+ # "type": "image_url",
72
+ # "image_url": {"url": f"data:image/jpeg;base64,{image}"},
73
+ # }
74
+ # )
75
+
76
+ return ChatPromptTemplate.from_messages(
77
+ [
78
+ HumanMessage(content=prompt_content),
79
+ ]
80
+ )
81
+
82
+ def create_rag_chain(retriever):
83
+ chain = (
84
+ {
85
+ "context": retriever | RunnableLambda(parse_docs),
86
+ "question": RunnablePassthrough(),
87
+ }
88
+ | RunnableLambda(build_prompt)
89
+ | ChatGroq(model="llama-3.3-70b-versatile",api_key=os.environ["GROQ_API_KEY"])
90
+ | StrOutputParser()
91
+ )
92
+
93
+ chain_with_sources = {
94
+ "context": retriever | RunnableLambda(parse_docs),
95
+ "question": RunnablePassthrough(),
96
+ } | RunnablePassthrough().assign(
97
+ response=(
98
+ RunnableLambda(build_prompt)
99
+ | ChatGroq(model="llama-3.3-70b-versatile",api_key=os.environ["GROQ_API_KEY"])
100
+ | StrOutputParser()
101
+ )
102
+ )
103
+ return chain, chain_with_sources
104
+
105
+ def invoke_chain(chain):
106
+ response = chain.invoke(
107
+ "What is the policy start and expiry date?"
108
+
109
+
110
+ )
111
+
112
+ return response
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pdfplumber
2
+ tiktoken
3
+ langchain
4
+ langchain-community
5
+ langchain-openai
6
+ langchain-groq
7
+ python-dotenv
8
+ langchain-huggingface
9
+ ragas
10
+ datasets
11
+ langchain-pinecone
12
+ streamlit