Sharal commited on
Commit
f013f91
·
verified ·
1 Parent(s): f3446b9

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -195
app.py DELETED
@@ -1,195 +0,0 @@
1
- import streamlit as st
2
- import os
3
- import tempfile
4
- from langchain_community.vectorstores import FAISS
5
- from langchain_community.document_loaders import PyPDFLoader
6
- from langchain.text_splitter import RecursiveCharacterTextSplitter
7
- from langchain_huggingface import HuggingFaceEmbeddings, HuggingFaceEndpoint
8
- from langchain.chains import ConversationalRetrievalChain
9
- from langchain.memory import ConversationBufferMemory
10
-
11
- api_token = os.getenv("HF_TOKEN")
12
- list_llm = ["meta-llama/Meta-Llama-3-8B-Instruct", "mistralai/Mistral-7B-Instruct-v0.2"]
13
- list_llm_simple = [os.path.basename(llm) for llm in list_llm]
14
-
15
- def load_doc(uploaded_files):
16
- try:
17
- temp_files = []
18
- for uploaded_file in uploaded_files:
19
- temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".pdf")
20
- temp_file.write(uploaded_file.read())
21
- temp_file.close()
22
- temp_files.append(temp_file.name)
23
-
24
- loaders = [PyPDFLoader(x) for x in temp_files]
25
- pages = []
26
- for loader in loaders:
27
- pages.extend(loader.load())
28
-
29
- text_splitter = RecursiveCharacterTextSplitter(chunk_size=1024, chunk_overlap=64)
30
- doc_splits = text_splitter.split_documents(pages)
31
-
32
- for temp_file in temp_files:
33
- os.remove(temp_file) # Clean up temporary files
34
-
35
- return doc_splits
36
- except Exception as e:
37
- st.error(f"Error loading document: {e}")
38
- return []
39
-
40
- def create_db(splits):
41
- try:
42
- embeddings = HuggingFaceEmbeddings()
43
- vectordb = FAISS.from_documents(splits, embeddings)
44
- return vectordb
45
- except Exception as e:
46
- st.error(f"Error creating vector database: {e}")
47
- return None
48
-
49
- def initialize_llmchain(llm_model, vector_db):
50
- try:
51
- llm = HuggingFaceEndpoint(
52
- repo_id=llm_model,
53
- huggingfacehub_api_token=api_token,
54
- temperature=0.5,
55
- max_new_tokens=4096,
56
- top_k=3,
57
- )
58
- memory = ConversationBufferMemory(
59
- memory_key="chat_history",
60
- output_key='answer',
61
- return_messages=True
62
- )
63
-
64
- retriever = vector_db.as_retriever()
65
- qa_chain = ConversationalRetrievalChain.from_llm(
66
- llm,
67
- retriever=retriever,
68
- chain_type="stuff",
69
- memory=memory,
70
- return_source_documents=True,
71
- verbose=False,
72
- )
73
- return qa_chain
74
- except Exception as e:
75
- st.error(f"Error initializing LLM chain: {e}")
76
- return None
77
-
78
- def initialize_database(uploaded_files):
79
- try:
80
- doc_splits = load_doc(uploaded_files)
81
- if not doc_splits:
82
- return None, "Failed to load documents."
83
- vector_db = create_db(doc_splits)
84
- if vector_db is None:
85
- return None, "Failed to create vector database."
86
- return vector_db, "Database created!"
87
- except Exception as e:
88
- st.error(f"Error initializing database: {e}")
89
- return None, "Failed to initialize database."
90
-
91
- def initialize_LLM(llm_option, vector_db):
92
- try:
93
- llm_name = list_llm[llm_option]
94
- qa_chain = initialize_llmchain(llm_name, vector_db)
95
- if qa_chain is None:
96
- return None, "Failed to initialize QA chain."
97
- return qa_chain, "QA chain initialized. Chatbot is ready!"
98
- except Exception as e:
99
- st.error(f"Error initializing LLM: {e}")
100
- return None, "Failed to initialize LLM."
101
-
102
- def format_chat_history(chat_history):
103
- formatted_chat_history = []
104
- for user_message, bot_message in chat_history:
105
- formatted_chat_history.append(f"User: {user_message}\nAssistant: {bot_message}\n")
106
- return formatted_chat_history
107
-
108
- def conversation(qa_chain, message, history):
109
- try:
110
- formatted_chat_history = format_chat_history(history)
111
- response = qa_chain.invoke({"question": message, "chat_history": formatted_chat_history})
112
- response_answer = response["answer"]
113
- response_sources = response["source_documents"]
114
-
115
- sources = []
116
- for doc in response_sources:
117
- sources.append({
118
- "content": doc.page_content.strip(),
119
- "page": doc.metadata["page"] + 1
120
- })
121
-
122
- new_history = history + [(message, response_answer)]
123
- return qa_chain, new_history, response_answer, sources
124
- except Exception as e:
125
- st.error(f"Error in conversation: {e}")
126
- return qa_chain, history, "", []
127
-
128
- def main():
129
- st.sidebar.title("PDF Chatbot")
130
-
131
- st.sidebar.markdown("### Step 1 - Upload PDF documents and Initialize RAG pipeline")
132
- uploaded_files = st.sidebar.file_uploader("Upload PDF documents", type="pdf", accept_multiple_files=True)
133
-
134
- if uploaded_files:
135
- if st.sidebar.button("Create vector database"):
136
- with st.spinner("Creating vector database..."):
137
- vector_db, db_message = initialize_database(uploaded_files)
138
- st.sidebar.success(db_message)
139
- st.session_state['vector_db'] = vector_db
140
-
141
- if 'vector_db' not in st.session_state:
142
- st.session_state['vector_db'] = None
143
-
144
- if 'qa_chain' not in st.session_state:
145
- st.session_state['qa_chain'] = None
146
-
147
- if 'chat_history' not in st.session_state:
148
- st.session_state['chat_history'] = []
149
-
150
- st.sidebar.markdown("### Select Large Language Model (LLM)")
151
- llm_option = st.sidebar.radio("Available LLMs", list_llm_simple)
152
-
153
- if st.sidebar.button("Initialize Question Answering Chatbot"):
154
- with st.spinner("Initializing QA chatbot..."):
155
- qa_chain, llm_message = initialize_LLM(list_llm_simple.index(llm_option), st.session_state['vector_db'])
156
- st.session_state['qa_chain'] = qa_chain
157
- st.sidebar.success(llm_message)
158
-
159
- st.title("Chat with your Document")
160
-
161
- if st.session_state['qa_chain']:
162
- st.markdown("### Chatbot Response")
163
-
164
- # Display the chat history in a chat-like interface
165
- for i, (user_msg, bot_msg) in enumerate(st.session_state['chat_history']):
166
- st.markdown(f"**User:** {user_msg}")
167
- st.markdown(f"**Assistant:** {bot_msg}")
168
-
169
- st.markdown("### Relevant context from the source document")
170
-
171
- with st.expander("Relevant context from the source document"):
172
- if 'sources' in st.session_state:
173
- for i, source in enumerate(st.session_state['sources']):
174
- st.text_area(f"Source {i + 1} - Page {source['page']}", value=source["content"], height=100)
175
-
176
- message = st.text_input("Ask a question", key="message")
177
- if st.button("Submit"):
178
- if message:
179
- with st.spinner("Generating response..."):
180
- qa_chain, chat_history, response_answer, sources = conversation(st.session_state['qa_chain'], message, st.session_state['chat_history'])
181
- st.session_state['qa_chain'] = qa_chain
182
- st.session_state['chat_history'] = chat_history
183
- st.session_state['sources'] = sources
184
-
185
- # Display the new response immediately
186
- st.markdown(f"**User:** {message}")
187
- st.markdown(f"**Assistant:** {response_answer}")
188
-
189
- st.markdown("### Relevant context from the source document")
190
- with st.expander("Relevant context from the source document"):
191
- for i, source in enumerate(sources):
192
- st.text_area(f"Source {i + 1} - Page {source['page']}", value=source["content"], height=100)
193
-
194
- if __name__ == "__main__":
195
- main()