Upload 4 files
Browse files- app.py +62 -0
- data_preprocessing.py +133 -0
- rag.py +112 -0
- requirements.txt +12 -0
app.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from pathlib import Path
|
3 |
+
from data_preprocessing import process_docs
|
4 |
+
from rag import create_rag_chain
|
5 |
+
import time
|
6 |
+
|
7 |
+
def response_generator(prompt,chain):
|
8 |
+
response = chain.invoke(prompt)
|
9 |
+
for word in response.split():
|
10 |
+
yield word + " "
|
11 |
+
time.sleep(0.05)
|
12 |
+
|
13 |
+
|
14 |
+
|
15 |
+
# Set up the file uploader
|
16 |
+
# uploaded_file = st.file_uploader("Choose a PDF file", type="pdf")
|
17 |
+
|
18 |
+
# Specify the directory to save files
|
19 |
+
save_directory = "docs"
|
20 |
+
save_path="docs/file.pdf"
|
21 |
+
|
22 |
+
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
st.title("📝 InsureAgent")
|
27 |
+
with st.sidebar:
|
28 |
+
uploaded_file = st.file_uploader("Upload a document", type=("pdf"))
|
29 |
+
if uploaded_file is not None:
|
30 |
+
with open(save_path, "wb") as f:
|
31 |
+
f.write(uploaded_file.getbuffer())
|
32 |
+
st.success(f"File saved successfully: {save_path}")
|
33 |
+
retriever=process_docs(save_path)
|
34 |
+
chain,chain_with_sources=create_rag_chain(retriever)
|
35 |
+
|
36 |
+
# Initialize chat history
|
37 |
+
if "messages" not in st.session_state:
|
38 |
+
st.session_state.messages = []
|
39 |
+
for message in st.session_state.messages:
|
40 |
+
with st.chat_message(message["role"]):
|
41 |
+
st.markdown(message["content"])
|
42 |
+
# Streamed response emulator
|
43 |
+
|
44 |
+
if prompt := st.chat_input("What is up?"):
|
45 |
+
# Add user message to chat history
|
46 |
+
st.session_state.messages.append({"role": "user", "content": prompt})
|
47 |
+
# Display user message in chat message container
|
48 |
+
with st.chat_message("user"):
|
49 |
+
st.markdown(prompt)
|
50 |
+
|
51 |
+
# Display assistant response in chat message container
|
52 |
+
with st.chat_message("assistant"):
|
53 |
+
response = st.write_stream(response_generator(prompt,chain))
|
54 |
+
# Add assistant response to chat history
|
55 |
+
st.session_state.messages.append({"role": "assistant", "content": response})
|
56 |
+
|
57 |
+
|
58 |
+
|
59 |
+
|
60 |
+
|
61 |
+
|
62 |
+
|
data_preprocessing.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdfplumber
|
2 |
+
import uuid
|
3 |
+
from langchain_groq import ChatGroq
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate
|
5 |
+
from langchain_core.output_parsers import StrOutputParser
|
6 |
+
from langchain.storage import InMemoryStore
|
7 |
+
from langchain.schema.document import Document
|
8 |
+
from langchain.embeddings import OpenAIEmbeddings
|
9 |
+
from langchain.retrievers.multi_vector import MultiVectorRetriever
|
10 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
11 |
+
from pinecone import Pinecone as pineC, ServerlessSpec
|
12 |
+
from langchain_pinecone import Pinecone
|
13 |
+
import os
|
14 |
+
from dotenv import load_dotenv
|
15 |
+
load_dotenv()
|
16 |
+
|
17 |
+
def extract_pdf(file_path):
|
18 |
+
texts=[]
|
19 |
+
tables=[]
|
20 |
+
# Open the PDF and extract pages
|
21 |
+
with pdfplumber.open(file_path) as pdf:
|
22 |
+
for page in pdf.pages:
|
23 |
+
texts.append(page.extract_text())
|
24 |
+
# print(text) # Extract plain text
|
25 |
+
if page.extract_tables():
|
26 |
+
tables.append(page.extract_tables())
|
27 |
+
# Extract tables
|
28 |
+
return texts, tables
|
29 |
+
def summarize_data(texts,tables):
|
30 |
+
prompt_text = """
|
31 |
+
You are an assistant tasked with summarizing tables and text.
|
32 |
+
Give a concise summary of the table or text that perfectly describes the table in starting 2 sentences.
|
33 |
+
|
34 |
+
Respond only with the summary, no additionnal comment.
|
35 |
+
Do not start your message by saying "Here is a summary" or anything like that.
|
36 |
+
Just give the summary as it is.
|
37 |
+
|
38 |
+
Table or text chunk: {element}
|
39 |
+
"""
|
40 |
+
prompt = ChatPromptTemplate.from_template(prompt_text)
|
41 |
+
|
42 |
+
#
|
43 |
+
# Summary chain
|
44 |
+
model = ChatGroq(temperature=0, model="llama-3.1-8b-instant",api_key=os.environ["GROQ_API_KEY"])
|
45 |
+
summarize_chain = {"element": lambda x: x} | prompt | model | StrOutputParser()
|
46 |
+
# Summarize extracted text
|
47 |
+
text_summaries = []
|
48 |
+
if texts:
|
49 |
+
text_summaries = summarize_chain.batch(texts, {"max_concurrency": 5})
|
50 |
+
|
51 |
+
# Summarize extracted tables
|
52 |
+
tables_html = [str(table) for table in tables] # Convert tables to string format
|
53 |
+
table_summaries = []
|
54 |
+
if tables_html:
|
55 |
+
table_summaries = summarize_chain.batch(tables_html, {"max_concurrency": 5})
|
56 |
+
return texts,text_summaries,tables,table_summaries
|
57 |
+
|
58 |
+
def create_vectorstore():
|
59 |
+
|
60 |
+
model_name = "intfloat/multilingual-e5-large-instruct"
|
61 |
+
model_kwargs = {'device': 'cpu'}
|
62 |
+
encode_kwargs = {'normalize_embeddings': False}
|
63 |
+
hf = HuggingFaceEmbeddings(
|
64 |
+
model_name=model_name,
|
65 |
+
model_kwargs=model_kwargs,
|
66 |
+
encode_kwargs=encode_kwargs
|
67 |
+
)
|
68 |
+
# index= pc.Index("gaido-rag")
|
69 |
+
# The vectorstore to use to index the child chunks
|
70 |
+
# vectorstore = Chroma(collection_name="multi_modal_rag", embedding_function=hf)
|
71 |
+
|
72 |
+
# The storage layer for the parent documents
|
73 |
+
store = InMemoryStore()
|
74 |
+
id_key = "doc_id"
|
75 |
+
|
76 |
+
pc = pineC(api_key=os.environ["PINECONE_API_KEY"])
|
77 |
+
|
78 |
+
index_name = "gaidorag"
|
79 |
+
text_field = "text"
|
80 |
+
cloud ='aws'
|
81 |
+
region = 'us-east-1'
|
82 |
+
|
83 |
+
spec = ServerlessSpec(cloud=cloud, region=region)
|
84 |
+
# check if index already exists (it shouldn't if this is first time)
|
85 |
+
if index_name not in pc.list_indexes().names():
|
86 |
+
# if does not exist, create index
|
87 |
+
pc.create_index(
|
88 |
+
index_name,
|
89 |
+
dimension=1024, # dimensionality of text-embedding-ada-002
|
90 |
+
metric='cosine',
|
91 |
+
spec=spec
|
92 |
+
)
|
93 |
+
# switch back to normal index for langchain
|
94 |
+
index = pc.Index(index_name)
|
95 |
+
|
96 |
+
vectorstore = Pinecone(
|
97 |
+
index, hf, text_field
|
98 |
+
)
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
# The retriever (empty to start)
|
103 |
+
retriever = MultiVectorRetriever(
|
104 |
+
vectorstore=vectorstore,
|
105 |
+
docstore=store,
|
106 |
+
id_key=id_key,
|
107 |
+
)
|
108 |
+
return retriever
|
109 |
+
def embed_docs(retriever,texts,text_summaries,tables,table_summaries):
|
110 |
+
# Add texts
|
111 |
+
id_key = "doc_id"
|
112 |
+
doc_ids = [str(uuid.uuid4()) for _ in texts]
|
113 |
+
summary_texts = [
|
114 |
+
Document(page_content=summary, metadata={id_key: doc_ids[i]}) for i, summary in enumerate(text_summaries)
|
115 |
+
]
|
116 |
+
retriever.vectorstore.add_documents(summary_texts)
|
117 |
+
retriever.docstore.mset(list(zip(doc_ids, texts)))
|
118 |
+
|
119 |
+
# Add tables
|
120 |
+
table_ids = [str(uuid.uuid4()) for _ in tables]
|
121 |
+
summary_tables = [
|
122 |
+
Document(page_content=summary, metadata={id_key: table_ids[i]}) for i, summary in enumerate(table_summaries)
|
123 |
+
]
|
124 |
+
retriever.vectorstore.add_documents(summary_tables)
|
125 |
+
retriever.docstore.mset(list(zip(table_ids, tables)))
|
126 |
+
|
127 |
+
|
128 |
+
def process_docs(file_path):
|
129 |
+
texts,tables=extract_pdf(file_path)
|
130 |
+
texts,text_summaries,tables,table_summaries=summarize_data(texts,tables)
|
131 |
+
retriever=create_vectorstore()
|
132 |
+
embed_docs(retriever,texts,text_summaries,tables,table_summaries)
|
133 |
+
return retriever
|
rag.py
ADDED
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
|
2 |
+
from langchain_core.messages import SystemMessage, HumanMessage
|
3 |
+
from langchain_groq import ChatGroq
|
4 |
+
from langchain_core.prompts import ChatPromptTemplate
|
5 |
+
from langchain_core.output_parsers import StrOutputParser
|
6 |
+
from base64 import b64decode
|
7 |
+
import os
|
8 |
+
from dotenv import load_dotenv
|
9 |
+
load_dotenv
|
10 |
+
|
11 |
+
|
12 |
+
def parse_docs(docs):
|
13 |
+
"""Split base64-encoded images and texts"""
|
14 |
+
b64 = []
|
15 |
+
text = []
|
16 |
+
for doc in docs:
|
17 |
+
try:
|
18 |
+
b64decode(doc)
|
19 |
+
b64.append(doc)
|
20 |
+
except Exception as e:
|
21 |
+
text.append(doc)
|
22 |
+
return {"images": b64, "texts": text}
|
23 |
+
|
24 |
+
|
25 |
+
def build_prompt(kwargs):
|
26 |
+
docs_by_type = kwargs["context"]
|
27 |
+
user_question = kwargs["question"]
|
28 |
+
|
29 |
+
# Extract text context
|
30 |
+
context_text = ""
|
31 |
+
if docs_by_type.get("texts"):
|
32 |
+
for text_element in docs_by_type["texts"]:
|
33 |
+
if isinstance(text_element, list):
|
34 |
+
# Flatten nested lists before joining
|
35 |
+
flat_text = " ".join(
|
36 |
+
" ".join(map(str, sub_element)) if isinstance(sub_element, list) else str(sub_element)
|
37 |
+
for sub_element in text_element
|
38 |
+
)
|
39 |
+
context_text += flat_text + "\n"
|
40 |
+
else:
|
41 |
+
context_text += str(text_element) + "\n"
|
42 |
+
|
43 |
+
# Extract table context
|
44 |
+
context_tables = ""
|
45 |
+
if docs_by_type.get("tables"):
|
46 |
+
for table in docs_by_type["tables"]:
|
47 |
+
table_str = "\n".join([" | ".join(map(str, row)) for row in table]) # Convert table rows to strings
|
48 |
+
context_tables += f"\nTable:\n{table_str}\n"
|
49 |
+
|
50 |
+
# Construct prompt with context (including tables)
|
51 |
+
prompt_template = f"""
|
52 |
+
Answer the question based only on the following context only Ground Truth Final Answer, which includes text and tables.
|
53 |
+
If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
54 |
+
Be specific
|
55 |
+
|
56 |
+
Context:
|
57 |
+
{context_text}
|
58 |
+
|
59 |
+
{context_tables}
|
60 |
+
|
61 |
+
Question: {user_question}
|
62 |
+
"""
|
63 |
+
|
64 |
+
prompt_content = [{"type": "text", "text": prompt_template}]
|
65 |
+
|
66 |
+
# If images are provided, include them
|
67 |
+
# if docs_by_type.get("images"):
|
68 |
+
# for image in docs_by_type["images"]:
|
69 |
+
# prompt_content.append(
|
70 |
+
# {
|
71 |
+
# "type": "image_url",
|
72 |
+
# "image_url": {"url": f"data:image/jpeg;base64,{image}"},
|
73 |
+
# }
|
74 |
+
# )
|
75 |
+
|
76 |
+
return ChatPromptTemplate.from_messages(
|
77 |
+
[
|
78 |
+
HumanMessage(content=prompt_content),
|
79 |
+
]
|
80 |
+
)
|
81 |
+
|
82 |
+
def create_rag_chain(retriever):
|
83 |
+
chain = (
|
84 |
+
{
|
85 |
+
"context": retriever | RunnableLambda(parse_docs),
|
86 |
+
"question": RunnablePassthrough(),
|
87 |
+
}
|
88 |
+
| RunnableLambda(build_prompt)
|
89 |
+
| ChatGroq(model="llama-3.3-70b-versatile",api_key=os.environ["GROQ_API_KEY"])
|
90 |
+
| StrOutputParser()
|
91 |
+
)
|
92 |
+
|
93 |
+
chain_with_sources = {
|
94 |
+
"context": retriever | RunnableLambda(parse_docs),
|
95 |
+
"question": RunnablePassthrough(),
|
96 |
+
} | RunnablePassthrough().assign(
|
97 |
+
response=(
|
98 |
+
RunnableLambda(build_prompt)
|
99 |
+
| ChatGroq(model="llama-3.3-70b-versatile",api_key=os.environ["GROQ_API_KEY"])
|
100 |
+
| StrOutputParser()
|
101 |
+
)
|
102 |
+
)
|
103 |
+
return chain, chain_with_sources
|
104 |
+
|
105 |
+
def invoke_chain(chain):
|
106 |
+
response = chain.invoke(
|
107 |
+
"What is the policy start and expiry date?"
|
108 |
+
|
109 |
+
|
110 |
+
)
|
111 |
+
|
112 |
+
return response
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pdfplumber
|
2 |
+
tiktoken
|
3 |
+
langchain
|
4 |
+
langchain-community
|
5 |
+
langchain-openai
|
6 |
+
langchain-groq
|
7 |
+
python-dotenv
|
8 |
+
langchain-huggingface
|
9 |
+
ragas
|
10 |
+
datasets
|
11 |
+
langchain-pinecone
|
12 |
+
streamlit
|