vishwask commited on
Commit
8f07cc0
·
verified ·
1 Parent(s): 8989755

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -54
app.py CHANGED
@@ -74,59 +74,17 @@ def load_db():
74
  # Initialize langchain LLM chain
75
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
76
  progress(0.1, desc="Initializing HF tokenizer...")
77
- # HuggingFacePipeline uses local model
78
- # Note: it will download model locally...
79
- # tokenizer=AutoTokenizer.from_pretrained(llm_model)
80
- # progress(0.5, desc="Initializing HF pipeline...")
81
- # pipeline=transformers.pipeline(
82
- # "text-generation",
83
- # model=llm_model,
84
- # tokenizer=tokenizer,
85
- # torch_dtype=torch.bfloat16,
86
- # trust_remote_code=True,
87
- # device_map="auto",
88
- # # max_length=1024,
89
- # max_new_tokens=max_tokens,
90
- # do_sample=True,
91
- # top_k=top_k,
92
- # num_return_sequences=1,
93
- # eos_token_id=tokenizer.eos_token_id
94
- # )
95
- # llm = HuggingFacePipeline(pipeline=pipeline, model_kwargs={'temperature': temperature})
96
 
97
  # HuggingFaceHub uses HF inference endpoints
98
  progress(0.5, desc="Initializing HF Hub...")
99
  # Use of trust_remote_code as model_kwargs
100
  # Warning: langchain issue
101
  # URL: https://github.com/langchain-ai/langchain/issues/6080
102
- if llm_model == "mistralai/Mixtral-8x7B-Instruct-v0.1":
103
- llm = HuggingFaceHub(
104
- repo_id=llm_model,
105
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "load_in_8bit": True}
106
- )
107
- elif llm_model == "microsoft/phi-2":
108
- raise gr.Error("phi-2 model requires 'trust_remote_code=True', currently not supported by langchain HuggingFaceHub...")
109
- llm = HuggingFaceHub(
110
- repo_id=llm_model,
111
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
112
- )
113
- elif llm_model == "TinyLlama/TinyLlama-1.1B-Chat-v1.0":
114
- llm = HuggingFaceHub(
115
- repo_id=llm_model,
116
- model_kwargs={"temperature": temperature, "max_new_tokens": 250, "top_k": top_k}
117
- )
118
- elif llm_model == "meta-llama/Llama-2-7b-chat-hf":
119
- raise gr.Error("Llama-2-7b-chat-hf model requires a Pro subscription...")
120
- llm = HuggingFaceHub(
121
- repo_id=llm_model,
122
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
123
- )
124
- else:
125
- llm = HuggingFaceHub(
126
- repo_id=llm_model,
127
- # model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k, "trust_remote_code": True, "torch_dtype": "auto"}
128
- model_kwargs={"temperature": temperature, "max_new_tokens": max_tokens, "top_k": top_k}
129
- )
130
 
131
  progress(0.75, desc="Defining buffer memory...")
132
  memory = ConversationBufferMemory(
@@ -239,7 +197,7 @@ def demo():
239
  qa_chain = gr.State()
240
  collection_name = gr.State()
241
  pdf_directory = '/home/user/app/pdfs'
242
-
243
 
244
  def process_pdfs():
245
  # List all PDF files in the directory
@@ -255,7 +213,7 @@ def demo():
255
  with gr.Row():
256
  # document = gr.Files(value = process_pdfs, height=100, file_count="multiple",visible=True,
257
  # file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
258
- document = gr.Files(**pdf_dict)
259
  with gr.Row():
260
  db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database",visible=False)
261
  with gr.Accordion("Advanced options - Document text splitter", open=False, visible=False):
@@ -269,9 +227,7 @@ def demo():
269
  db_btn = gr.Button("Generate vector database...")
270
 
271
 
272
- with gr.Row():
273
- llm_btn = gr.Radio(list_llm_simple, \
274
- label="LLM models", value = list_llm_simple[0], type="index", info="Choose your LLM model")
275
  with gr.Accordion("Advanced options - LLM model", open=False, visible=False):
276
  with gr.Row():
277
  slider_temperature = gr.Slider(value = 0.1,visible=False)
@@ -280,7 +236,7 @@ def demo():
280
  with gr.Row():
281
  slider_topk = gr.Slider(value = 3, visible=False)
282
  with gr.Row():
283
- llm_progress = gr.Textbox(value="None",label="QA chain initialization")
284
  with gr.Row():
285
  qachain_btn = gr.Button("Initialize question-answering chain...")
286
 
@@ -308,7 +264,7 @@ def demo():
308
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
309
  outputs=[vector_db, collection_name, db_progress])
310
  qachain_btn.click(initialize_LLM, \
311
- inputs=[llm_btn, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
312
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
313
  inputs=None, \
314
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \
 
74
  # Initialize langchain LLM chain
75
  def initialize_llmchain(llm_model, temperature, max_tokens, top_k, vector_db, progress=gr.Progress()):
76
  progress(0.1, desc="Initializing HF tokenizer...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  # HuggingFaceHub uses HF inference endpoints
79
  progress(0.5, desc="Initializing HF Hub...")
80
  # Use of trust_remote_code as model_kwargs
81
  # Warning: langchain issue
82
  # URL: https://github.com/langchain-ai/langchain/issues/6080
83
+
84
+ llm = HuggingFaceHub(repo_id=llm_model, model_kwargs={"temperature": temperature,
85
+ "max_new_tokens": max_tokens,
86
+ "top_k": top_k,
87
+ "load_in_8bit": True})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  progress(0.75, desc="Defining buffer memory...")
90
  memory = ConversationBufferMemory(
 
197
  qa_chain = gr.State()
198
  collection_name = gr.State()
199
  pdf_directory = '/home/user/app/pdfs'
200
+ llm_model = "mistralai/Mistral-7B-Instruct-v0.2"
201
 
202
  def process_pdfs():
203
  # List all PDF files in the directory
 
213
  with gr.Row():
214
  # document = gr.Files(value = process_pdfs, height=100, file_count="multiple",visible=True,
215
  # file_types=["pdf"], interactive=True, label="Upload your PDF documents (single or multiple)")
216
+ document = gr.Files(**pdf_dict, visible = False)
217
  with gr.Row():
218
  db_btn = gr.Radio(["ChromaDB"], label="Vector database type", value = "ChromaDB", type="index", info="Choose your vector database",visible=False)
219
  with gr.Accordion("Advanced options - Document text splitter", open=False, visible=False):
 
227
  db_btn = gr.Button("Generate vector database...")
228
 
229
 
230
+
 
 
231
  with gr.Accordion("Advanced options - LLM model", open=False, visible=False):
232
  with gr.Row():
233
  slider_temperature = gr.Slider(value = 0.1,visible=False)
 
236
  with gr.Row():
237
  slider_topk = gr.Slider(value = 3, visible=False)
238
  with gr.Row():
239
+ llm_progress = gr.Textbox(value="None",label="QA chain initialization", visible=False)
240
  with gr.Row():
241
  qachain_btn = gr.Button("Initialize question-answering chain...")
242
 
 
264
  inputs=[document, slider_chunk_size, slider_chunk_overlap], \
265
  outputs=[vector_db, collection_name, db_progress])
266
  qachain_btn.click(initialize_LLM, \
267
+ inputs=[llm_model, slider_temperature, slider_maxtokens, slider_topk, vector_db], \
268
  outputs=[qa_chain, llm_progress]).then(lambda:[None,"",0,"",0,"",0], \
269
  inputs=None, \
270
  outputs=[chatbot, doc_source1, source1_page, doc_source2, source2_page, doc_source3, source3_page], \