vishwask commited on
Commit
b7e0851
·
verified ·
1 Parent(s): e3d8df5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -47
app.py CHANGED
@@ -31,7 +31,8 @@ temperature = 0.1
31
  max_tokens = 6000
32
  top_k = 3
33
 
34
- def load_doc(list_file_path, chunk_size, chunk_overlap):
 
35
  # Processing for one document only
36
  # loader = PyPDFLoader(file_path)
37
  # pages = loader.load()
@@ -46,6 +47,8 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
46
  doc_splits = text_splitter.split_documents(pages)
47
  return doc_splits
48
 
 
 
49
  # Create vector database
50
  def create_db(splits, collection_name):
51
  embedding = HuggingFaceEmbeddings()
@@ -67,12 +70,7 @@ def load_db():
67
  embedding_function=embedding)
68
  return vectordb
69
 
70
-
71
-
72
-
73
-
74
- #list_file_obj = document
75
-
76
  # Initialize database
77
  def initialize_database(list_file_obj):
78
  # Create list of documents (when valid)
@@ -94,7 +92,7 @@ def initialize_database(list_file_obj):
94
  # print('list_file_path: ', list_file_path)
95
  print('Collection name: ', collection_name)
96
  # Load document and create splits
97
- doc_splits = load_doc(list_file_path, chunk_size, chunk_overlap)
98
  # Create or load vector database
99
  # global vector_db
100
  vector_db = create_db(doc_splits, collection_name)
@@ -121,14 +119,6 @@ def initialize_LLM(vector_db):
121
  return qa_chain
122
 
123
 
124
- def format_chat_history(message, chat_history):
125
- formatted_chat_history = []
126
- for user_message, bot_message in chat_history:
127
- formatted_chat_history.append(f"User: {user_message}")
128
- formatted_chat_history.append(f"Assistant: {bot_message}")
129
- return formatted_chat_history
130
-
131
-
132
  def conversation(qa_chain, message, history):
133
  formatted_chat_history = format_chat_history(message, history)
134
  #print("formatted_chat_history",formatted_chat_history)
@@ -153,43 +143,47 @@ def conversation(qa_chain, message, history):
153
  new_history = history + [(message, response_answer)]
154
  # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
155
  return qa_chain, new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
156
-
157
- #document = os.listdir(list_file_obj)
158
- #qa_chain =
159
-
160
 
161
  def demo():
162
- with gr.Blocks(theme='base') as demo:
163
  vector_db = gr.State()
164
  qa_chain = gr.State()
165
- collection_name = gr.State()
166
-
167
- vector_db, collection_name = initialize_database(list_file_obj)
168
  chatbot = gr.Chatbot(height=300)
169
- with gr.Accordion('References', open=True):
170
- with gr.Row():
171
- doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
172
- source1_page = gr.Number(label="Page", scale=1)
173
- with gr.Row():
174
- doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
175
- source2_page = gr.Number(label="Page", scale=1)
176
- with gr.Row():
177
- doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
178
- source3_page = gr.Number(label="Page", scale=1)
179
  with gr.Row():
180
- msg = gr.Textbox(placeholder = 'Ask your question', container = True)
 
181
  with gr.Row():
182
- qa_chain_button = gr.Button('Start Chatbot')
183
- submit_btn = gr.Button('Submit')
184
- clear_button = gr.ClearButton([msg, chatbot])
185
-
186
- qa_chain_button.click(initialize_LLM, \
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  inputs=[vector_db], \
188
- outputs=[qa_chain]).then(lambda:[None,"",0,"",0,"",0], \
189
  inputs=None, \
190
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
191
- queue=False)
192
-
 
193
  msg.submit(conversation, \
194
  inputs=[qa_chain, msg, chatbot], \
195
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
@@ -198,12 +192,9 @@ def demo():
198
  inputs=[qa_chain, msg, chatbot], \
199
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
200
  queue=False)
201
- clear_button.click(lambda:[None,"",0,"",0,"",0], \
202
  inputs=None, \
203
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
204
  queue=False)
205
  demo.queue().launch(debug=True)
206
 
207
- if __name__ == "__main__":
208
- demo()
209
-
 
31
  max_tokens = 6000
32
  top_k = 3
33
 
34
+
35
+ def load_doc(list_file_path):
36
  # Processing for one document only
37
  # loader = PyPDFLoader(file_path)
38
  # pages = loader.load()
 
47
  doc_splits = text_splitter.split_documents(pages)
48
  return doc_splits
49
 
50
+
51
+
52
  # Create vector database
53
  def create_db(splits, collection_name):
54
  embedding = HuggingFaceEmbeddings()
 
70
  embedding_function=embedding)
71
  return vectordb
72
 
73
+
 
 
 
 
 
74
  # Initialize database
75
  def initialize_database(list_file_obj):
76
  # Create list of documents (when valid)
 
92
  # print('list_file_path: ', list_file_path)
93
  print('Collection name: ', collection_name)
94
  # Load document and create splits
95
+ doc_splits = load_doc(list_file_path)
96
  # Create or load vector database
97
  # global vector_db
98
  vector_db = create_db(doc_splits, collection_name)
 
119
  return qa_chain
120
 
121
 
 
 
 
 
 
 
 
 
122
  def conversation(qa_chain, message, history):
123
  formatted_chat_history = format_chat_history(message, history)
124
  #print("formatted_chat_history",formatted_chat_history)
 
143
  new_history = history + [(message, response_answer)]
144
  # return gr.update(value=""), new_history, response_sources[0], response_sources[1]
145
  return qa_chain, new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
 
 
 
 
146
 
147
  def demo():
148
+ with gr.Block() as demo:
149
  vector_db = gr.State()
150
  qa_chain = gr.State()
151
+ collection_name = gr.State()
152
+
 
153
  chatbot = gr.Chatbot(height=300)
154
+ with gr.Accordion("Advanced - Document references", open=False):
 
 
 
 
 
 
 
 
 
155
  with gr.Row():
156
+ doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
157
+ source1_page = gr.Number(label="Page", scale=1)
158
  with gr.Row():
159
+ doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
160
+ source2_page = gr.Number(label="Page", scale=1)
161
+ with gr.Row():
162
+ doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
163
+ source3_page = gr.Number(label="Page", scale=1)
164
+ with gr.Row():
165
+ msg = gr.Textbox(placeholder="Type message", container=True)
166
+ with gr.Row():
167
+ db_btn = gr.Button('Initialize database')
168
+ qachain_btn = gr.Button('Start chatbot')
169
+ submit_btn = gr.Button("Submit")
170
+ clear_btn = gr.ClearButton([msg, chatbot])
171
+
172
+ document = list_file_obj
173
+
174
+ #upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
175
+ db_btn.click(initialize_database, \
176
+ inputs=[document], \
177
+ outputs=[vector_db, collection_name])
178
+
179
+ qachain_btn.click(initialize_LLM, \
180
  inputs=[vector_db], \
181
+ outputs=[qa_chain]).then(lambda:[0], \
182
  inputs=None, \
183
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
184
+ queue=False)
185
+
186
+ # Chatbot events
187
  msg.submit(conversation, \
188
  inputs=[qa_chain, msg, chatbot], \
189
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
 
192
  inputs=[qa_chain, msg, chatbot], \
193
  outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
194
  queue=False)
195
+ clear_btn.click(lambda:[None,"",0,"",0,"",0], \
196
  inputs=None, \
197
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
198
  queue=False)
199
  demo.queue().launch(debug=True)
200