Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import streamlit as st
|
4 |
+
from dotenv import load_dotenv
|
5 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
6 |
+
from langchain.vectorstores import FAISS
|
7 |
+
from langchain.prompts import PromptTemplate
|
8 |
+
from langchain_community.llms import Cohere
|
9 |
+
from langchain.embeddings.cohere import CohereEmbeddings
|
10 |
+
from langchain.memory import ConversationBufferMemory
|
11 |
+
from langchain.chains import ConversationalRetrievalChain
|
12 |
+
from langchain_community.document_loaders import PyPDFLoader
|
13 |
+
|
14 |
+
# Imports for Data Ingestion
|
15 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
16 |
+
from langchain_community.document_loaders.pdf import PyPDFDirectoryLoader
|
17 |
+
from langchain_community.document_loaders import PyPDFLoader
|
18 |
+
import os
|
19 |
+
|
20 |
+
import tempfile
|
21 |
+
from langchain_openai import ChatOpenAI
|
22 |
+
from langchain.document_loaders import UnstructuredFileLoader
|
23 |
+
from langchain_community.vectorstores import FAISS
|
24 |
+
from langchain.embeddings import HuggingFaceEmbeddings
|
25 |
+
from langchain.text_splitter import CharacterTextSplitter
|
26 |
+
from langchain.chains import RetrievalQA
|
27 |
+
from langchain_openai import OpenAIEmbeddings
|
28 |
+
from langchain.vectorstores import FAISS
|
29 |
+
from langchain import PromptTemplate
|
30 |
+
from langchain_text_splitters import (
|
31 |
+
Language,
|
32 |
+
RecursiveCharacterTextSplitter,
|
33 |
+
)
|
34 |
+
from PIL import Image, ImageOps
|
35 |
+
import io
|
36 |
+
import PyPDF2
|
37 |
+
import requests
|
38 |
+
import pymupdf4llm
|
39 |
+
import pathlib
|
40 |
+
import time
|
41 |
+
|
42 |
+
import boto3
|
43 |
+
import json
|
44 |
+
from openai import OpenAI
|
45 |
+
# from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
|
46 |
+
from langchain.retrievers import ContextualCompressionRetriever
|
47 |
+
from langchain.retrievers.document_compressors import FlashrankRerank
|
48 |
+
|
49 |
+
from PyPDF2 import PdfReader # Add this import for PDF reading
|
50 |
+
import uuid # Import uuid for unique keys
|
51 |
+
|
52 |
+
# Hyperparameters
|
53 |
+
PDF_CHUNK_SIZE = 1024
|
54 |
+
PDF_CHUNK_OVERLAP = 256
|
55 |
+
k = 3
|
56 |
+
|
57 |
+
# client = OpenAI(
|
58 |
+
# # defaults to os.environ.get("OPENAI_API_KEY")
|
59 |
+
# api_key=os.getenv("OPENAI_API_KEY"),
|
60 |
+
# )
|
61 |
+
|
62 |
+
from langchain_openai import OpenAIEmbeddings
|
63 |
+
embeddings = OpenAIEmbeddings(
|
64 |
+
model="text-embedding-3-large",api_key=os.getenv("OPENAI_API_KEY")
|
65 |
+
# With the `text-embedding-3` class
|
66 |
+
# of models, you can specify the size
|
67 |
+
# of the embeddings you want returned.
|
68 |
+
# dimensions=1024
|
69 |
+
)
|
70 |
+
from langchain_openai import ChatOpenAI
|
71 |
+
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
|
72 |
+
|
73 |
+
llm = ChatOpenAI(
|
74 |
+
model="gpt-4o",
|
75 |
+
temperature=0,
|
76 |
+
max_tokens=None,
|
77 |
+
timeout=None,
|
78 |
+
max_retries=2,
|
79 |
+
api_key=os.getenv("OPENAI_API_KEY"), # if you prefer to pass api key in directly instaed of using env vars
|
80 |
+
# base_url="...",
|
81 |
+
# organization="...",
|
82 |
+
# other params...
|
83 |
+
)
|
84 |
+
|
85 |
+
default_system_prompt = """
|
86 |
+
You are a helpful and knowledgeable assistant who is expert on medical question answering.
|
87 |
+
Your role is select the best answer for queries related to medical information.
|
88 |
+
YOU WILL ALWAYS ANSWER FROM THE CONTEXT PROVIDED. If answer is not provided, politely say that you are not aware of the answer.
|
89 |
+
"""
|
90 |
+
|
91 |
+
|
92 |
+
knowledge_base_prompt = """You have been provided with medical notes and books.
|
93 |
+
Your role is provide the best answer for queries related to medical information.
|
94 |
+
YOU WILL ALWAYS ANSWER FROM THE CONTEXT PROVIDED. If answer is not provided, politely say that you are not aware of the answer.
|
95 |
+
"""
|
96 |
+
#- Keep answers short and direct.
|
97 |
+
|
98 |
+
# Function to ingest PDFs from the directory
|
99 |
+
def data_ingestion():
|
100 |
+
loader = PyPDFDirectoryLoader("finance_documents")
|
101 |
+
documents = loader.load()
|
102 |
+
# Split the text into chunks
|
103 |
+
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4096, chunk_overlap=512)
|
104 |
+
docs = text_splitter.split_documents(documents)
|
105 |
+
return docs
|
106 |
+
|
107 |
+
# Function to create and save vector store
|
108 |
+
def setup_vector_store(documents):
|
109 |
+
# Create a vector store using the documents and embeddings
|
110 |
+
vector_store = FAISS.from_documents(documents, embeddings)
|
111 |
+
# Save the vector store locally
|
112 |
+
vector_store.save_local("faiss_index_medical")
|
113 |
+
|
114 |
+
# Function to load or create vector store
|
115 |
+
def load_or_create_vector_store():
|
116 |
+
# Check if the vector store file exists
|
117 |
+
if os.path.exists("faiss_index_medical"):
|
118 |
+
# Load the vector store
|
119 |
+
vector_store = FAISS.load_local("faiss_index_medical", embeddings, allow_dangerous_deserialization=True)
|
120 |
+
print("Loaded existing vector store.")
|
121 |
+
else:
|
122 |
+
# If the vector store doesn't exist, create it
|
123 |
+
docs = data_ingestion()
|
124 |
+
setup_vector_store(docs)
|
125 |
+
vector_store = FAISS.load_local("faiss_index_medical", embeddings, allow_dangerous_deserialization=True)
|
126 |
+
print("Created and loaded new vector store.")
|
127 |
+
|
128 |
+
return vector_store
|
129 |
+
|
130 |
+
def load_and_pad_image(image_path, size=(64, 64)):
|
131 |
+
img = Image.open(image_path)
|
132 |
+
|
133 |
+
# Make the image square by padding it with white or any background color you like
|
134 |
+
img_with_padding = ImageOps.pad(img, size) # Change color if needed
|
135 |
+
return img_with_padding
|
136 |
+
|
137 |
+
def LLM(llm, query):
|
138 |
+
# Use vectorstore from uploaded files if available
|
139 |
+
if 'vectorstore' in st.session_state and st.session_state['vectorstore'] is not None:
|
140 |
+
system_prompt = knowledge_base_prompt
|
141 |
+
vectorstore = st.session_state['vectorstore']
|
142 |
+
else:
|
143 |
+
system_prompt = default_system_prompt
|
144 |
+
vectorstore = load_or_create_vector_store()
|
145 |
+
knowledge_base = vectorstore
|
146 |
+
compressor = FlashrankRerank()
|
147 |
+
retriever = knowledge_base.as_retriever(search_kwargs={"k": k})
|
148 |
+
compression_retriever = ContextualCompressionRetriever(
|
149 |
+
base_compressor=compressor, base_retriever=retriever
|
150 |
+
)
|
151 |
+
|
152 |
+
template = '''
|
153 |
+
%s
|
154 |
+
-------------------------------
|
155 |
+
Context: {context}
|
156 |
+
|
157 |
+
Current conversation:
|
158 |
+
{chat_history}
|
159 |
+
|
160 |
+
Question: {question}
|
161 |
+
Answer:
|
162 |
+
''' % (system_prompt)
|
163 |
+
|
164 |
+
PROMPT = PromptTemplate(
|
165 |
+
template=template, input_variables=["context", "chat_history", "question"]
|
166 |
+
)
|
167 |
+
chain_type_kwargs = {"prompt": PROMPT}
|
168 |
+
|
169 |
+
# Initialize memory to manage chat history if it doesn't exist
|
170 |
+
if "memory" not in st.session_state:
|
171 |
+
st.session_state.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
172 |
+
|
173 |
+
# Retrieve chat history from st.session_state.messages
|
174 |
+
chat_history = [
|
175 |
+
(msg["role"], msg["content"]) for msg in st.session_state.messages if msg["role"] in ["user", "assistant"]
|
176 |
+
]
|
177 |
+
|
178 |
+
# Create the conversational chain with memory for chat history
|
179 |
+
conversation_chain = ConversationalRetrievalChain.from_llm(
|
180 |
+
llm=llm,
|
181 |
+
retriever=compression_retriever,
|
182 |
+
memory=st.session_state.memory,
|
183 |
+
verbose=True,
|
184 |
+
combine_docs_chain_kwargs=chain_type_kwargs
|
185 |
+
)
|
186 |
+
|
187 |
+
# Run the conversation chain with the latest user query and retrieve response
|
188 |
+
response = conversation_chain({"question": query, "chat_history": chat_history})
|
189 |
+
return response.get("answer")
|
190 |
+
|
191 |
+
# Function to get text from PDF
|
192 |
+
def get_pdf_text(pdf_file):
|
193 |
+
pdf_reader = PdfReader(pdf_file)
|
194 |
+
return "".join(page.extract_text() for page in pdf_reader.pages)
|
195 |
+
|
196 |
+
|
197 |
+
def get_text_chunks(text, file_name, max_chars=16000): # Approx. 4000 tokens
|
198 |
+
# Initial large chunk size
|
199 |
+
large_text_splitter = RecursiveCharacterTextSplitter(chunk_size=8000, chunk_overlap=512)
|
200 |
+
docs = large_text_splitter.create_documents([text])
|
201 |
+
|
202 |
+
# Check character length (as proxy for tokens) and split if a chunk exceeds the limit
|
203 |
+
valid_docs = []
|
204 |
+
for doc in docs:
|
205 |
+
if len(doc.page_content) > max_chars:
|
206 |
+
# Further split if the chunk exceeds max_chars
|
207 |
+
smaller_text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200)
|
208 |
+
valid_docs.extend(smaller_text_splitter.create_documents([doc.page_content]))
|
209 |
+
else:
|
210 |
+
valid_docs.append(doc)
|
211 |
+
|
212 |
+
# Add metadata to each document chunk
|
213 |
+
for doc in valid_docs:
|
214 |
+
doc.metadata["file_name"] = file_name
|
215 |
+
return valid_docs
|
216 |
+
# Function to process uploaded files
|
217 |
+
def process_files(file_list):
|
218 |
+
all_docs = []
|
219 |
+
raw_text = ""
|
220 |
+
for file in file_list:
|
221 |
+
file_extension = os.path.splitext(file.name)[1]
|
222 |
+
file_name = os.path.splitext(file.name)[0]
|
223 |
+
if file_extension == ".pdf":
|
224 |
+
raw_text += get_pdf_text(file)
|
225 |
+
elif file_extension == ".txt":
|
226 |
+
raw_text += file.read().decode('utf-8')
|
227 |
+
elif file_extension == ".csv":
|
228 |
+
raw_text += file.read().decode('utf-8')
|
229 |
+
else:
|
230 |
+
st.warning("File type not supported")
|
231 |
+
|
232 |
+
# Now, split the text into chunks
|
233 |
+
docs = get_text_chunks(raw_text, file_name)
|
234 |
+
for doc in docs:
|
235 |
+
doc.metadata["extension"] = file_extension
|
236 |
+
doc.metadata["source"] = file.name
|
237 |
+
all_docs.extend(docs)
|
238 |
+
if all_docs:
|
239 |
+
# Create vectorstore
|
240 |
+
vectorstore = FAISS.from_documents(all_docs, embeddings)
|
241 |
+
# Save vectorstore in session state
|
242 |
+
st.session_state['vectorstore'] = vectorstore
|
243 |
+
st.success("Knowledge base updated with uploaded files!")
|
244 |
+
else:
|
245 |
+
st.warning("No valid files were uploaded. Please upload PDF, TXT, or CSV files.")
|
246 |
+
|
247 |
+
# Main function to set up Streamlit chat interface
|
248 |
+
def main():
|
249 |
+
load_dotenv()
|
250 |
+
|
251 |
+
favicon_path = "medical.png" # Replace with the actual path to your image file
|
252 |
+
favicon_image = load_and_pad_image(favicon_path)
|
253 |
+
|
254 |
+
st.set_page_config(
|
255 |
+
page_title="Medical Chatbot",
|
256 |
+
page_icon=favicon_image,
|
257 |
+
)
|
258 |
+
# Create two columns for the logo and title text
|
259 |
+
col1, col2 = st.columns([1, 8]) # Adjust the column width ratios as needed
|
260 |
+
|
261 |
+
# Reduce spacing by adjusting padding
|
262 |
+
with col1:
|
263 |
+
st.image(favicon_image) # Display the logo image
|
264 |
+
|
265 |
+
with col2:
|
266 |
+
# Reduce spacing by adding custom HTML with no margin/padding
|
267 |
+
st.markdown("""
|
268 |
+
<h1 style='text-align: left; margin-top: -12px;'>
|
269 |
+
Medical Chatbot
|
270 |
+
</h1>
|
271 |
+
""", unsafe_allow_html=True)
|
272 |
+
|
273 |
+
# Initialize the unique key for the file uploader
|
274 |
+
if 'file_uploader_key' not in st.session_state:
|
275 |
+
st.session_state['file_uploader_key'] = str(uuid.uuid4())
|
276 |
+
|
277 |
+
# Add file upload component in the sidebar
|
278 |
+
with st.sidebar:
|
279 |
+
st.subheader("Your PDFs")
|
280 |
+
pdf_docs = st.file_uploader(
|
281 |
+
"Upload PDFs and click process",
|
282 |
+
type=["pdf", "txt", "csv"],
|
283 |
+
accept_multiple_files=True,
|
284 |
+
key=st.session_state['file_uploader_key']
|
285 |
+
)
|
286 |
+
if st.button("Process"):
|
287 |
+
if pdf_docs is not None and len(pdf_docs) > 0:
|
288 |
+
with st.spinner("Processing PDFs"):
|
289 |
+
process_files(pdf_docs)
|
290 |
+
else:
|
291 |
+
st.error("Please upload at least one file.")
|
292 |
+
|
293 |
+
# Button to start a new session
|
294 |
+
if st.button("New Session"):
|
295 |
+
# Clear the chat history and memory
|
296 |
+
st.session_state["messages"] = [{"role": "assistant", "content": "Hello there, how can I help you?"}]
|
297 |
+
st.session_state.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
298 |
+
# Clear the vectorstore from session state
|
299 |
+
st.session_state['vectorstore'] = None
|
300 |
+
# Assign a new key to the file uploader to reset it
|
301 |
+
st.session_state['file_uploader_key'] = str(uuid.uuid4())
|
302 |
+
# pdf_docs = None
|
303 |
+
st.rerun()
|
304 |
+
|
305 |
+
|
306 |
+
user_question = st.chat_input("Ask a Question")
|
307 |
+
|
308 |
+
# Initialize or load chat history into session state
|
309 |
+
if "messages" not in st.session_state:
|
310 |
+
st.session_state["messages"] = [{"role": "assistant", "content": "Hello there, how can I help you?"}]
|
311 |
+
|
312 |
+
# Display chat history
|
313 |
+
for message in st.session_state.messages:
|
314 |
+
with st.chat_message(message["role"]):
|
315 |
+
st.write(message["content"])
|
316 |
+
|
317 |
+
# Capture user input and update the chat history
|
318 |
+
if user_question:
|
319 |
+
st.session_state.messages.append({"role": "user", "content": user_question})
|
320 |
+
with st.chat_message("user"):
|
321 |
+
st.write(user_question)
|
322 |
+
|
323 |
+
# Generate and display assistant's response, updating the chat history
|
324 |
+
with st.chat_message("assistant"):
|
325 |
+
with st.spinner("Loading"):
|
326 |
+
ai_response = LLM(llm, user_question)
|
327 |
+
st.write(ai_response)
|
328 |
+
|
329 |
+
st.session_state.messages.append({"role": "assistant", "content": ai_response})
|
330 |
+
|
331 |
+
if __name__ == '__main__':
|
332 |
+
main()
|