Spaces:
Sleeping
Sleeping
File size: 6,942 Bytes
3b0feea 0f23f69 3b0feea 4913f59 3b0feea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import streamlit as st
import os
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_huggingface import HuggingFaceEndpoint # Updated import
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
import tempfile
api_token = os.getenv("HF_TOKEN")
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
def load_doc(uploaded_files):
try:
temp_files = []
for uploaded_file in uploaded_files:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
temp_file.write(uploaded_file.read())
temp_file.close()
temp_files.append(temp_file.name)
loaders = [PyPDFLoader(x) for x in temp_files]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
doc_splits = text_splitter.split_documents(pages)
for temp_file in temp_files:
os.remove(temp_file) # Clean up temporary files
return doc_splits
except Exception as e:
st.error(f"Error loading document: {e}")
return []
def create_db(splits):
try:
embeddings = HuggingFaceEmbeddings()
vectordb = FAISS.from_documents(splits, embeddings)
return vectordb
except Exception as e:
st.error(f"Error creating vector database: {e}")
return None
def initialize_llmchain(llm_model, vector_db):
try:
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=api_token,
temperature=0.5,
max_new_tokens=4096,
top_k=3,
)
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
return qa_chain
except Exception as e:
st.error(f"Error initializing LLM chain: {e}")
return None
def initialize_database(uploaded_files):
try:
doc_splits = load_doc(uploaded_files)
if not doc_splits:
return None, "Failed to load documents."
vector_db = create_db(doc_splits)
if vector_db is None:
return None, "Failed to create vector database."
return vector_db, "Database created!"
except Exception as e:
st.error(f"Error initializing database: {e}")
return None, "Failed to initialize database."
def initialize_LLM(llm_option, vector_db):
try:
llm_name = list_llm[llm_option]
qa_chain = initialize_llmchain(llm_name, vector_db)
if qa_chain is None:
return None, "Failed to initialize QA chain."
return qa_chain, "QA chain initialized. Chatbot is ready!"
except Exception as e:
st.error(f"Error initializing LLM: {e}")
return None, "Failed to initialize LLM."
def format_chat_history(chat_history):
formatted_chat_history = []
for user_message, bot_message in chat_history:
formatted_chat_history.append(f"User: {user_message}\nAssistant: {bot_message}\n")
return formatted_chat_history
def conversation(qa_chain, message, history):
try:
formatted_chat_history = format_chat_history(history)
response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"]
response_sources = response["source_documents"]
sources = []
for doc in response_sources:
sources.append({
"content": doc.page_content.strip(),
"page": doc.metadata["page"] + 1
})
new_history = history + [(message, response_answer)]
return qa_chain, new_history, response_answer, sources
except Exception as e:
st.error(f"Error in conversation: {e}")
return qa_chain, history, "", []
def main():
st.sidebar.title("PDF Chatbot")
st.sidebar.markdown("### Step 1 - Upload PDF documents and create the vector database")
uploaded_files = st.sidebar.file_uploader("Upload PDF documents", type="pdf", accept_multiple_files=True)
if uploaded_files:
if st.sidebar.button("Create vector database"):
with st.spinner("Creating vector database..."):
vector_db, db_message = initialize_database(uploaded_files)
st.sidebar.success(db_message)
st.session_state['vector_db'] = vector_db
if 'vector_db' not in st.session_state:
st.session_state['vector_db'] = None
if 'qa_chain' not in st.session_state:
st.session_state['qa_chain'] = None
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
st.sidebar.markdown("### Select Large Language Model (LLM)")
llm_option = st.sidebar.radio("Available LLMs", list_llm_simple)
if st.sidebar.button("Initialize Question Answering Chatbot"):
with st.spinner("Initializing QA chatbot..."):
qa_chain, llm_message = initialize_LLM(list_llm_simple.index(llm_option), st.session_state['vector_db'])
st.session_state['qa_chain'] = qa_chain
st.sidebar.success(llm_message)
st.title("Chat with your Document")
sources = [] # Initialize sources variable
if st.session_state['qa_chain']:
message = st.text_input("Ask a question")
if st.button("Submit"):
with st.spinner("Generating response..."):
qa_chain, chat_history, response_answer, sources = conversation(st.session_state['qa_chain'], message, st.session_state['chat_history'])
st.session_state['qa_chain'] = qa_chain
st.session_state['chat_history'] = chat_history
st.markdown("### Chatbot Response")
# Display the chat history in a chat-like interface
for i, (user_msg, bot_msg) in enumerate(st.session_state['chat_history']):
st.markdown(f"**User:** {user_msg}")
st.markdown(f"**Assistant:** {bot_msg}")
with st.expander("Relevant context from the source document"):
for source in sources:
st.text_area(f"Source - Page {source['page']}", value=source["content"], height=100)
if __name__ == "__main__":
main()
|