Rohan12345 commited on
Commit
9e44a6f
·
verified ·
1 Parent(s): a853a16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +231 -30
app.py CHANGED
@@ -1,34 +1,40 @@
1
  import gradio as gr
 
2
  from langchain_community.document_loaders import PyPDFLoader
3
  from langchain.text_splitter import RecursiveCharacterTextSplitter
4
  from langchain_community.vectorstores import Chroma
5
- from langchain_community.embeddings import HuggingFaceEmbeddings
 
 
 
 
 
6
  from pathlib import Path
 
7
  from unidecode import unidecode
 
 
 
 
 
 
8
 
9
- def summarize_document(document_text):
10
- # Your summarization code here
11
- summary = "The document covers various topics such as X, Y, and Z, providing detailed insights into each aspect."
12
- return summary
13
-
14
- def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
15
- list_file_path = [x.name for x in list_file_obj if x is not None]
16
- collection_name = create_collection_name(list_file_path[0])
17
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
18
- vector_db = create_db(doc_splits, collection_name)
19
- return vector_db, collection_name, "Complete!"
20
 
 
21
  def load_doc(list_file_path, chunk_size, chunk_overlap):
22
  loaders = [PyPDFLoader(x) for x in list_file_path]
23
  pages = []
24
  for loader in loaders:
25
  pages.extend(loader.load())
26
  text_splitter = RecursiveCharacterTextSplitter(
27
- chunk_size = chunk_size,
28
- chunk_overlap = chunk_overlap)
29
  doc_splits = text_splitter.split_documents(pages)
30
  return doc_splits
31
 
 
32
  def create_db(splits, collection_name):
33
  embedding = HuggingFaceEmbeddings()
34
  new_client = chromadb.EphemeralClient()
@@ -40,22 +46,217 @@ def create_db(splits, collection_name):
40
  )
41
  return vectordb
42
 
43
- def create_collection_name(filepath):
44
- collection_name = Path(filepath).stem
45
- collection_name = unidecode(collection_name)
46
- collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
47
- collection_name = collection_name[:50]
48
- if len(collection_name) < 3:
49
- collection_name = collection_name + 'xyz'
50
- if not collection_name[0].isalnum():
51
- collection_name = 'A' + collection_name[1:]
52
- if not collection_name[-1].isalnum():
53
- collection_name = collection_name[:-1] + 'Z'
54
- return collection_name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def demo():
57
- with gr.Interface(summarize_document, inputs="text", outputs="text", title="PDF Summarizer") as iface:
58
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- if __name__ == "__main__":
61
- demo()
 
1
  import gradio as gr
2
+ import os
3
  from langchain_community.document_loaders import PyPDFLoader
4
  from langchain.text_splitter import RecursiveCharacterTextSplitter
5
  from langchain_community.vectorstores import Chroma
6
+ from langchain.chains import ConversationalRetrievalChain
7
+ from langchain_community.embeddings import HuggingFaceEmbeddings
8
+ from langchain_community.llms import HuggingFacePipeline
9
+ from langchain.chains import ConversationChain
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain_community.llms import HuggingFaceEndpoint
12
  from pathlib import Path
13
+ import chromadb
14
  from unidecode import unidecode
15
+ from transformers import AutoTokenizer
16
+ import transformers
17
+ import torch
18
+ import tqdm
19
+ import accelerate
20
+ import re
21
 
22
+ list_llm = ["HuggingFaceH4/zephyr-7b-beta", "mistralai/Mistral-7B-Instruct-v0.2"]
23
+ list_llm_simple = [os.path.basename(llm) for llm in list_llm]
 
 
 
 
 
 
 
 
 
24
 
25
+ # Load PDF document and create doc splits
26
  def load_doc(list_file_path, chunk_size, chunk_overlap):
27
  loaders = [PyPDFLoader(x) for x in list_file_path]
28
  pages = []
29
  for loader in loaders:
30
  pages.extend(loader.load())
31
  text_splitter = RecursiveCharacterTextSplitter(
32
+ chunk_size=chunk_size,
33
+ chunk_overlap=chunk_overlap)
34
  doc_splits = text_splitter.split_documents(pages)
35
  return doc_splits
36
 
37
+ # Create vector database
38
  def create_db(splits, collection_name):
39
  embedding = HuggingFaceEmbeddings()
40
  new_client = chromadb.EphemeralClient()
 
46
  )
47
  return vectordb
48
 
49
+ # Initialize langchain LLM chain
50
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
51
+ progress(0.1, desc="Initializing HF Hub...")
52
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
53
+ llm = HuggingFaceEndpoint(
54
+ repo_id=llm_model,
55
+ temperature=temperature,
56
+ max_new_tokens=max_tokens,
57
+ top_k=top_k,
58
+ load_in_8bit=True,
59
+ )
60
+ # Add other LLM models initialization here
61
+ progress(0.75, desc="Defining buffer memory...")
62
+ memory = ConversationBufferMemory(
63
+ memory_key="chat_history",
64
+ output_key='answer',
65
+ return_messages=True
66
+ )
67
+ retriever=vector_db.as_retriever()
68
+ progress(0.8, desc="Defining retrieval chain...")
69
+ qa_chain = ConversationalRetrievalChain.from_llm(
70
+ llm,
71
+ retriever=retriever,
72
+ chain_type="stuff",
73
+ memory=memory,
74
+ return_source_documents=True,
75
+ verbose=False,
76
+ )
77
+ progress(0.9, desc="Done!")
78
+ return qa_chain
79
+
80
+ def format_chat_history(message, chat_history):
81
+ formatted_chat_history = []
82
+ for user_message, bot_message in chat_history:
83
+ formatted_chat_history.append(f"User: {user_message}")
84
+ formatted_chat_history.append(f"Assistant: {bot_message}")
85
+ return formatted_chat_history
86
+
87
+ def conversation(qa_chain, message, history):
88
+ formatted_chat_history = format_chat_history(message, history)
89
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
90
+ response_answer = response["answer"]
91
+ response_sources = response["source_documents"]
92
+ response_source1 = response_sources[0].page_content.strip()
93
+ response_source2 = response_sources[1].page_content.strip()
94
+ response_source3 = response_sources[2].page_content.strip()
95
+ response_source1_page = response_sources[0].metadata["page"] + 1
96
+ response_source2_page = response_sources[1].metadata["page"] + 1
97
+ response_source3_page = response_sources[2].metadata["page"] + 1
98
+ new_history = history + [(message, response_answer)]
99
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
100
+
101
+ def upload_file(file_obj):
102
+ list_file_path = []
103
+ for idx, file in enumerate(file_obj):
104
+ file_path = file_obj.name
105
+ list_file_path.append(file_path)
106
+ return list_file_path
107
 
108
  def demo():
109
+ with gr.Blocks(theme="base") as demo:
110
+ vector_db = gr.State()
111
+ qa_chain = gr.State()
112
+ collection_name = gr.State()
113
+
114
+ gr.Markdown(
115
+ """<center><h2>PDF-based chatbot</center></h2>
116
+ <h3>Ask any questions about your PDF documents</h3>""")
117
+
118
+ with gr.Tab("Step 1 - Upload PDF"):
119
+ with gr.Row():
120
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
121
+
122
+ with gr.Tab("Step 2 - Process document"):
123
+ with gr.Row():
124
+ db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
125
+ with gr.Accordion("Advanced options - Document text splitter", open=False):
126
+ with gr.Row():
127
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
128
+ with gr.Row():
129
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
130
+ with gr.Row():
131
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
132
+ with gr.Row():
133
+ db_btn = gr.Button("Generate vector database")
134
+
135
+ with gr.Tab("Step 3 - Initialize QA chain"):
136
+ with gr.Row():
137
+ llm_btn = gr.Radio(list_llm_simple, \
138
+ label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
139
+ with gr.Accordion("Advanced options - LLM model", open=False):
140
+ with gr.Row():
141
+ slider_temperature = gr.Slider(minimum = 0.01, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
142
+ with gr.Row():
143
+ slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
144
+ with gr.Row():
145
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
146
+ with gr.Row():
147
+ llm_progress = gr.Textbox(label="LLM initialization", value="None")
148
+ with gr.Row():
149
+ llm_btn = gr.Button("Initialize LLM chain")
150
+
151
+ with gr.Tab("Step 4 - Chat"):
152
+ with gr.Row():
153
+ message = gr.Textbox(label="Your question")
154
+ ask_btn = gr.Button("Ask")
155
+ with gr.Row():
156
+ answer = gr.Textbox(label="Answer", value="Ask your question to get an answer")
157
+ chat_history = gr.Textbox(label="Chat history", value="Chat history")
158
+ source1 = gr.Textbox(label="Source 1", value="Source 1")
159
+ source2 = gr.Textbox(label="Source 2", value="Source 2")
160
+ source3 = gr.Textbox(label="Source 3", value="Source 3")
161
+
162
+ @demo.func
163
+ def upload_file(file_obj):
164
+ list_file_path = []
165
+ for idx, file in enumerate(file_obj):
166
+ file_path = file_obj.name
167
+ list_file_path.append(file_path)
168
+ return list_file_path
169
+
170
+ @demo.func
171
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
172
+ loaders = [PyPDFLoader(x) for x in list_file_path]
173
+ pages = []
174
+ for loader in loaders:
175
+ pages.extend(loader.load())
176
+ text_splitter = RecursiveCharacterTextSplitter(
177
+ chunk_size=chunk_size,
178
+ chunk_overlap=chunk_overlap)
179
+ doc_splits = text_splitter.split_documents(pages)
180
+ return doc_splits
181
+
182
+ @demo.func
183
+ def create_db(splits, collection_name):
184
+ embedding = HuggingFaceEmbeddings()
185
+ new_client = chromadb.EphemeralClient()
186
+ vectordb = Chroma.from_documents(
187
+ documents=splits,
188
+ embedding=embedding,
189
+ client=new_client,
190
+ collection_name=collection_name,
191
+ )
192
+ return vectordb
193
+
194
+ @demo.func
195
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
196
+ progress(0.1, desc="Initializing HF Hub...")
197
+ if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
198
+ llm = HuggingFaceEndpoint(
199
+ repo_id=llm_model,
200
+ temperature=temperature,
201
+ max_new_tokens=max_tokens,
202
+ top_k=top_k,
203
+ load_in_8bit=True,
204
+ )
205
+ # Add other LLM models initialization here
206
+ progress(0.75, desc="Defining buffer memory...")
207
+ memory = ConversationBufferMemory(
208
+ memory_key="chat_history",
209
+ output_key='answer',
210
+ return_messages=True
211
+ )
212
+ retriever=vector_db.as_retriever()
213
+ progress(0.8, desc="Defining retrieval chain...")
214
+ qa_chain = ConversationalRetrievalChain.from_llm(
215
+ llm,
216
+ retriever=retriever,
217
+ chain_type="stuff",
218
+ memory=memory,
219
+ return_source_documents=True,
220
+ verbose=False,
221
+ )
222
+ progress(0.9, desc="Done!")
223
+ return qa_chain
224
+
225
+ @demo.func
226
+ def format_chat_history(message, chat_history):
227
+ formatted_chat_history = []
228
+ for user_message, bot_message in chat_history:
229
+ formatted_chat_history.append(f"User: {user_message}")
230
+ formatted_chat_history.append(f"Assistant: {bot_message}")
231
+ return formatted_chat_history
232
+
233
+ @demo.func
234
+ def conversation(qa_chain, message, history):
235
+ formatted_chat_history = format_chat_history(message, history)
236
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
237
+ response_answer = response["answer"]
238
+ response_sources = response["source_documents"]
239
+ response_source1 = response_sources[0].page_content.strip()
240
+ response_source2 = response_sources[1].page_content.strip()
241
+ response_source3 = response_sources[2].page_content.strip()
242
+ response_source1_page = response_sources[0].metadata["page"] + 1
243
+ response_source2_page = response_sources[1].metadata["page"] + 1
244
+ response_source3_page = response_sources[2].metadata["page"] + 1
245
+ new_history = history + [(message, response_answer)]
246
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
247
+
248
+ # Define file upload actions
249
+ demo.upload_file(upload_file)
250
+
251
+ # Define document processing actions
252
+ demo.load_doc(load_doc)
253
+ demo.create_db(create_db)
254
+
255
+ # Define LLM chain initialization actions
256
+ demo.initialize_llmchain(initialize_llmchain)
257
+
258
+ # Define conversation actions
259
+ demo.format_chat_history(format_chat_history)
260
+ demo.conversation(conversation)
261
 
262
+ demo.launch()