Spaces:
Sleeping
Sleeping
import os | |
import streamlit as st | |
from dotenv import load_dotenv | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain.vectorstores import FAISS | |
from langchain.prompts import PromptTemplate | |
from langchain_community.llms import Cohere | |
from langchain.embeddings.cohere import CohereEmbeddings | |
from langchain.memory import ConversationBufferMemory | |
from langchain.chains import ConversationalRetrievalChain | |
from langchain_community.document_loaders import PyPDFLoader | |
# Imports for Data Ingestion | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
from langchain_community.document_loaders.pdf import PyPDFDirectoryLoader | |
from langchain_community.document_loaders import PyPDFLoader | |
import os | |
import tempfile | |
from langchain_openai import ChatOpenAI | |
from langchain.document_loaders import UnstructuredFileLoader | |
from langchain_community.vectorstores import FAISS | |
from langchain.embeddings import HuggingFaceEmbeddings | |
from langchain.text_splitter import CharacterTextSplitter | |
from langchain.chains import RetrievalQA | |
from langchain_openai import OpenAIEmbeddings | |
from langchain.vectorstores import FAISS | |
from langchain import PromptTemplate | |
from langchain_text_splitters import ( | |
Language, | |
RecursiveCharacterTextSplitter, | |
) | |
from PIL import Image, ImageOps | |
import io | |
import PyPDF2 | |
import requests | |
import pymupdf4llm | |
import pathlib | |
import time | |
import boto3 | |
import json | |
from openai import OpenAI | |
# from langchain.retrievers.contextual_compression import ContextualCompressionRetriever | |
from langchain.retrievers import ContextualCompressionRetriever | |
from langchain.retrievers.document_compressors import FlashrankRerank | |
from PyPDF2 import PdfReader # Add this import for PDF reading | |
import uuid # Import uuid for unique keys | |
# Hyperparameters | |
PDF_CHUNK_SIZE = 1024 | |
PDF_CHUNK_OVERLAP = 256 | |
k = 3 | |
# client = OpenAI( | |
# # defaults to os.environ.get("OPENAI_API_KEY") | |
# api_key=os.getenv("OPENAI_API_KEY"), | |
# ) | |
from langchain_openai import OpenAIEmbeddings | |
embeddings = OpenAIEmbeddings( | |
model="text-embedding-3-large",api_key=os.getenv("OPENAI_API_KEY") | |
# With the `text-embedding-3` class | |
# of models, you can specify the size | |
# of the embeddings you want returned. | |
# dimensions=1024 | |
) | |
from langchain_openai import ChatOpenAI | |
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
llm = ChatOpenAI( | |
model="gpt-4o-mini", | |
temperature=0, | |
max_tokens=None, | |
timeout=None, | |
max_retries=2, | |
api_key=os.getenv("OPENAI_API_KEY"), # if you prefer to pass api key in directly instaed of using env vars | |
# base_url="...", | |
# organization="...", | |
# other params... | |
) | |
default_system_prompt = """ | |
You are a helpful and knowledgeable assistant who is expert on medical question answering. | |
Your role is select the best answer for queries related to medical information. | |
YOU WILL ALWAYS ANSWER FROM THE CONTEXT PROVIDED. If answer is not provided, politely say that you are not aware of the answer. | |
""" | |
knowledge_base_prompt = """You have been provided with medical notes and books. | |
Your role is provide the best answer for queries related to medical information. | |
YOU WILL ALWAYS ANSWER FROM THE CONTEXT PROVIDED. If answer is not provided, politely say that you are not aware of the answer. | |
""" | |
#- Keep answers short and direct. | |
# Function to ingest PDFs from the directory | |
def data_ingestion(): | |
loader = PyPDFDirectoryLoader("finance_documents") | |
documents = loader.load() | |
# Split the text into chunks | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=4096, chunk_overlap=512) | |
docs = text_splitter.split_documents(documents) | |
return docs | |
# Function to create and save vector store | |
def setup_vector_store(documents): | |
# Create a vector store using the documents and embeddings | |
vector_store = FAISS.from_documents(documents, embeddings) | |
# Save the vector store locally | |
vector_store.save_local("faiss_index_medical") | |
# Function to load or create vector store | |
def load_or_create_vector_store(): | |
# Check if the vector store file exists | |
if os.path.exists("faiss_index_medical"): | |
# Load the vector store | |
vector_store = FAISS.load_local("faiss_index_medical", embeddings, allow_dangerous_deserialization=True) | |
print("Loaded existing vector store.") | |
else: | |
# If the vector store doesn't exist, create it | |
docs = data_ingestion() | |
setup_vector_store(docs) | |
vector_store = FAISS.load_local("faiss_index_medical", embeddings, allow_dangerous_deserialization=True) | |
print("Created and loaded new vector store.") | |
return vector_store | |
def load_and_pad_image(image_path, size=(64, 64)): | |
img = Image.open(image_path) | |
# Make the image square by padding it with white or any background color you like | |
img_with_padding = ImageOps.pad(img, size) # Change color if needed | |
return img_with_padding | |
def LLM(llm, query): | |
# Use vectorstore from uploaded files if available | |
if 'vectorstore' in st.session_state and st.session_state['vectorstore'] is not None: | |
system_prompt = knowledge_base_prompt | |
vectorstore = st.session_state['vectorstore'] | |
else: | |
system_prompt = default_system_prompt | |
vectorstore = load_or_create_vector_store() | |
knowledge_base = vectorstore | |
compressor = FlashrankRerank() | |
retriever = knowledge_base.as_retriever(search_kwargs={"k": k}) | |
compression_retriever = ContextualCompressionRetriever( | |
base_compressor=compressor, base_retriever=retriever | |
) | |
template = ''' | |
%s | |
------------------------------- | |
Context: {context} | |
Current conversation: | |
{chat_history} | |
Question: {question} | |
Answer: | |
''' % (system_prompt) | |
PROMPT = PromptTemplate( | |
template=template, input_variables=["context", "chat_history", "question"] | |
) | |
chain_type_kwargs = {"prompt": PROMPT} | |
# Initialize memory to manage chat history if it doesn't exist | |
if "memory" not in st.session_state: | |
st.session_state.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
# Retrieve chat history from st.session_state.messages | |
chat_history = [ | |
(msg["role"], msg["content"]) for msg in st.session_state.messages if msg["role"] in ["user", "assistant"] | |
] | |
# Create the conversational chain with memory for chat history | |
conversation_chain = ConversationalRetrievalChain.from_llm( | |
llm=llm, | |
retriever=compression_retriever, | |
memory=st.session_state.memory, | |
verbose=True, | |
combine_docs_chain_kwargs=chain_type_kwargs | |
) | |
# Run the conversation chain with the latest user query and retrieve response | |
response = conversation_chain({"question": query, "chat_history": chat_history}) | |
return response.get("answer") | |
# Function to get text from PDF | |
def get_pdf_text(pdf_file): | |
pdf_reader = PdfReader(pdf_file) | |
return "".join(page.extract_text() for page in pdf_reader.pages) | |
def get_text_chunks(text, file_name, max_chars=16000): # Approx. 4000 tokens | |
# Initial large chunk size | |
large_text_splitter = RecursiveCharacterTextSplitter(chunk_size=8000, chunk_overlap=512) | |
docs = large_text_splitter.create_documents([text]) | |
# Check character length (as proxy for tokens) and split if a chunk exceeds the limit | |
valid_docs = [] | |
for doc in docs: | |
if len(doc.page_content) > max_chars: | |
# Further split if the chunk exceeds max_chars | |
smaller_text_splitter = RecursiveCharacterTextSplitter(chunk_size=2000, chunk_overlap=200) | |
valid_docs.extend(smaller_text_splitter.create_documents([doc.page_content])) | |
else: | |
valid_docs.append(doc) | |
# Add metadata to each document chunk | |
for doc in valid_docs: | |
doc.metadata["file_name"] = file_name | |
return valid_docs | |
# Function to process uploaded files | |
def process_files(file_list): | |
all_docs = [] | |
raw_text = "" | |
for file in file_list: | |
file_extension = os.path.splitext(file.name)[1] | |
file_name = os.path.splitext(file.name)[0] | |
if file_extension == ".pdf": | |
raw_text += get_pdf_text(file) | |
elif file_extension == ".txt": | |
raw_text += file.read().decode('utf-8') | |
elif file_extension == ".csv": | |
raw_text += file.read().decode('utf-8') | |
else: | |
st.warning("File type not supported") | |
# Now, split the text into chunks | |
docs = get_text_chunks(raw_text, file_name) | |
for doc in docs: | |
doc.metadata["extension"] = file_extension | |
doc.metadata["source"] = file.name | |
all_docs.extend(docs) | |
if all_docs: | |
# Create vectorstore | |
vectorstore = FAISS.from_documents(all_docs, embeddings) | |
# Save vectorstore in session state | |
st.session_state['vectorstore'] = vectorstore | |
st.success("Knowledge base updated with uploaded files!") | |
else: | |
st.warning("No valid files were uploaded. Please upload PDF, TXT, or CSV files.") | |
# Main function to set up Streamlit chat interface | |
def main(): | |
load_dotenv() | |
favicon_path = "medical.png" # Replace with the actual path to your image file | |
favicon_image = load_and_pad_image(favicon_path) | |
st.set_page_config( | |
page_title="Medical Chatbot", | |
page_icon=favicon_image, | |
) | |
# Create two columns for the logo and title text | |
col1, col2 = st.columns([1, 8]) # Adjust the column width ratios as needed | |
# Reduce spacing by adjusting padding | |
with col1: | |
st.image(favicon_image) # Display the logo image | |
with col2: | |
# Reduce spacing by adding custom HTML with no margin/padding | |
st.markdown(""" | |
<h1 style='text-align: left; margin-top: -12px;'> | |
Medical Chatbot | |
</h1> | |
""", unsafe_allow_html=True) | |
# Initialize the unique key for the file uploader | |
if 'file_uploader_key' not in st.session_state: | |
st.session_state['file_uploader_key'] = str(uuid.uuid4()) | |
# Add file upload component in the sidebar | |
with st.sidebar: | |
st.subheader("Your PDFs") | |
pdf_docs = st.file_uploader( | |
"Upload PDFs and click process", | |
type=["pdf", "txt", "csv"], | |
accept_multiple_files=True, | |
key=st.session_state['file_uploader_key'] | |
) | |
if st.button("Process"): | |
if pdf_docs is not None and len(pdf_docs) > 0: | |
with st.spinner("Processing PDFs"): | |
process_files(pdf_docs) | |
else: | |
st.error("Please upload at least one file.") | |
# Button to start a new session | |
if st.button("New Session"): | |
# Clear the chat history and memory | |
st.session_state["messages"] = [{"role": "assistant", "content": "Hello there, how can I help you?"}] | |
st.session_state.memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) | |
# Clear the vectorstore from session state | |
st.session_state['vectorstore'] = None | |
# Assign a new key to the file uploader to reset it | |
st.session_state['file_uploader_key'] = str(uuid.uuid4()) | |
# pdf_docs = None | |
st.rerun() | |
user_question = st.chat_input("Ask a Question") | |
# Initialize or load chat history into session state | |
if "messages" not in st.session_state: | |
st.session_state["messages"] = [{"role": "assistant", "content": "Hello there, how can I help you?"}] | |
# Display chat history | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.write(message["content"]) | |
# Capture user input and update the chat history | |
if user_question: | |
st.session_state.messages.append({"role": "user", "content": user_question}) | |
with st.chat_message("user"): | |
st.write(user_question) | |
# Generate and display assistant's response, updating the chat history | |
with st.chat_message("assistant"): | |
with st.spinner("Loading"): | |
ai_response = LLM(llm, user_question) | |
st.write(ai_response) | |
st.session_state.messages.append({"role": "assistant", "content": ai_response}) | |
if __name__ == '__main__': | |
main() | |