Sharal commited on
Commit
3b0feea
·
verified ·
1 Parent(s): f013f91

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +183 -0
app.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_community.document_loaders import PyPDFLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain_community.embeddings import HuggingFaceEmbeddings
7
+ from langchain_huggingface import HuggingFaceEndpoint # Updated import
8
+ from langchain.chains import ConversationalRetrievalChain
9
+ from langchain.memory import ConversationBufferMemory
10
+ import tempfile
11
+
12
+ api_token = os.getenv("HF_TOKEN")
13
+ list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
14
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
15
+
16
+ def load_doc(uploaded_files):
17
+ try:
18
+ temp_files = []
19
+ for uploaded_file in uploaded_files:
20
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
21
+ temp_file.write(uploaded_file.read())
22
+ temp_file.close()
23
+ temp_files.append(temp_file.name)
24
+
25
+ loaders = [PyPDFLoader(x) for x in temp_files]
26
+ pages = []
27
+ for loader in loaders:
28
+ pages.extend(loader.load())
29
+
30
+ text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
31
+ doc_splits = text_splitter.split_documents(pages)
32
+
33
+ for temp_file in temp_files:
34
+ os.remove(temp_file) # Clean up temporary files
35
+
36
+ return doc_splits
37
+ except Exception as e:
38
+ st.error(f"Error loading document: {e}")
39
+ return []
40
+
41
+ def create_db(splits):
42
+ try:
43
+ embeddings = HuggingFaceEmbeddings()
44
+ vectordb = FAISS.from_documents(splits, embeddings)
45
+ return vectordb
46
+ except Exception as e:
47
+ st.error(f"Error creating vector database: {e}")
48
+ return None
49
+
50
+ def initialize_llmchain(llm_model, vector_db):
51
+ try:
52
+ llm = HuggingFaceEndpoint(
53
+ repo_id=llm_model,
54
+ huggingfacehub_api_token=api_token,
55
+ temperature=0.5,
56
+ max_new_tokens=4096,
57
+ top_k=3,
58
+ )
59
+ memory = ConversationBufferMemory(
60
+ memory_key="chat_history",
61
+ output_key='answer',
62
+ return_messages=True
63
+ )
64
+
65
+ retriever = vector_db.as_retriever()
66
+ qa_chain = ConversationalRetrievalChain.from_llm(
67
+ llm,
68
+ retriever=retriever,
69
+ chain_type="stuff",
70
+ memory=memory,
71
+ return_source_documents=True,
72
+ verbose=False,
73
+ )
74
+ return qa_chain
75
+ except Exception as e:
76
+ st.error(f"Error initializing LLM chain: {e}")
77
+ return None
78
+
79
+ def initialize_database(uploaded_files):
80
+ try:
81
+ doc_splits = load_doc(uploaded_files)
82
+ if not doc_splits:
83
+ return None, "Failed to load documents."
84
+ vector_db = create_db(doc_splits)
85
+ if vector_db is None:
86
+ return None, "Failed to create vector database."
87
+ return vector_db, "Database created!"
88
+ except Exception as e:
89
+ st.error(f"Error initializing database: {e}")
90
+ return None, "Failed to initialize database."
91
+
92
+ def initialize_LLM(llm_option, vector_db):
93
+ try:
94
+ llm_name = list_llm[llm_option]
95
+ qa_chain = initialize_llmchain(llm_name, vector_db)
96
+ if qa_chain is None:
97
+ return None, "Failed to initialize QA chain."
98
+ return qa_chain, "QA chain initialized. Chatbot is ready!"
99
+ except Exception as e:
100
+ st.error(f"Error initializing LLM: {e}")
101
+ return None, "Failed to initialize LLM."
102
+
103
+ def format_chat_history(chat_history):
104
+ formatted_chat_history = []
105
+ for user_message, bot_message in chat_history:
106
+ formatted_chat_history.append(f"User: {user_message}\nAssistant: {bot_message}\n")
107
+ return formatted_chat_history
108
+
109
+ def conversation(qa_chain, message, history):
110
+ try:
111
+ formatted_chat_history = format_chat_history(history)
112
+ response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
113
+ response_answer = response["answer"]
114
+ response_sources = response["source_documents"]
115
+
116
+ sources = []
117
+ for doc in response_sources:
118
+ sources.append({
119
+ "content": doc.page_content.strip(),
120
+ "page": doc.metadata["page"] + 1
121
+ })
122
+
123
+ new_history = history + [(message, response_answer)]
124
+ return qa_chain, new_history, response_answer, sources
125
+ except Exception as e:
126
+ st.error(f"Error in conversation: {e}")
127
+ return qa_chain, history, "", []
128
+
129
+ def main():
130
+ st.sidebar.title("PDF Chatbot")
131
+
132
+ st.sidebar.markdown("### Step 1 - Upload PDF documents and Initialize RAG pipeline")
133
+ uploaded_files = st.sidebar.file_uploader("Upload PDF documents", type="pdf", accept_multiple_files=True)
134
+
135
+ if uploaded_files:
136
+ if st.sidebar.button("Create vector database"):
137
+ with st.spinner("Creating vector database..."):
138
+ vector_db, db_message = initialize_database(uploaded_files)
139
+ st.sidebar.success(db_message)
140
+ st.session_state['vector_db'] = vector_db
141
+
142
+ if 'vector_db' not in st.session_state:
143
+ st.session_state['vector_db'] = None
144
+
145
+ if 'qa_chain' not in st.session_state:
146
+ st.session_state['qa_chain'] = None
147
+
148
+ if 'chat_history' not in st.session_state:
149
+ st.session_state['chat_history'] = []
150
+
151
+ st.sidebar.markdown("### Select Large Language Model (LLM)")
152
+ llm_option = st.sidebar.radio("Available LLMs", list_llm_simple)
153
+
154
+ if st.sidebar.button("Initialize Question Answering Chatbot"):
155
+ with st.spinner("Initializing QA chatbot..."):
156
+ qa_chain, llm_message = initialize_LLM(list_llm_simple.index(llm_option), st.session_state['vector_db'])
157
+ st.session_state['qa_chain'] = qa_chain
158
+ st.sidebar.success(llm_message)
159
+
160
+ st.title("Chat with your Document")
161
+
162
+ if st.session_state['qa_chain']:
163
+ message = st.text_input("Ask a question")
164
+
165
+ if st.button("Submit"):
166
+ with st.spinner("Generating response..."):
167
+ qa_chain, chat_history, response_answer, sources = conversation(st.session_state['qa_chain'], message, st.session_state['chat_history'])
168
+ st.session_state['qa_chain'] = qa_chain
169
+ st.session_state['chat_history'] = chat_history
170
+
171
+ st.markdown("### Chatbot Response")
172
+
173
+ # Display the chat history in a chat-like interface
174
+ for i, (user_msg, bot_msg) in enumerate(st.session_state['chat_history']):
175
+ st.markdown(f"**User:** {user_msg}")
176
+ st.markdown(f"**Assistant:** {bot_msg}")
177
+
178
+ with st.expander("Relevant context from the source document"):
179
+ for source in sources:
180
+ st.text_area(f"Source - Page {source['page']}", value=source["content"], height=100)
181
+
182
+ if __name__ == "__main__":
183
+ main()