barghavani commited on
Commit
b53e14c
·
verified ·
1 Parent(s): dc4a5b8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +225 -138
app.py CHANGED
@@ -1,151 +1,238 @@
1
- from typing import Any
2
  import gradio as gr
3
- from langchain.embeddings.openai import OpenAIEmbeddings
4
- from langchain.vectorstores import Chroma
5
-
6
  from langchain.chains import ConversationalRetrievalChain
7
- from langchain.chat_models import ChatOpenAI
 
 
 
 
 
 
8
 
9
- from langchain.document_loaders import PyPDFLoader
10
 
11
- import fitz
12
- from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- import chromadb
15
- import re
16
- import uuid
17
 
18
- enable_box = gr.Textbox.update(value = None, placeholder = 'Upload your OpenAI API key',interactive = True)
19
- disable_box = gr.Textbox.update(value = 'OpenAI API key is Set', interactive = False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- def set_apikey(api_key: str):
22
- app.OPENAI_API_KEY = api_key
23
- return disable_box
24
-
25
- def enable_api_box():
26
- return enable_box
27
-
28
- def add_text(history, text: str):
29
- if not text:
30
- raise gr.Error('enter text')
31
- history = history + [(text,'')]
32
- return history
33
-
34
- class my_app:
35
- def __init__(self, OPENAI_API_KEY: str = None ) -> None:
36
- self.OPENAI_API_KEY: str = OPENAI_API_KEY
37
- self.chain = None
38
- self.chat_history: list = []
39
- self.N: int = 0
40
- self.count: int = 0
41
-
42
- def __call__(self, file: str) -> Any:
43
- if self.count==0:
44
- self.chain = self.build_chain(file)
45
- self.count+=1
46
- return self.chain
47
-
48
- def chroma_client(self):
49
- #create a chroma client
50
- client = chromadb.Client()
51
- #create a collecyion
52
- collection = client.get_or_create_collection(name="my-collection")
53
- return client
54
 
55
- def process_file(self,file: str):
56
- loader = PyPDFLoader(file.name)
57
- documents = loader.load()
58
- pattern = r"/([^/]+)$"
59
- match = re.search(pattern, file.name)
60
- file_name = match.group(1)
61
- return documents, file_name
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- def build_chain(self, file: str):
64
- documents, file_name = self.process_file(file)
65
- #Load embeddings model
66
- embeddings = OpenAIEmbeddings(openai_api_key=self.OPENAI_API_KEY)
67
- pdfsearch = Chroma.from_documents(documents, embeddings, collection_name= file_name,)
68
- chain = ConversationalRetrievalChain.from_llm(
69
- ChatOpenAI(temperature=0.0, openai_api_key=self.OPENAI_API_KEY),
70
- retriever=pdfsearch.as_retriever(search_kwargs={"k": 1}),
71
- return_source_documents=True,)
72
- return chain
73
 
74
 
75
- def get_response(history, query, file):
76
- if not file:
77
- raise gr.Error(message='Upload a PDF')
78
- chain = app(file)
79
- result = chain({"question": query, 'chat_history':app.chat_history},return_only_outputs=True)
80
- app.chat_history += [(query, result["answer"])]
81
- app.N = list(result['source_documents'][0])[1][1]['page']
82
- for char in result['answer']:
83
- history[-1][-1] += char
84
- yield history,''
85
-
86
- def render_file(file):
87
- doc = fitz.open(file.name)
88
- page = doc[app.N]
89
- #Render the page as a PNG image with a resolution of 300 DPI
90
- pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72))
91
- image = Image.frombytes('RGB', [pix.width, pix.height], pix.samples)
92
- return image
93
-
94
- def render_first(file):
95
- doc = fitz.open(file.name)
96
- page = doc[0]
97
- #Render the page as a PNG image with a resolution of 300 DPI
98
- pix = page.get_pixmap(matrix=fitz.Matrix(300/72, 300/72))
99
- image = Image.frombytes('RGB', [pix.width, pix.height], pix.samples)
100
- return image,[]
101
-
102
- app = my_app()
103
- with gr.Blocks() as demo:
104
- with gr.Column():
105
- with gr.Row():
106
- with gr.Column(scale=0.8):
107
- api_key = gr.Textbox(placeholder='Enter OpenAI API key', show_label=False, interactive=True).style(container=False)
108
- with gr.Column(scale=0.2):
109
- change_api_key = gr.Button('Change Key')
110
- with gr.Row():
111
- chatbot = gr.Chatbot(value=[], elem_id='chatbot').style(height=650)
112
- show_img = gr.Image(label='Upload PDF', tool='select' ).style(height=680)
113
- with gr.Row():
114
- with gr.Column(scale=0.60):
115
- txt = gr.Textbox(
116
- show_label=False,
117
- placeholder="Enter text and press enter",
118
- ).style(container=False)
119
- with gr.Column(scale=0.20):
120
- submit_btn = gr.Button('submit')
121
- with gr.Column(scale=0.20):
122
- btn = gr.UploadButton("📁 upload a PDF", file_types=[".pdf"]).style()
123
 
124
- api_key.submit(
125
- fn=set_apikey,
126
- inputs=[api_key],
127
- outputs=[api_key,])
128
- change_api_key.click(
129
- fn= enable_api_box,
130
- outputs=[api_key])
131
- btn.upload(
132
- fn=render_first,
133
- inputs=[btn],
134
- outputs=[show_img,chatbot],)
135
-
136
- submit_btn.click(
137
- fn=add_text,
138
- inputs=[chatbot,txt],
139
- outputs=[chatbot, ],
140
- queue=False).success(
141
- fn=get_response,
142
- inputs = [chatbot, txt, btn],
143
- outputs = [chatbot,txt]).success(
144
- fn=render_file,
145
- inputs = [btn],
146
- outputs=[show_img]
147
- )
 
 
 
 
 
 
 
 
 
 
 
148
 
149
-
150
- demo.queue()
151
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from langchain_community.document_loaders import PyPDFLoader
3
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
4
+ from langchain_community.vectorstores import Chroma
5
  from langchain.chains import ConversationalRetrievalChain
6
+ from langchain_community.embeddings import OpenAIEmbeddings
7
+ from langchain_community.llms import ChatOpenAI
8
+ from langchain.memory import ConversationBufferMemory
9
+ from pathlib import Path
10
+ import chromadb
11
+ from unidecode import unidecode
12
+ import re
13
 
14
+ # Assume list_llm and other previously defined variables and functions are available
15
 
16
+ def get_openai_api_key():
17
+ """Function to prompt the user to input their OpenAI API key."""
18
+ api_key = input("Please enter your OpenAI API key: ")
19
+ return api_key
20
+ def load_doc(list_file_path, chunk_size, chunk_overlap):
21
+ # Processing for one document only
22
+ # loader = PyPDFLoader(file_path)
23
+ # pages = loader.load()
24
+ loaders = [PyPDFLoader(x) for x in list_file_path]
25
+ pages = []
26
+ for loader in loaders:
27
+ pages.extend(loader.load())
28
+ # text_splitter = RecursiveCharacterTextSplitter(chunk_size = 600, chunk_overlap = 50)
29
+ text_splitter = RecursiveCharacterTextSplitter(
30
+ chunk_size = chunk_size,
31
+ chunk_overlap = chunk_overlap)
32
+ doc_splits = text_splitter.split_documents(pages)
33
+ return doc_splits
34
+ def create_db(splits, collection_name, api_key):
35
+ """Adjusted to include OpenAI API key for embeddings."""
36
+ embedding = OpenAIEmbeddings(api_key=api_key) # Utilize the OpenAI API key
37
+ new_client = chromadb.EphemeralClient()
38
+ vectordb = Chroma.from_documents(
39
+ documents=splits,
40
+ embedding=embedding,
41
+ client=new_client,
42
+ collection_name=collection_name,
43
+ )
44
+ return vectordb
45
+ def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, api_key):
46
+ """Adjusted to include OpenAI API key for the LLM initialization."""
47
+ llm = ChatOpenAI(api_key=api_key, temperature=temperature, model_name=llm_model)
48
+ memory = ConversationBufferMemory()
49
+ retriever = vector_db.as_retriever()
50
+ qa_chain = ConversationalRetrievalChain.from_llm(
51
+ llm=llm,
52
+ retriever=retriever,
53
+ memory=memory,
54
+ return_source_documents=True,
55
+ )
56
+ return qa_chain
57
+ def create_collection_name(filepath):
58
+ # Extract filename without extension
59
+ collection_name = Path(filepath).stem
60
+ # Fix potential issues from naming convention
61
+ ## Remove space
62
+ collection_name = collection_name.replace(" ","-")
63
+ ## ASCII transliterations of Unicode text
64
+ collection_name = unidecode(collection_name)
65
+ ## Remove special characters
66
+ #collection_name = re.findall("[\dA-Za-z]*", collection_name)[0]
67
+ collection_name = re.sub('[^A-Za-z0-9]+', '-', collection_name)
68
+ ## Limit length to 50 characters
69
+ collection_name = collection_name[:50]
70
+ ## Minimum length of 3 characters
71
+ if len(collection_name) < 3:
72
+ collection_name = collection_name + 'xyz'
73
+ print('Filepath: ', filepath)
74
+ print('Collection name: ', collection_name)
75
+ return collection_name
76
 
 
 
 
77
 
78
+ # Initialize database
79
+ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Progress()):
80
+ # Create list of documents (when valid)
81
+ list_file_path = [x.name for x in list_file_obj if x is not None]
82
+ # Create collection_name for vector database
83
+ progress(0.1, desc="Creating collection name...")
84
+ collection_name = create_collection_name(list_file_path[0])
85
+ progress(0.25, desc="Loading document...")
86
+ # Load document and create splits
87
+ doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
88
+ # Create or load vector database
89
+ progress(0.5, desc="Generating vector database...")
90
+ # global vector_db
91
+ vector_db = create_db(doc_splits, collection_name)
92
+ progress(0.9, desc="Done!")
93
+ return vector_db, collection_name, "Complete!"
94
 
95
+
96
+ def initialize_LLM(llm_option, llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
97
+ # print("llm_option",llm_option)
98
+ llm_name = list_llm[llm_option]
99
+ print("llm_name: ",llm_name)
100
+ qa_chain = initialize_llmchain(llm_name, llm_temperature, max_tokens, top_k, vector_db, progress)
101
+ return qa_chain, "Complete!"
102
+
103
+
104
+ def format_chat_history(message, chat_history):
105
+ formatted_chat_history = []
106
+ for user_message, bot_message in chat_history:
107
+ formatted_chat_history.append(f"User: {user_message}")
108
+ formatted_chat_history.append(f"Assistant: {bot_message}")
109
+ return formatted_chat_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+
112
+ def conversation(qa_chain, message, history):
113
+ formatted_chat_history = format_chat_history(message, history)
114
+ #print("formatted_chat_history",formatted_chat_history)
115
+
116
+ # Generate response using QA chain
117
+ response = qa_chain({"question": message, "chat_history": formatted_chat_history})
118
+ response_answer = response["answer"]
119
+ if response_answer.find("Helpful Answer:") != -1:
120
+ response_answer = response_answer.split("Helpful Answer:")[-1]
121
+ response_sources = response["source_documents"]
122
+ response_source1 = response_sources[0].page_content.strip()
123
+ response_source2 = response_sources[1].page_content.strip()
124
+ response_source3 = response_sources[2].page_content.strip()
125
+ # Langchain sources are zero-based
126
+ response_source1_page = response_sources[0].metadata["page"] + 1
127
+ response_source2_page = response_sources[1].metadata["page"] + 1
128
+ response_source3_page = response_sources[2].metadata["page"] + 1
129
+ # print ('chat response: ', response_answer)
130
+ # print('DB source', response_sources)
131
 
132
+ # Append user message and response to chat history
133
+ new_history = history + [(message, response_answer)]
134
+ # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
135
+ return qa_chain, gr.update(value=""), new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
 
 
 
 
 
 
136
 
137
 
138
+ def upload_file(file_obj):
139
+ list_file_path = []
140
+ for idx, file in enumerate(file_obj):
141
+ file_path = file_obj.name
142
+ list_file_path.append(file_path)
143
+ # print(file_path)
144
+ # initialize_database(file_path, progress)
145
+ return list_file_path
146
+
147
+
148
+ def demo():
149
+ with gr.Blocks(theme="base") as demo:
150
+ vector_db = gr.State()
151
+ qa_chain = gr.State()
152
+ collection_name = gr.State()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
153
 
154
+ gr.Markdown(
155
+ """PDF-based chatbot (by Dr. Aloke Upadhaya)</center></h2>
156
+ <h3>Ask any questions about your PDF documents, along with follow-ups</h3>
157
+ """)
158
+ with gr.Tab("Step 1 - Document pre-processing"):
159
+ with gr.Row():
160
+ document = gr.Files(height=100, file_count="multiple", file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
161
+ # upload_btn = gr.UploadButton("Loading document...", height=100, file_count="multiple", file_types=["pdf"], scale=1)
162
+ with gr.Row():
163
+ db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database")
164
+ with gr.Accordion("Advanced options - Document text splitter", open=False):
165
+ with gr.Row():
166
+ slider_chunk_size = gr.Slider(minimum = 100, maximum = 1000, value=600, step=20, label="Chunk size", info="Chunk size", interactive=True)
167
+ with gr.Row():
168
+ slider_chunk_overlap = gr.Slider(minimum = 10, maximum = 200, value=40, step=10, label="Chunk overlap", info="Chunk overlap", interactive=True)
169
+ with gr.Row():
170
+ db_progress = gr.Textbox(label="Vector database initialization", value="None")
171
+ with gr.Row():
172
+ db_btn = gr.Button("Generate vector database...")
173
+
174
+ with gr.Tab("Step 2 - QA chain initialization"):
175
+ with gr.Row():
176
+ llm_btn = gr.Radio(list_llm_simple, \
177
+ label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
178
+ with gr.Accordion("Advanced options - LLM model", open=False):
179
+ with gr.Row():
180
+ slider_temperature = gr.Slider(minimum = 0.0, maximum = 1.0, value=0.7, step=0.1, label="Temperature", info="Model temperature", interactive=True)
181
+ with gr.Row():
182
+ slider_maxtokens = gr.Slider(minimum = 224, maximum = 4096, value=1024, step=32, label="Max Tokens", info="Model max tokens", interactive=True)
183
+ with gr.Row():
184
+ slider_topk = gr.Slider(minimum = 1, maximum = 10, value=3, step=1, label="top-k samples", info="Model top-k samples", interactive=True)
185
+ with gr.Row():
186
+ llm_progress = gr.Textbox(value="None",label="QA chain initialization")
187
+ with gr.Row():
188
+ qachain_btn = gr.Button("Initialize question-answering chain...")
189
 
190
+ with gr.Tab("Step 3 - Conversation with chatbot"):
191
+ chatbot = gr.Chatbot(height=300)
192
+ with gr.Accordion("Advanced - Document references", open=False):
193
+ with gr.Row():
194
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
195
+ source1_page = gr.Number(label="Page", scale=1)
196
+ with gr.Row():
197
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
198
+ source2_page = gr.Number(label="Page", scale=1)
199
+ with gr.Row():
200
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
201
+ source3_page = gr.Number(label="Page", scale=1)
202
+ with gr.Row():
203
+ msg = gr.Textbox(placeholder="Type message", container=True)
204
+ with gr.Row():
205
+ submit_btn = gr.Button("Submit")
206
+ clear_btn = gr.ClearButton([msg, chatbot])
207
+
208
+ # Preprocessing events
209
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
210
+ db_btn.click(initialize_database, \
211
+ inputs=[document, slider_chunk_size, slider_chunk_overlap], \
212
+ outputs=[vector_db, collection_name, db_progress])
213
+ qachain_btn.click(initialize_LLM, \
214
+ inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
215
+ outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
216
+ inputs=None, \
217
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
218
+ queue=False)
219
+
220
+ # Chatbot events
221
+ msg.submit(conversation, \
222
+ inputs=[qa_chain, msg, chatbot], \
223
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
224
+ queue=False)
225
+ submit_btn.click(conversation, \
226
+ inputs=[qa_chain, msg, chatbot], \
227
+ outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
228
+ queue=False)
229
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], \
230
+ inputs=None, \
231
+ outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
232
+ queue=False)
233
+ demo.queue().launch(debug=True)
234
+
235
+
236
+ if __name__ == "__main__":
237
+ api_key = get_openai_api_key() # Get the API key from the user
238
+ demo(api_key) # Pass the API key to the demo function