barghavani commited on
Commit
24555e8
·
verified ·
1 Parent(s): 993e5ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +272 -23
app.py CHANGED
@@ -1,35 +1,284 @@
1
- from pathlib import Path
2
- from typing import Union
3
-
4
- from pypdf import PdfReader
5
- from transformers import pipeline
6
  import gradio as gr
7
- from langchain.llms import HuggingFaceHub
8
  import os
9
- from getpass import getpass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
11
 
12
 
13
- question_answerer =pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-alpha")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
- def get_text_from_pdf(pdf_file: Union[str, Path]) -> str:
17
- """Read the PDF from the given path and return a string with its entire content."""
18
- reader = PdfReader(pdf_file)
 
 
 
 
 
19
 
20
- # Extract text from all pages
21
- full_text = ""
22
- for page in reader.pages:
23
- full_text += page.extract_text()
24
- return full_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- def answer_doc_question(pdf_file, question):
28
- pdf_text = get_text_from_pdf(pdf_file)
29
- answer = question_answerer(question, pdf_text)
30
- return answer["answer"]
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
- pdf_input = gr.File(file_types=[".pdf"], label="Upload a PDF document and ask a question about it.")
34
- question = gr.Textbox(label="Type a question regarding the uploaded document here.")
35
- gr.Interface(fn=answer_doc_question, inputs=[pdf_input, question], outputs="text").launch()
 
 
 
 
 
 
1
  import gradio as gr
 
2
  import os
3
+
4
+ from langchain.document_loaders import PyPDFLoader
5
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
6
+ from langchain.vectorstores import Chroma
7
+ from langchain.chains import ChatVectorDBChain
8
+ from langchain.embeddings import OpenAIEmbeddings
9
+ from langchain.chains import ConversationalRetrievalChain
10
+ from langchain_community.embeddings import HuggingFaceEmbeddings
11
+ from langchain_community.llms import HuggingFacePipeline
12
+ from langchain.chains import ConversationChain
13
+ from langchain.memory import ConversationBufferMemory
14
+ from langchain_community.llms import HuggingFaceEndpoint
15
+ from langchain_community.chat_models.openai import ChatOpenAI
16
+ from langchain.chains import ConversationalRetrievalChain
17
+ from langchain.llms import OpenAI
18
+
19
+ from pathlib import Path
20
+ import chromadb
21
+ from unidecode import unidecode
22
+
23
+ from transformers import AutoTokenizer
24
+ import transformers
25
+ import torch
26
+ import tqdm
27
+ import accelerate
28
+ import re
29
+
30
+
31
  TOKEN = os.getenv('HUGGING_FACE_HUB_TOKEN')
32
 
33
 
34
+
35
+ # Load PDF document and create doc splits
36
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
37
+ # Processing for one document only
38
+ # loader = PyPDFLoader(file_path)
39
+ # pages = loader.load()
40
+ loaders = [PyPDFLoader(x) for x in list_file_path]
41
+ pages = []
42
+ for loader in loaders:
43
+ pages.extend(loader.load())
44
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
45
+ text_splitter = RecursiveCharacterTextSplitter(
46
+ chunk_size = chunk_size,
47
+ chunk_overlap = chunk_overlap)
48
+ doc_splits = text_splitter.split_documents(pages)
49
+ return doc_splits
50
+
51
+
52
+ # Initialize your vector database with OpenAIEmbeddings and persist it
53
+ def create_db(documents, collection_name, persist_directory="."):
54
+ embeddings = OpenAIEmbeddings()
55
+ vectordb = Chroma.from_documents(
56
+ documents=documents,
57
+ embedding=embeddings,
58
+ collection_name=collection_name,
59
+ persist_directory=persist_directory
60
+ )
61
+ vectordb.persist()
62
+ return vectordb
63
+
64
+ # Load vector database
65
+ def load_db():
66
+ embedding = HuggingFaceEmbeddings()
67
+ vectordb = Chroma(
68
+ # persist_directory=default_persist_directory,
69
+ embedding_function=embedding)
70
+ return vectordb
71
+
72
+
73
+
74
+
75
+ # Assuming vectordb is correctly initialized and persisted as shown above
76
+ def initialize_llmchain(vectordb, model_name="gpt-3.5-turbo", temperature=0.9):
77
+ chat_model = ChatOpenAI(temperature=temperature, model_name=model_name)
78
+ chat_vector_db_chain = ChatVectorDBChain.from_llm(
79
+ llm=chat_model,
80
+ vectordb=vectordb,
81
+ return_source_documents=True
82
+ )
83
+ return chat_vector_db_chain
84
+
85
+ # Example usage
86
+ if __name__ == "__main__":
87
+ # Suppose `documents` is a list of your documents and `collection_name` is your desired collection name
88
+ documents = [...] # your documents here
89
+ collection_name = "your_collection_name"
90
+ vectordb = create_db(documents, collection_name)
91
+
92
+ # Initialize the ChatVectorDBChain with the vector database
93
+ chat_vector_db_chain = initialize_chat_vector_db_chain(vectordb)
94
+
95
+ # Use the chain to process a query
96
+ query = "your query here"
97
+ result = chat_vector_db_chain({"question": query, "chat_history": ""})
98
+
99
+ print("Answer:")
100
+ print(result["answer"])
101
+ # Generate collection name for vector database
102
+ # - Use filepath as input, ensuring unicode text
103
+ def create_collection_name(filepath):
104
+ # Extract filename without extension
105
+ collection_name = Path(filepath).stem
106
+ # Fix potential issues from naming convention
107
+ ## Remove space
108
+ collection_name = collection_name.replace(" ","-")
109
+ ## ASCII transliterations of Unicode text
110
+ collection_name = unidecode(collection_name)
111
+ ## Remove special characters
112
+ #collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
113
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
114
+ ## Limit length to 50 characters
115
+ collection_name = collection_name[:50]
116
+ ## Minimum length of 3 characters
117
+ if len(collection_name) < 3:
118
+ collection_name = collection_name + 'xyz'
119
+ print('Filepath: ', filepath)
120
+ print('Collection name: ', collection_name)
121
+ return collection_name
122
+
123
+
124
+ # Initialize database
125
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
126
+ # Create list of documents (when valid)
127
+ list_file_path = [x.name for x in list_file_obj if x is not None]
128
+ # Create collection_name for vector database
129
+ progress(0.1, desc="Creating collection name...")
130
+ collection_name = create_collection_name(list_file_path[0])
131
+ progress(0.25, desc="Loading document...")
132
+ # Load document and create splits
133
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
134
+ # Create or load vector database
135
+ progress(0.5, desc="Generating vector database...")
136
+ # global vector_db
137
+ vector_db = create_db(doc_splits, collection_name)
138
+ progress(0.9, desc="Done!")
139
+ return vector_db, collection_name, "Complete!"
140
+
141
+
142
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
143
+ # print("llm_option",llm_option)
144
+ llm_name = list_llm[llm_option]
145
+ print("llm_name: ",llm_name)
146
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
147
+ return qa_chain, "Complete!"
148
+
149
+
150
+ def format_chat_history(message, chat_history):
151
+ formatted_chat_history = []
152
+ for user_message, bot_message in chat_history:
153
+ formatted_chat_history.append(f"User: {user_message}")
154
+ formatted_chat_history.append(f"Assistant: {bot_message}")
155
+ return formatted_chat_history
156
+
157
 
158
+ def conversation(qa_chain, message, history):
159
+ formatted_chat_history = format_chat_history(message, history)
160
+ #print("formatted_chat_history",formatted_chat_history)
161
+
162
+ # Generate response using QA chain
163
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
164
+ response_answer = response["answer"]
165
+ if response_answer.find("Helpful Answer:") != -1:
166
+ response_answer = response_answer.split("Helpful Answer:")[-1]
167
+ response_sources = response["source_documents"]
168
+ response_source1 = response_sources[0].page_content.strip()
169
+ response_source2 = response_sources[1].page_content.strip()
170
+ response_source3 = response_sources[2].page_content.strip()
171
+ # Langchain sources are zero-based
172
+ response_source1_page = response_sources[0].metadata["page"] + 1
173
+ response_source2_page = response_sources[1].metadata["page"] + 1
174
+ response_source3_page = response_sources[2].metadata["page"] + 1
175
+ # print ('chat response: ', response_answer)
176
+ # print('DB source', response_sources)
177
+
178
+ # Append user message and response to chat history
179
+ new_history = history + [(message, response_answer)]
180
+ # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
181
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
182
+
183
 
184
+ def upload_file(file_obj):
185
+ list_file_path = []
186
+ for idx, file in enumerate(file_obj):
187
+ file_path = file_obj.name
188
+ list_file_path.append(file_path)
189
+ # print(file_path)
190
+ # initialize_database(file_path, progress)
191
+ return list_file_path
192
 
193
+
194
+ def demo():
195
+ with gr.Blocks(theme="base") as demo:
196
+ vector_db = gr.State()
197
+ qa_chain = gr.State()
198
+ collection_name = gr.State()
199
+
200
+ gr.Markdown(
201
+ """PDF-based chatbot (by Dr. Aloke Upadhaya)</center></h2>
202
+ <h3>Ask any questions about your PDF documents, along with follow-ups</h3>
203
+ """)
204
+
205
+ with gr.Tab("Step 1 - Document pre-processing"):
206
+ with gr.Row():
207
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
208
+ # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
209
+ with gr.Row():
210
+ db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
211
+ with gr.Accordion("Advanced options - Document text splitter", open=False):
212
+ with gr.Row():
213
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
214
+ with gr.Row():
215
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
216
+ with gr.Row():
217
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
218
+ with gr.Row():
219
+ db_btn = gr.Button("Generate vector database...")
220
+
221
+ with gr.Tab("Step 2 - QA chain initialization"):
222
+ with gr.Row():
223
+ llm_btn = gr.Radio(list_llm_simple, \
224
+ label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
225
+ with gr.Accordion("Advanced options - LLM model", open=False):
226
+ with gr.Row():
227
+ slider_temperature = gr.Slider(minimum = 0.0, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
228
+ with gr.Row():
229
+ slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
230
+ with gr.Row():
231
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
232
+ with gr.Row():
233
+ llm_progress = gr.Textbox(value="None",label="QA chain initialization")
234
+ with gr.Row():
235
+ qachain_btn = gr.Button("Initialize question-answering chain...")
236
 
237
+ with gr.Tab("Step 3 - Conversation with chatbot"):
238
+ chatbot = gr.Chatbot(height=300)
239
+ with gr.Accordion("Advanced - Document references", open=False):
240
+ with gr.Row():
241
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
242
+ source1_page = gr.Number(label="Page", scale=1)
243
+ with gr.Row():
244
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
245
+ source2_page = gr.Number(label="Page", scale=1)
246
+ with gr.Row():
247
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
248
+ source3_page = gr.Number(label="Page", scale=1)
249
+ with gr.Row():
250
+ msg = gr.Textbox(placeholder="Type message", container=True)
251
+ with gr.Row():
252
+ submit_btn = gr.Button("Submit")
253
+ clear_btn = gr.ClearButton([msg, chatbot])
254
+
255
+ # Preprocessing events
256
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
257
+ db_btn.click(initialize_database, \
258
+ inputs=[document, slider_chunk_size, slider_chunk_overlap], \
259
+ outputs=[vector_db, collection_name, db_progress])
260
+ qachain_btn.click(initialize_LLM, \
261
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
262
+ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
263
+ inputs=None, \
264
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
265
+ queue=False)
266
 
267
+ # Chatbot events
268
+ msg.submit(conversation, \
269
+ inputs=[qa_chain, msg, chatbot], \
270
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
271
+ queue=False)
272
+ submit_btn.click(conversation, \
273
+ inputs=[qa_chain, msg, chatbot], \
274
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
275
+ queue=False)
276
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], \
277
+ inputs=None, \
278
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
279
+ queue=False)
280
+ demo.queue().launch(debug=True)
281
 
282
 
283
+ if __name__ == "__main__":
284
+ demo()