Towhidul commited on
Commit
01a6269
·
verified ·
1 Parent(s): a74c56e

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +332 -0
app.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import streamlit as st
4
+ from dotenv import load_dotenv
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.vectorstores import FAISS
7
+ from langchain.prompts import PromptTemplate
8
+ from langchain_community.llms import Cohere
9
+ from langchain.embeddings.cohere import CohereEmbeddings
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain.chains import ConversationalRetrievalChain
12
+ from langchain_community.document_loaders import PyPDFLoader
13
+
14
+ # Imports for Data Ingestion
15
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
16
+ from langchain_community.document_loaders.pdf import PyPDFDirectoryLoader
17
+ from langchain_community.document_loaders import PyPDFLoader
18
+ import os
19
+
20
+ import tempfile
21
+ from langchain_openai import ChatOpenAI
22
+ from langchain.document_loaders import UnstructuredFileLoader
23
+ from langchain_community.vectorstores import FAISS
24
+ from langchain.embeddings import HuggingFaceEmbeddings
25
+ from langchain.text_splitter import CharacterTextSplitter
26
+ from langchain.chains import RetrievalQA
27
+ from langchain_openai import OpenAIEmbeddings
28
+ from langchain.vectorstores import FAISS
29
+ from langchain import PromptTemplate
30
+ from langchain_text_splitters import (
31
+ Language,
32
+ RecursiveCharacterTextSplitter,
33
+ )
34
+ from PIL import Image, ImageOps
35
+ import io
36
+ import PyPDF2
37
+ import requests
38
+ import pymupdf4llm
39
+ import pathlib
40
+ import time
41
+
42
+ import boto3
43
+ import json
44
+ from openai import OpenAI
45
+ # from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
46
+ from langchain.retrievers import ContextualCompressionRetriever
47
+ from langchain.retrievers.document_compressors import FlashrankRerank
48
+
49
+ from PyPDF2 import PdfReader # Add this import for PDF reading
50
+ import uuid # Import uuid for unique keys
51
+
52
+ # Hyperparameters
53
+ PDF_CHUNK_SIZE = 1024
54
+ PDF_CHUNK_OVERLAP = 256
55
+ k = 3
56
+
57
+ # client = OpenAI(
58
+ # # defaults to os.environ.get("OPENAI_API_KEY")
59
+ # api_key=os.getenv("OPENAI_API_KEY"),
60
+ # )
61
+
62
+ from langchain_openai import OpenAIEmbeddings
63
+ embeddings = OpenAIEmbeddings(
64
+ model="text-embedding-3-large",api_key=os.getenv("OPENAI_API_KEY")
65
+ # With the `text-embedding-3` class
66
+ # of models, you can specify the size
67
+ # of the embeddings you want returned.
68
+ # dimensions=1024
69
+ )
70
+ from langchain_openai import ChatOpenAI
71
+ from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
72
+
73
+ llm = ChatOpenAI(
74
+ model="gpt-4o",
75
+ temperature=0,
76
+ max_tokens=None,
77
+ timeout=None,
78
+ max_retries=2,
79
+ api_key=os.getenv("OPENAI_API_KEY"), # if you prefer to pass api key in directly instaed of using env vars
80
+ # base_url="...",
81
+ # organization="...",
82
+ # other params...
83
+ )
84
+
85
+ default_system_prompt = """
86
+ You are a helpful and knowledgeable assistant who is expert on medical question answering.
87
+ Your role is select the best answer for queries related to medical information.
88
+ YOU WILL ALWAYS ANSWER FROM THE CONTEXT PROVIDED. If answer is not provided, politely say that you are not aware of the answer.
89
+ """
90
+
91
+
92
+ knowledge_base_prompt = """You have been provided with medical notes and books.
93
+ Your role is provide the best answer for queries related to medical information.
94
+ YOU WILL ALWAYS ANSWER FROM THE CONTEXT PROVIDED. If answer is not provided, politely say that you are not aware of the answer.
95
+ """
96
+ #- Keep answers short and direct.
97
+
98
+ # Function to ingest PDFs from the directory
99
+ def data_ingestion():
100
+ loader = PyPDFDirectoryLoader("finance_documents")
101
+ documents = loader.load()
102
+ # Split the text into chunks
103
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=4096, chunk_overlap=512)
104
+ docs = text_splitter.split_documents(documents)
105
+ return docs
106
+
107
+ # Function to create and save vector store
108
+ def setup_vector_store(documents):
109
+ # Create a vector store using the documents and embeddings
110
+ vector_store = FAISS.from_documents(documents, embeddings)
111
+ # Save the vector store locally
112
+ vector_store.save_local("faiss_index_medical")
113
+
114
+ # Function to load or create vector store
115
+ def load_or_create_vector_store():
116
+ # Check if the vector store file exists
117
+ if os.path.exists("faiss_index_medical"):
118
+ # Load the vector store
119
+ vector_store = FAISS.load_local("faiss_index_medical", embeddings, allow_dangerous_deserialization=True)
120
+ print("Loaded existing vector store.")
121
+ else:
122
+ # If the vector store doesn't exist, create it
123
+ docs = data_ingestion()
124
+ setup_vector_store(docs)
125
+ vector_store = FAISS.load_local("faiss_index_medical", embeddings, allow_dangerous_deserialization=True)
126
+ print("Created and loaded new vector store.")
127
+
128
+ return vector_store
129
+
130
+ def load_and_pad_image(image_path, size=(64, 64)):
131
+ img = Image.open(image_path)
132
+
133
+ # Make the image square by padding it with white or any background color you like
134
+ img_with_padding = ImageOps.pad(img, size) # Change color if needed
135
+ return img_with_padding
136
+
137
+ def LLM(llm, query):
138
+ # Use vectorstore from uploaded files if available
139
+ if 'vectorstore' in st.session_state and st.session_state['vectorstore'] is not None:
140
+ system_prompt = knowledge_base_prompt
141
+ vectorstore = st.session_state['vectorstore']
142
+ else:
143
+ system_prompt = default_system_prompt
144
+ vectorstore = load_or_create_vector_store()
145
+ knowledge_base = vectorstore
146
+ compressor = FlashrankRerank()
147
+ retriever = knowledge_base.as_retriever(search_kwargs={"k": k})
148
+ compression_retriever = ContextualCompressionRetriever(
149
+ base_compressor=compressor, base_retriever=retriever
150
+ )
151
+
152
+ template = '''
153
+ %s
154
+ -------------------------------
155
+ Context: {context}
156
+
157
+ Current conversation:
158
+ {chat_history}
159
+
160
+ Question: {question}
161
+ Answer:
162
+ ''' % (system_prompt)
163
+
164
+ PROMPT = PromptTemplate(
165
+ template=template, input_variables=["context", "chat_history", "question"]
166
+ )
167
+ chain_type_kwargs = {"prompt": PROMPT}
168
+
169
+ # Initialize memory to manage chat history if it doesn't exist
170
+ if "memory" not in st.session_state:
171
+ st.session_state.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
172
+
173
+ # Retrieve chat history from st.session_state.messages
174
+ chat_history = [
175
+ (msg["role"], msg["content"]) for msg in st.session_state.messages if msg["role"] in ["user", "assistant"]
176
+ ]
177
+
178
+ # Create the conversational chain with memory for chat history
179
+ conversation_chain = ConversationalRetrievalChain.from_llm(
180
+ llm=llm,
181
+ retriever=compression_retriever,
182
+ memory=st.session_state.memory,
183
+ verbose=True,
184
+ combine_docs_chain_kwargs=chain_type_kwargs
185
+ )
186
+
187
+ # Run the conversation chain with the latest user query and retrieve response
188
+ response = conversation_chain({"question": query, "chat_history": chat_history})
189
+ return response.get("answer")
190
+
191
+ # Function to get text from PDF
192
+ def get_pdf_text(pdf_file):
193
+ pdf_reader = PdfReader(pdf_file)
194
+ return "".join(page.extract_text() for page in pdf_reader.pages)
195
+
196
+
197
+ def get_text_chunks(text, file_name, max_chars=16000): # Approx. 4000 tokens
198
+ # Initial large chunk size
199
+ large_text_splitter = RecursiveCharacterTextSplitter(chunk_size=8000, chunk_overlap=512)
200
+ docs = large_text_splitter.create_documents([text])
201
+
202
+ # Check character length (as proxy for tokens) and split if a chunk exceeds the limit
203
+ valid_docs = []
204
+ for doc in docs:
205
+ if len(doc.page_content) > max_chars:
206
+ # Further split if the chunk exceeds max_chars
207
+ smaller_text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
208
+ valid_docs.extend(smaller_text_splitter.create_documents([doc.page_content]))
209
+ else:
210
+ valid_docs.append(doc)
211
+
212
+ # Add metadata to each document chunk
213
+ for doc in valid_docs:
214
+ doc.metadata["file_name"] = file_name
215
+ return valid_docs
216
+ # Function to process uploaded files
217
+ def process_files(file_list):
218
+ all_docs = []
219
+ raw_text = ""
220
+ for file in file_list:
221
+ file_extension = os.path.splitext(file.name)[1]
222
+ file_name = os.path.splitext(file.name)[0]
223
+ if file_extension == ".pdf":
224
+ raw_text += get_pdf_text(file)
225
+ elif file_extension == ".txt":
226
+ raw_text += file.read().decode('utf-8')
227
+ elif file_extension == ".csv":
228
+ raw_text += file.read().decode('utf-8')
229
+ else:
230
+ st.warning("File type not supported")
231
+
232
+ # Now, split the text into chunks
233
+ docs = get_text_chunks(raw_text, file_name)
234
+ for doc in docs:
235
+ doc.metadata["extension"] = file_extension
236
+ doc.metadata["source"] = file.name
237
+ all_docs.extend(docs)
238
+ if all_docs:
239
+ # Create vectorstore
240
+ vectorstore = FAISS.from_documents(all_docs, embeddings)
241
+ # Save vectorstore in session state
242
+ st.session_state['vectorstore'] = vectorstore
243
+ st.success("Knowledge base updated with uploaded files!")
244
+ else:
245
+ st.warning("No valid files were uploaded. Please upload PDF, TXT, or CSV files.")
246
+
247
+ # Main function to set up Streamlit chat interface
248
+ def main():
249
+ load_dotenv()
250
+
251
+ favicon_path = "medical.png" # Replace with the actual path to your image file
252
+ favicon_image = load_and_pad_image(favicon_path)
253
+
254
+ st.set_page_config(
255
+ page_title="Medical Chatbot",
256
+ page_icon=favicon_image,
257
+ )
258
+ # Create two columns for the logo and title text
259
+ col1, col2 = st.columns([1, 8]) # Adjust the column width ratios as needed
260
+
261
+ # Reduce spacing by adjusting padding
262
+ with col1:
263
+ st.image(favicon_image) # Display the logo image
264
+
265
+ with col2:
266
+ # Reduce spacing by adding custom HTML with no margin/padding
267
+ st.markdown("""
268
+ <h1 style='text-align: left; margin-top: -12px;'>
269
+ Medical Chatbot
270
+ </h1>
271
+ """, unsafe_allow_html=True)
272
+
273
+ # Initialize the unique key for the file uploader
274
+ if 'file_uploader_key' not in st.session_state:
275
+ st.session_state['file_uploader_key'] = str(uuid.uuid4())
276
+
277
+ # Add file upload component in the sidebar
278
+ with st.sidebar:
279
+ st.subheader("Your PDFs")
280
+ pdf_docs = st.file_uploader(
281
+ "Upload PDFs and click process",
282
+ type=["pdf", "txt", "csv"],
283
+ accept_multiple_files=True,
284
+ key=st.session_state['file_uploader_key']
285
+ )
286
+ if st.button("Process"):
287
+ if pdf_docs is not None and len(pdf_docs) > 0:
288
+ with st.spinner("Processing PDFs"):
289
+ process_files(pdf_docs)
290
+ else:
291
+ st.error("Please upload at least one file.")
292
+
293
+ # Button to start a new session
294
+ if st.button("New Session"):
295
+ # Clear the chat history and memory
296
+ st.session_state["messages"] = [{"role": "assistant", "content": "Hello there, how can I help you?"}]
297
+ st.session_state.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
298
+ # Clear the vectorstore from session state
299
+ st.session_state['vectorstore'] = None
300
+ # Assign a new key to the file uploader to reset it
301
+ st.session_state['file_uploader_key'] = str(uuid.uuid4())
302
+ # pdf_docs = None
303
+ st.rerun()
304
+
305
+
306
+ user_question = st.chat_input("Ask a Question")
307
+
308
+ # Initialize or load chat history into session state
309
+ if "messages" not in st.session_state:
310
+ st.session_state["messages"] = [{"role": "assistant", "content": "Hello there, how can I help you?"}]
311
+
312
+ # Display chat history
313
+ for message in st.session_state.messages:
314
+ with st.chat_message(message["role"]):
315
+ st.write(message["content"])
316
+
317
+ # Capture user input and update the chat history
318
+ if user_question:
319
+ st.session_state.messages.append({"role": "user", "content": user_question})
320
+ with st.chat_message("user"):
321
+ st.write(user_question)
322
+
323
+ # Generate and display assistant's response, updating the chat history
324
+ with st.chat_message("assistant"):
325
+ with st.spinner("Loading"):
326
+ ai_response = LLM(llm, user_question)
327
+ st.write(ai_response)
328
+
329
+ st.session_state.messages.append({"role": "assistant", "content": ai_response})
330
+
331
+ if __name__ == '__main__':
332
+ main()