Spaces:
Sleeping
Sleeping
File size: 7,173 Bytes
1049376 1376891 1049376 1376891 1049376 0438285 1049376 0438285 1049376 0438285 ac11d7b 1049376 0438285 1049376 a1d1670 1049376 a1d1670 1049376 ac11d7b 1049376 0438285 1049376 0438285 1049376 a1d1670 1049376 a1d1670 1049376 a3741a5 1049376 93b3f24 1049376 93b3f24 1049376 93b3f24 1049376 ecbab86 1049376 ecbab86 1049376 ecbab86 1049376 ecbab86 0438285 1049376 a3741a5 93b3f24 ecbab86 1049376 ecbab86 1049376 a1d1670 1049376 ecbab86 1049376 2e1d4d8 ac11d7b 2e1d4d8 1049376 f5daac6 2e1d4d8 f5daac6 f749b78 f4bc94d f749b78 f5daac6 03ec30e a3741a5 1049376 1376891 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
import streamlit as st
import os
import tempfile
from langchain_community.vectorstores import FAISS
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
from langchain.chains import ConversationalRetrievalChain
from langchain.memory import ConversationBufferMemory
api_token = os.getenv("HF_TOKEN")
list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
list_llm_simple = [os.path.basename(llm) for llm in list_llm]
def load_doc(uploaded_files):
try:
temp_files = []
for uploaded_file in uploaded_files:
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
temp_file.write(uploaded_file.read())
temp_file.close()
temp_files.append(temp_file.name)
loaders = [PyPDFLoader(x) for x in temp_files]
pages = []
for loader in loaders:
pages.extend(loader.load())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
doc_splits = text_splitter.split_documents(pages)
for temp_file in temp_files:
os.remove(temp_file) # Clean up temporary files
return doc_splits
except Exception as e:
st.error(f"Error loading document: {e}")
return []
def create_db(splits):
try:
embeddings = HuggingFaceEmbeddings()
vectordb = FAISS.from_documents(splits, embeddings)
return vectordb
except Exception as e:
st.error(f"Error creating vector database: {e}")
return None
def initialize_llmchain(llm_model, vector_db):
try:
llm = HuggingFaceEndpoint(
repo_id=llm_model,
huggingfacehub_api_token=api_token,
temperature=0.5,
max_new_tokens=4096,
top_k=3,
)
memory = ConversationBufferMemory(
memory_key="chat_history",
output_key='answer',
return_messages=True
)
retriever = vector_db.as_retriever()
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
chain_type="stuff",
memory=memory,
return_source_documents=True,
verbose=False,
)
return qa_chain
except Exception as e:
st.error(f"Error initializing LLM chain: {e}")
return None
def initialize_database(uploaded_files):
try:
doc_splits = load_doc(uploaded_files)
if not doc_splits:
return None, "Failed to load documents."
vector_db = create_db(doc_splits)
if vector_db is None:
return None, "Failed to create vector database."
return vector_db, "Database created!"
except Exception as e:
st.error(f"Error initializing database: {e}")
return None, "Failed to initialize database."
def initialize_LLM(llm_option, vector_db):
try:
llm_name = list_llm[llm_option]
qa_chain = initialize_llmchain(llm_name, vector_db)
if qa_chain is None:
return None, "Failed to initialize QA chain."
return qa_chain, "QA chain initialized. Chatbot is ready!"
except Exception as e:
st.error(f"Error initializing LLM: {e}")
return None, "Failed to initialize LLM."
def format_chat_history(chat_history):
formatted_chat_history = []
for user_message, bot_message in chat_history:
formatted_chat_history.append(f"User: {user_message}\nAssistant: {bot_message}\n")
return formatted_chat_history
def conversation(qa_chain, message, history):
try:
formatted_chat_history = format_chat_history(history)
response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
response_answer = response["answer"]
response_sources = response["source_documents"]
sources = []
for doc in response_sources:
sources.append({
"content": doc.page_content.strip(),
"page": doc.metadata["page"] + 1
})
new_history = history + [(message, response_answer)]
return qa_chain, new_history, response_answer, sources
except Exception as e:
st.error(f"Error in conversation: {e}")
return qa_chain, history, "", []
def main():
st.sidebar.title("PDF Chatbot")
st.sidebar.markdown("### Step 1 - Upload PDF documents and Initialize RAG pipeline")
uploaded_files = st.sidebar.file_uploader("Upload PDF documents", type="pdf", accept_multiple_files=True)
if uploaded_files:
if st.sidebar.button("Create vector database"):
with st.spinner("Creating vector database..."):
vector_db, db_message = initialize_database(uploaded_files)
st.sidebar.success(db_message)
st.session_state['vector_db'] = vector_db
if 'vector_db' not in st.session_state:
st.session_state['vector_db'] = None
if 'qa_chain' not in st.session_state:
st.session_state['qa_chain'] = None
if 'chat_history' not in st.session_state:
st.session_state['chat_history'] = []
st.sidebar.markdown("### Select Large Language Model (LLM)")
llm_option = st.sidebar.radio("Available LLMs", list_llm_simple)
if st.sidebar.button("Initialize Question Answering Chatbot"):
with st.spinner("Initializing QA chatbot..."):
qa_chain, llm_message = initialize_LLM(list_llm_simple.index(llm_option), st.session_state['vector_db'])
st.session_state['qa_chain'] = qa_chain
st.sidebar.success(llm_message)
st.title("Chat with your Document")
if st.session_state['qa_chain']:
st.markdown("### Chatbot Response")
# Display the chat history in a chat-like interface
for i, (user_msg, bot_msg) in enumerate(st.session_state['chat_history']):
st.markdown(f"**User:** {user_msg}")
st.markdown(f"**Assistant:** {bot_msg}")
st.markdown("### Relevant context from the source document")
with st.expander("Relevant context from the source document"):
if 'sources' in st.session_state:
for i, source in enumerate(st.session_state['sources']):
st.text_area(f"Source {i + 1} - Page {source['page']}", value=source["content"], height=100)
with st.form(key="question_form"):
message = st.text_input("Ask a question", key="message")
submit_button = st.form_submit_button(label="Submit")
if submit_button:
with st.spinner("Generating response..."):
qa_chain, chat_history, response_answer, sources = conversation(st.session_state['qa_chain'], message, st.session_state['chat_history'])
st.session_state['qa_chain'] = qa_chain
st.session_state['chat_history'] = chat_history
st.session_state['sources'] = sources
if __name__ == "__main__":
main()
|