Update app.py
Browse files
app.py
CHANGED
@@ -73,7 +73,7 @@ def load_db():
|
|
73 |
|
74 |
|
75 |
# Initialize langchain LLM chain
|
76 |
-
def initialize_llmchain(
|
77 |
progress(0.1, desc="Initializing HF tokenizer...")
|
78 |
|
79 |
# HuggingFaceHub uses HF inference endpoints
|
@@ -140,11 +140,11 @@ def initialize_database(list_file_obj, chunk_size, chunk_overlap, progress=gr.Pr
|
|
140 |
return vector_db, collection_name, "Complete!"
|
141 |
|
142 |
|
143 |
-
def initialize_LLM(
|
144 |
# print("llm_option",llm_option)
|
145 |
llm_name = llm_model
|
146 |
print("llm_name: ",llm_name)
|
147 |
-
qa_chain = initialize_llmchain(
|
148 |
return qa_chain, "Complete!"
|
149 |
|
150 |
|
@@ -266,7 +266,7 @@ def demo():
|
|
266 |
inputs=[document, slider_chunk_size, slider_chunk_overlap], \
|
267 |
outputs=[vector_db, collection_name, db_progress])
|
268 |
qachain_btn.click(initialize_LLM, \
|
269 |
-
inputs=[
|
270 |
outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
|
271 |
inputs=None, \
|
272 |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|
|
|
73 |
|
74 |
|
75 |
# Initialize langchain LLM chain
|
76 |
+
def initialize_llmchain(temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
77 |
progress(0.1, desc="Initializing HF tokenizer...")
|
78 |
|
79 |
# HuggingFaceHub uses HF inference endpoints
|
|
|
140 |
return vector_db, collection_name, "Complete!"
|
141 |
|
142 |
|
143 |
+
def initialize_LLM(llm_temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
|
144 |
# print("llm_option",llm_option)
|
145 |
llm_name = llm_model
|
146 |
print("llm_name: ",llm_name)
|
147 |
+
qa_chain = initialize_llmchain(llm_temperature, max_tokens, top_k, vector_db, progress)
|
148 |
return qa_chain, "Complete!"
|
149 |
|
150 |
|
|
|
266 |
inputs=[document, slider_chunk_size, slider_chunk_overlap], \
|
267 |
outputs=[vector_db, collection_name, db_progress])
|
268 |
qachain_btn.click(initialize_LLM, \
|
269 |
+
inputs=[slider_temperature, slider_maxtokens, slider_topk, vector_db], \
|
270 |
outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
|
271 |
inputs=None, \
|
272 |
outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
|