mercybabs commited on
Commit
44cbe88
·
1 Parent(s): 04ad69c

Initial commit - added my RAG Streamlit app

Browse files
Files changed (3) hide show
  1. Dockerfile +23 -0
  2. app.py +286 -0
  3. requirements.txt +152 -0
Dockerfile ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python image
2
+ FROM python:3.12
3
+
4
+ # Set working directory inside the container
5
+ WORKDIR /app
6
+
7
+ # Copy the application files into the container
8
+ COPY . /app
9
+
10
+ # Install system dependencies
11
+ RUN apt update && apt install -y \
12
+ tesseract-ocr \
13
+ poppler-utils \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Install Python dependencies
17
+ RUN pip install --no-cache-dir -r requirements.txt
18
+
19
+ # Expose the port Streamlit runs on
20
+ EXPOSE 8501
21
+
22
+ # Set the default command to run the app
23
+ CMD ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]
app.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #Import Library
2
+ from unstructured.partition.pdf import partition_pdf
3
+ from langchain_openai import ChatOpenAI
4
+ from langchain_core.messages import SystemMessage, HumanMessage
5
+ from langchain_core.prompts import ChatPromptTemplate
6
+ from langchain_core.output_parsers import StrOutputParser
7
+ from langchain.schema.runnable import RunnablePassthrough,RunnableLambda
8
+
9
+ from langchain_postgres.vectorstores import PGVector
10
+ from database import COLLECTION_NAME, CONNECTION_STRING
11
+ from langchain_community.storage import RedisStore
12
+ from langchain.schema.document import Document
13
+ from langchain_openai import OpenAIEmbeddings
14
+ from langchain.retrievers.multi_vector import MultiVectorRetriever
15
+ from pathlib import Path
16
+ from IPython.display import display, HTML
17
+ from base64 import b64decode
18
+ import os, hashlib, shutil, uuid, json, time
19
+ import torch, redis, streamlit as st
20
+ import logging
21
+ import openai
22
+
23
+
24
+ # from dotenv import load_dotenv
25
+ # load_dotenv()
26
+
27
+ openai_api_key = os.getenv("OPENAI_API_KEY")
28
+
29
+ # Ensure PyTorch module path is correctly set
30
+ torch.classes.__path__ = [os.path.join(torch.__path__[0], torch.classes.__file__)]
31
+
32
+ # Configure logging
33
+ logging.basicConfig(level=logging.INFO)
34
+
35
+ # Initialize Redis client
36
+ client = redis.Redis(host="localhost", port=6379, db=0)
37
+
38
+
39
+
40
+
41
+ #Data Loading
42
+ def load_pdf_data(file_path):
43
+ logging.info(f"Data ready to be partitioned and loaded ")
44
+ raw_pdf_elements = partition_pdf(
45
+ filename=file_path,
46
+
47
+ infer_table_structure=True,
48
+ strategy = "hi_res",
49
+
50
+ extract_image_block_types = ["Image"],
51
+ extract_image_block_to_payload = True,
52
+
53
+ chunking_strategy="by_title",
54
+ mode='elements',
55
+ max_characters=10000,
56
+ new_after_n_chars=5000,
57
+ combine_text_under_n_chars=2000,
58
+ image_output_dir_path="data/",
59
+ )
60
+ logging.info(f"Pdf data finish loading, chunks now available!")
61
+ return raw_pdf_elements
62
+
63
+ # Generate a unique hash for a PDF file
64
+ def get_pdf_hash(pdf_path):
65
+ """Generate a SHA-256 hash of the PDF file content."""
66
+ with open(pdf_path, "rb") as f:
67
+ pdf_bytes = f.read()
68
+ return hashlib.sha256(pdf_bytes).hexdigest()
69
+
70
+ # Summarize extracted text and tables using LLM
71
+ def summarize_text_and_tables(text, tables):
72
+ logging.info("Ready to summarize data with LLM")
73
+ prompt_text = """You are an assistant tasked with summarizing text and tables. \
74
+
75
+ You are to give a concise summary of the table or text and do nothing else.
76
+ Table or text chunk: {element} """
77
+ prompt = ChatPromptTemplate.from_template(prompt_text)
78
+ model = ChatOpenAI(temperature=0.6, model="gpt-4o-mini", openai_api_key=openai_api_key)
79
+ summarize_chain = {"element": RunnablePassthrough()}| prompt | model | StrOutputParser()
80
+ logging.info(f"{model} done with summarization")
81
+ return {
82
+ "text": summarize_chain.batch(text, {"max_concurrency": 5}),
83
+ "table": summarize_chain.batch(tables, {"max_concurrency": 5})
84
+ }
85
+
86
+ #Initialize a pgvector and retriever for storing and searching documents
87
+ def initialize_retriever():
88
+
89
+ store = RedisStore(client=client)
90
+ id_key = "doc_id"
91
+ vectorstore = PGVector(
92
+ embeddings=OpenAIEmbeddings(),
93
+ collection_name=COLLECTION_NAME,
94
+ connection=CONNECTION_STRING,
95
+ use_jsonb=True,
96
+ )
97
+ retrieval_loader = MultiVectorRetriever(vectorstore=vectorstore, docstore=store, id_key="doc_id")
98
+ return retrieval_loader
99
+
100
+
101
+ # Store text, tables, and their summaries in the retriever
102
+
103
+ def store_docs_in_retriever(text, text_summary, table, table_summary, retriever):
104
+ """Store text and table documents along with their summaries in the retriever."""
105
+
106
+ def add_documents_to_retriever(documents, summaries, retriever, id_key = "doc_id"):
107
+ """Helper function to add documents and their summaries to the retriever."""
108
+ if not summaries:
109
+ return None, []
110
+
111
+ doc_ids = [str(uuid.uuid4()) for _ in documents]
112
+ summary_docs = [
113
+ Document(page_content=summary, metadata={id_key: doc_ids[i]})
114
+ for i, summary in enumerate(summaries)
115
+ ]
116
+
117
+ retriever.vectorstore.add_documents(summary_docs, ids=doc_ids)
118
+ retriever.docstore.mset(list(zip(doc_ids, documents)))
119
+
120
+ # Add text, table, and image summaries to the retriever
121
+ add_documents_to_retriever(text, text_summary, retriever)
122
+ add_documents_to_retriever(table, table_summary, retriever)
123
+ return retriever
124
+
125
+
126
+ # Parse the retriever output
127
+ def parse_retriver_output(data):
128
+ parsed_elements = []
129
+ for element in data:
130
+ # Decode bytes to string if necessary
131
+ if isinstance(element, bytes):
132
+ element = element.decode("utf-8")
133
+
134
+ parsed_elements.append(element)
135
+
136
+ return parsed_elements
137
+
138
+
139
+ # Chat with the LLM using retrieved context
140
+
141
+ def chat_with_llm(retriever):
142
+
143
+ logging.info(f"Context ready to send to LLM ")
144
+ prompt_text = """
145
+ You are an AI Assistant tasked with understanding detailed
146
+ information from text and tables. You are to answer the question based on the
147
+ context provided to you. You must not go beyond the context given to you.
148
+
149
+ Context:
150
+ {context}
151
+
152
+ Question:
153
+ {question}
154
+ """
155
+
156
+ prompt = ChatPromptTemplate.from_template(prompt_text)
157
+ model = ChatOpenAI(temperature=0.6, model="gpt-4o-mini", openai_api_key=openai_api_key)
158
+
159
+ rag_chain = ({
160
+ "context": retriever | RunnableLambda(parse_retriver_output), "question": RunnablePassthrough(),
161
+ }
162
+ | prompt
163
+ | model
164
+ | StrOutputParser()
165
+ )
166
+
167
+ logging.info(f"Completed! ")
168
+
169
+ return rag_chain
170
+
171
+ # Generate temporary file path of uploaded docs
172
+ def _get_file_path(file_upload):
173
+
174
+ temp_dir = "temp"
175
+ os.makedirs(temp_dir, exist_ok=True) # Ensure the directory exists
176
+
177
+ if isinstance(file_upload, str):
178
+ file_path = file_upload # Already a string path
179
+ else:
180
+ file_path = os.path.join(temp_dir, file_upload.name)
181
+ with open(file_path, "wb") as f:
182
+ f.write(file_upload.getbuffer())
183
+ return file_path
184
+
185
+
186
+ # Process uploaded PDF file
187
+ def process_pdf(file_upload):
188
+ print('Processing PDF hash info...')
189
+
190
+ file_path = _get_file_path(file_upload)
191
+ pdf_hash = get_pdf_hash(file_path)
192
+
193
+ load_retriever = initialize_retriever()
194
+ existing = client.exists(f"pdf:{pdf_hash}")
195
+ print(f"Checking Redis for hash {pdf_hash}: {'Exists' if existing else 'Not found'}")
196
+
197
+ if existing:
198
+ print(f"PDF already exists with hash {pdf_hash}. Skipping upload.")
199
+ return load_retriever
200
+
201
+ print(f"New PDF detected. Processing... {pdf_hash}")
202
+
203
+ pdf_elements = load_pdf_data(file_path)
204
+
205
+ tables = [element.metadata.text_as_html for element in
206
+ pdf_elements if 'Table' in str(type(element))]
207
+
208
+ text = [element.text for element in pdf_elements if
209
+ 'CompositeElement' in str(type(element))]
210
+
211
+ summaries = summarize_text_and_tables(text, tables)
212
+ retriever = store_docs_in_retriever(text, summaries['text'], tables, summaries['table'], load_retriever)
213
+
214
+ # Store the PDF hash in Redis
215
+ client.set(f"pdf:{pdf_hash}", json.dumps({"text": "PDF processed"}))
216
+
217
+ # Debug: Check if Redis stored the key
218
+ stored = client.exists(f"pdf:{pdf_hash}")
219
+ # #remove temp directory
220
+ # shutil.rmtree("dir")
221
+ print(f"Stored PDF hash in Redis: {'Success' if stored else 'Failed'}")
222
+ return retriever
223
+
224
+
225
+ #Invoke chat with LLM based on uploaded PDF and user query
226
+ def invoke_chat(file_upload, message):
227
+
228
+ retriever =process_pdf(file_upload)
229
+ rag_chain = chat_with_llm(retriever)
230
+ response = rag_chain.invoke(message)
231
+ response_placeholder = st.empty()
232
+ response_placeholder.write(response)
233
+ return response
234
+
235
+
236
+ # Main application interface using Streamlit
237
+ def main():
238
+
239
+ st.title("PDF Chat Assistant ")
240
+ logging.info("App started")
241
+
242
+ if 'messages' not in st.session_state:
243
+ st.session_state.messages = []
244
+
245
+
246
+ file_upload = st.sidebar.file_uploader(
247
+ label="Upload", type=["pdf"],
248
+ accept_multiple_files=False,
249
+ key="pdf_uploader"
250
+ )
251
+
252
+ if file_upload:
253
+ st.success("File uploaded successfully! You can now ask your question.")
254
+
255
+ # Prompt for user input
256
+ if prompt := st.chat_input("Your question"):
257
+ st.session_state.messages.append({"role": "user", "content": prompt})
258
+
259
+ # Display chat history
260
+ for message in st.session_state.messages:
261
+ with st.chat_message(message["role"]):
262
+ st.write(message["content"])
263
+
264
+ # Generate response if last message is not from assistant
265
+ if st.session_state.messages and st.session_state.messages[-1]["role"] != "assistant":
266
+ with st.chat_message("assistant"):
267
+ start_time = time.time()
268
+ logging.info("Generating response...")
269
+ with st.spinner("Processing..."):
270
+ user_message = " ".join([msg["content"] for msg in st.session_state.messages if msg])
271
+ response_message = invoke_chat(file_upload, user_message)
272
+
273
+ duration = time.time() - start_time
274
+ response_msg_with_duration = f"{response_message}\n\nDuration: {duration:.2f} seconds"
275
+
276
+ st.session_state.messages.append({"role": "assistant", "content": response_msg_with_duration})
277
+ st.write(f"Duration: {duration:.2f} seconds")
278
+ logging.info(f"Response: {response_message}, Duration: {duration:.2f} s")
279
+
280
+
281
+
282
+
283
+
284
+
285
+ if __name__ == "__main__":
286
+ main()
requirements.txt ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.11.14
4
+ aiosignal==1.3.2
5
+ altair==5.5.0
6
+ annotated-types==0.7.0
7
+ antlr4-python3-runtime==4.9.3
8
+ backoff==2.2.1
9
+ blinker==1.9.0
10
+ cachetools==5.5.2
11
+ chardet==5.2.0
12
+ click==8.1.8
13
+ coloredlogs==15.0.1
14
+ contourpy==1.3.1
15
+ cycler==0.12.1
16
+ dataclasses-json==0.6.7
17
+ Deprecated==1.2.18
18
+ distro==1.9.0
19
+ effdet==0.4.1
20
+ emoji==2.14.1
21
+ et_xmlfile==2.0.0
22
+ eval_type_backport==0.2.2
23
+ filetype==1.2.0
24
+ flatbuffers==25.2.10
25
+ fonttools==4.56.0
26
+ frozenlist==1.5.0
27
+ fsspec==2025.3.0
28
+ gitdb==4.0.12
29
+ GitPython==3.1.44
30
+ google-api-core==2.24.2
31
+ google-auth==2.38.0
32
+ google-cloud-vision==3.10.1
33
+ googleapis-common-protos==1.69.2
34
+ greenlet==3.1.1
35
+ grpcio==1.71.0
36
+ grpcio-status==1.71.0
37
+ html5lib==1.1
38
+ httpx-sse==0.4.0
39
+ huggingface-hub==0.29.3
40
+ humanfriendly==10.0
41
+ iopath==0.1.10
42
+ jiter==0.9.0
43
+ joblib==1.4.2
44
+ jsonpatch==1.33
45
+ kiwisolver==1.4.8
46
+ langchain==0.3.20
47
+ langchain-community==0.3.19
48
+ langchain-core==0.3.45
49
+ langchain-openai==0.3.9
50
+ langchain-postgres==0.0.13
51
+ langchain-text-splitters==0.3.6
52
+ langdetect==1.0.9
53
+ langsmith==0.3.15
54
+ layoutparser==0.3.4
55
+ lxml==5.3.1
56
+ Markdown==3.7
57
+ marshmallow==3.26.1
58
+ matplotlib==3.10.1
59
+ mpmath==1.3.0
60
+ multidict==6.2.0
61
+ narwhals==1.31.0
62
+ networkx==3.4.2
63
+ nltk==3.9.1
64
+ numpy==1.26.4
65
+ nvidia-cublas-cu12==12.4.5.8
66
+ nvidia-cuda-cupti-cu12==12.4.127
67
+ nvidia-cuda-nvrtc-cu12==12.4.127
68
+ nvidia-cuda-runtime-cu12==12.4.127
69
+ nvidia-cudnn-cu12==9.1.0.70
70
+ nvidia-cufft-cu12==11.2.1.3
71
+ nvidia-curand-cu12==10.3.5.147
72
+ nvidia-cusolver-cu12==11.6.1.9
73
+ nvidia-cusparse-cu12==12.3.1.170
74
+ nvidia-cusparselt-cu12==0.6.2
75
+ nvidia-nccl-cu12==2.21.5
76
+ nvidia-nvjitlink-cu12==12.4.127
77
+ nvidia-nvtx-cu12==12.4.127
78
+ olefile==0.47
79
+ omegaconf==2.3.0
80
+ onnx==1.17.0
81
+ onnxruntime==1.21.0
82
+ openai==1.66.3
83
+ opencv-python==4.11.0.86
84
+ openpyxl==3.1.5
85
+ orjson==3.10.15
86
+ pandas==2.2.3
87
+ pdf2image==1.17.0
88
+ pdfminer.six==20231228
89
+ pdfplumber==0.11.5
90
+ pgvector==0.3.6
91
+ pi_heif==0.22.0
92
+ pikepdf==9.5.2
93
+ pillow==11.1.0
94
+ portalocker==3.1.1
95
+ propcache==0.3.0
96
+ proto-plus==1.26.1
97
+ protobuf==5.29.3
98
+ psycopg==3.2.6
99
+ psycopg-pool==3.2.6
100
+ pyarrow==19.0.1
101
+ pyasn1==0.6.1
102
+ pyasn1_modules==0.4.1
103
+ pycocotools==2.0.8
104
+ pydantic==2.10.6
105
+ pydantic-settings==2.8.1
106
+ pydantic_core==2.27.2
107
+ pydeck==0.9.1
108
+ pypandoc==1.15
109
+ pyparsing==3.2.1
110
+ pypdf==5.4.0
111
+ pypdfium2==4.30.1
112
+ python-docx==1.1.2
113
+ python-dotenv==1.0.1
114
+ python-iso639==2025.2.18
115
+ python-magic==0.4.27
116
+ python-multipart==0.0.20
117
+ python-oxmsg==0.0.2
118
+ python-pptx==1.0.2
119
+ pytz==2025.1
120
+ RapidFuzz==3.12.2
121
+ redis==5.2.1
122
+ regex==2024.11.6
123
+ rsa==4.9
124
+ safetensors==0.5.3
125
+ scipy==1.15.2
126
+ smmap==5.0.2
127
+ SQLAlchemy==2.0.39
128
+ streamlit==1.43.2
129
+ sympy==1.13.1
130
+ tenacity==9.0.0
131
+ tiktoken==0.9.0
132
+ timm==1.0.15
133
+ tokenizers==0.21.1
134
+ toml==0.10.2
135
+ torch==2.6.0
136
+ torchvision==0.21.0
137
+ tqdm==4.67.1
138
+ transformers==4.49.0
139
+ triton==3.2.0
140
+ typing-inspect==0.9.0
141
+ typing-inspection==0.4.0
142
+ tzdata==2025.1
143
+ unstructured==0.16.10
144
+ unstructured-client==0.31.1
145
+ unstructured-inference==0.8.1
146
+ unstructured.pytesseract==0.3.15
147
+ watchdog==6.0.0
148
+ wrapt==1.17.2
149
+ xlrd==2.0.1
150
+ XlsxWriter==3.2.2
151
+ yarl==1.18.3
152
+ zstandard==0.23.0