Update app.py
Browse files
app.py
CHANGED
@@ -101,6 +101,19 @@ def initialize_database(list_file_obj):
|
|
101 |
return vector_db, collection_name
|
102 |
|
103 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
104 |
def initialize_LLM(vector_db):
|
105 |
# print("llm_option",llm_option)
|
106 |
llm_name = llm_model
|
@@ -146,27 +159,13 @@ def conversation(qa_chain, message, history):
|
|
146 |
|
147 |
|
148 |
|
149 |
-
|
150 |
def demo():
|
151 |
with gr.Blocks(theme='base') as demo:
|
152 |
#vector_db = gr.State()
|
153 |
#qa_chain = gr.State()
|
154 |
-
#collection_name = gr.State()
|
155 |
-
|
156 |
-
|
157 |
-
vector_db, collection_name = initialize_database(list_file_obj)
|
158 |
-
|
159 |
-
# Initialize langchain LLM chain
|
160 |
-
llm = HuggingFaceHub(repo_id = llm_model,
|
161 |
-
model_kwargs={"temperature": temperature,
|
162 |
-
"max_new_tokens": max_tokens,
|
163 |
-
"top_k": top_k,
|
164 |
-
"load_in_8bit": True})
|
165 |
-
retriever=vector_db.as_retriever()
|
166 |
-
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
|
167 |
-
qa_chain = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,chain_type="stuff",
|
168 |
-
memory=memory,return_source_documents=True,verbose=False,)
|
169 |
|
|
|
170 |
chatbot = gr.Chatbot(height=300)
|
171 |
with gr.Accordion('References', open=True):
|
172 |
with gr.Row():
|
@@ -181,9 +180,17 @@ def demo():
|
|
181 |
with gr.Row():
|
182 |
msg = gr.Textbox(placeholder = 'Ask your question', container = True)
|
183 |
with gr.Row():
|
|
|
184 |
submit_btn = gr.Button('Submit')
|
185 |
clear_button = gr.ClearButton([msg, chatbot])
|
186 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
msg.submit(conversation, \
|
188 |
inputs=[qa_chain, msg, chatbot], \
|
189 |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
|
|
101 |
return vector_db, collection_name
|
102 |
|
103 |
|
104 |
+
def initialize_llmchain(vector_db):
|
105 |
+
# Initialize langchain LLM chain
|
106 |
+
llm = HuggingFaceHub(repo_id = llm_model,model_kwargs={"temperature": temperature,
|
107 |
+
"max_new_tokens": max_tokens,
|
108 |
+
"top_k": top_k,
|
109 |
+
"load_in_8bit": True})
|
110 |
+
retriever=vector_db.as_retriever()
|
111 |
+
memory = ConversationBufferMemory(memory_key="chat_history", output_key='answer', return_messages=True)
|
112 |
+
qa_chain = ConversationalRetrievalChain.from_llm(llm,retriever=retriever,chain_type="stuff",
|
113 |
+
memory=memory,return_source_documents=True,verbose=False,)
|
114 |
+
|
115 |
+
return qa_chain
|
116 |
+
|
117 |
def initialize_LLM(vector_db):
|
118 |
# print("llm_option",llm_option)
|
119 |
llm_name = llm_model
|
|
|
159 |
|
160 |
|
161 |
|
|
|
162 |
def demo():
|
163 |
with gr.Blocks(theme='base') as demo:
|
164 |
#vector_db = gr.State()
|
165 |
#qa_chain = gr.State()
|
166 |
+
#collection_name = gr.State()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
+
vector_db, collection_name = initialize_database(list_file_obj)
|
169 |
chatbot = gr.Chatbot(height=300)
|
170 |
with gr.Accordion('References', open=True):
|
171 |
with gr.Row():
|
|
|
180 |
with gr.Row():
|
181 |
msg = gr.Textbox(placeholder = 'Ask your question', container = True)
|
182 |
with gr.Row():
|
183 |
+
qa_chain_button = gr.Button('Start Chatbot')
|
184 |
submit_btn = gr.Button('Submit')
|
185 |
clear_button = gr.ClearButton([msg, chatbot])
|
186 |
|
187 |
+
qa_chain_button.click(initialize_LLM, \
|
188 |
+
inputs=[vector_db], \
|
189 |
+
outputs=[qa_chain]).then(lambda:[None,"",0,"",0,"",0], \
|
190 |
+
inputs=None, \
|
191 |
+
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
192 |
+
queue=False)
|
193 |
+
|
194 |
msg.submit(conversation, \
|
195 |
inputs=[qa_chain, msg, chatbot], \
|
196 |
outputs=[qa_chain, msg, chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|