Update app.py
Browse files
app.py
CHANGED
@@ -31,7 +31,8 @@ temperature = 0.1
|
|
31 |
max_tokens = 6000
|
32 |
top_k = 3
|
33 |
|
34 |
-
|
|
|
35 |
# Processing for one document only
|
36 |
# loader = PyPDFLoader(file_path)
|
37 |
# pages = loader.load()
|
@@ -46,6 +47,8 @@ def load_doc(list_file_path, chunk_size, chunk_overlap):
|
|
46 |
doc_splits = text_splitter.split_documents(pages)
|
47 |
return doc_splits
|
48 |
|
|
|
|
|
49 |
# Create vector database
|
50 |
def create_db(splits, collection_name):
|
51 |
embedding = HuggingFaceEmbeddings()
|
@@ -67,12 +70,7 @@ def load_db():
|
|
67 |
embedding_function=embedding)
|
68 |
return vectordb
|
69 |
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
#list_file_obj = document
|
75 |
-
|
76 |
# Initialize database
|
77 |
def initialize_database(list_file_obj):
|
78 |
# Create list of documents (when valid)
|
@@ -94,7 +92,7 @@ def initialize_database(list_file_obj):
|
|
94 |
# print('list_file_path: ', list_file_path)
|
95 |
print('Collection name: ', collection_name)
|
96 |
# Load document and create splits
|
97 |
-
doc_splits = load_doc(list_file_path
|
98 |
# Create or load vector database
|
99 |
# global vector_db
|
100 |
vector_db = create_db(doc_splits, collection_name)
|
@@ -121,14 +119,6 @@ def initialize_LLM(vector_db):
|
|
121 |
return qa_chain
|
122 |
|
123 |
|
124 |
-
def format_chat_history(message, chat_history):
|
125 |
-
formatted_chat_history = []
|
126 |
-
for user_message, bot_message in chat_history:
|
127 |
-
formatted_chat_history.append(f"User: {user_message}")
|
128 |
-
formatted_chat_history.append(f"Assistant: {bot_message}")
|
129 |
-
return formatted_chat_history
|
130 |
-
|
131 |
-
|
132 |
def conversation(qa_chain, message, history):
|
133 |
formatted_chat_history = format_chat_history(message, history)
|
134 |
#print("formatted_chat_history",formatted_chat_history)
|
@@ -153,43 +143,47 @@ def conversation(qa_chain, message, history):
|
|
153 |
new_history = history + [(message, response_answer)]
|
154 |
# return gr.update(value=""), new_history, response_sources[0], response_sources[1]
|
155 |
return qa_chain, new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
|
156 |
-
|
157 |
-
#document = os.listdir(list_file_obj)
|
158 |
-
#qa_chain =
|
159 |
-
|
160 |
|
161 |
def demo():
|
162 |
-
with gr.
|
163 |
vector_db = gr.State()
|
164 |
qa_chain = gr.State()
|
165 |
-
collection_name = gr.State()
|
166 |
-
|
167 |
-
vector_db, collection_name = initialize_database(list_file_obj)
|
168 |
chatbot = gr.Chatbot(height=300)
|
169 |
-
with gr.Accordion(
|
170 |
-
with gr.Row():
|
171 |
-
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
|
172 |
-
source1_page = gr.Number(label="Page", scale=1)
|
173 |
-
with gr.Row():
|
174 |
-
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
|
175 |
-
source2_page = gr.Number(label="Page", scale=1)
|
176 |
-
with gr.Row():
|
177 |
-
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
|
178 |
-
source3_page = gr.Number(label="Page", scale=1)
|
179 |
with gr.Row():
|
180 |
-
|
|
|
181 |
with gr.Row():
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
inputs=[vector_db], \
|
188 |
-
outputs=[qa_chain]).then(lambda:[
|
189 |
inputs=None, \
|
190 |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
191 |
-
queue=False)
|
192 |
-
|
|
|
193 |
msg.submit(conversation, \
|
194 |
inputs=[qa_chain, msg, chatbot], \
|
195 |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
@@ -198,12 +192,9 @@ def demo():
|
|
198 |
inputs=[qa_chain, msg, chatbot], \
|
199 |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
200 |
queue=False)
|
201 |
-
|
202 |
inputs=None, \
|
203 |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
204 |
queue=False)
|
205 |
demo.queue().launch(debug=True)
|
206 |
|
207 |
-
if __name__ == "__main__":
|
208 |
-
demo()
|
209 |
-
|
|
|
31 |
max_tokens = 6000
|
32 |
top_k = 3
|
33 |
|
34 |
+
|
35 |
+
def load_doc(list_file_path):
|
36 |
# Processing for one document only
|
37 |
# loader = PyPDFLoader(file_path)
|
38 |
# pages = loader.load()
|
|
|
47 |
doc_splits = text_splitter.split_documents(pages)
|
48 |
return doc_splits
|
49 |
|
50 |
+
|
51 |
+
|
52 |
# Create vector database
|
53 |
def create_db(splits, collection_name):
|
54 |
embedding = HuggingFaceEmbeddings()
|
|
|
70 |
embedding_function=embedding)
|
71 |
return vectordb
|
72 |
|
73 |
+
|
|
|
|
|
|
|
|
|
|
|
74 |
# Initialize database
|
75 |
def initialize_database(list_file_obj):
|
76 |
# Create list of documents (when valid)
|
|
|
92 |
# print('list_file_path: ', list_file_path)
|
93 |
print('Collection name: ', collection_name)
|
94 |
# Load document and create splits
|
95 |
+
doc_splits = load_doc(list_file_path)
|
96 |
# Create or load vector database
|
97 |
# global vector_db
|
98 |
vector_db = create_db(doc_splits, collection_name)
|
|
|
119 |
return qa_chain
|
120 |
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
def conversation(qa_chain, message, history):
|
123 |
formatted_chat_history = format_chat_history(message, history)
|
124 |
#print("formatted_chat_history",formatted_chat_history)
|
|
|
143 |
new_history = history + [(message, response_answer)]
|
144 |
# return gr.update(value=""), new_history, response_sources[0], response_sources[1]
|
145 |
return qa_chain, new_history, response_source1, response_source1_page, response_source2, response_source2_page, response_source3, response_source3_page
|
|
|
|
|
|
|
|
|
146 |
|
147 |
def demo():
|
148 |
+
with gr.Block() as demo:
|
149 |
vector_db = gr.State()
|
150 |
qa_chain = gr.State()
|
151 |
+
collection_name = gr.State()
|
152 |
+
|
|
|
153 |
chatbot = gr.Chatbot(height=300)
|
154 |
+
with gr.Accordion("Advanced - Document references", open=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
with gr.Row():
|
156 |
+
doc_source1 = gr.Textbox(label="Reference 1", lines=2, container=True, scale=20)
|
157 |
+
source1_page = gr.Number(label="Page", scale=1)
|
158 |
with gr.Row():
|
159 |
+
doc_source2 = gr.Textbox(label="Reference 2", lines=2, container=True, scale=20)
|
160 |
+
source2_page = gr.Number(label="Page", scale=1)
|
161 |
+
with gr.Row():
|
162 |
+
doc_source3 = gr.Textbox(label="Reference 3", lines=2, container=True, scale=20)
|
163 |
+
source3_page = gr.Number(label="Page", scale=1)
|
164 |
+
with gr.Row():
|
165 |
+
msg = gr.Textbox(placeholder="Type message", container=True)
|
166 |
+
with gr.Row():
|
167 |
+
db_btn = gr.Button('Initialize database')
|
168 |
+
qachain_btn = gr.Button('Start chatbot')
|
169 |
+
submit_btn = gr.Button("Submit")
|
170 |
+
clear_btn = gr.ClearButton([msg, chatbot])
|
171 |
+
|
172 |
+
document = list_file_obj
|
173 |
+
|
174 |
+
#upload_btn.upload(upload_file, inputs=[upload_btn], outputs=[document])
|
175 |
+
db_btn.click(initialize_database, \
|
176 |
+
inputs=[document], \
|
177 |
+
outputs=[vector_db, collection_name])
|
178 |
+
|
179 |
+
qachain_btn.click(initialize_LLM, \
|
180 |
inputs=[vector_db], \
|
181 |
+
outputs=[qa_chain]).then(lambda:[0], \
|
182 |
inputs=None, \
|
183 |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
184 |
+
queue=False)
|
185 |
+
|
186 |
+
# Chatbot events
|
187 |
msg.submit(conversation, \
|
188 |
inputs=[qa_chain, msg, chatbot], \
|
189 |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
|
|
192 |
inputs=[qa_chain, msg, chatbot], \
|
193 |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
194 |
queue=False)
|
195 |
+
clear_btn.click(lambda:[None,"",0,"",0,"",0], \
|
196 |
inputs=None, \
|
197 |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
198 |
queue=False)
|
199 |
demo.queue().launch(debug=True)
|
200 |
|
|
|
|
|
|