Update app.py
Browse files
app.py
CHANGED
@@ -311,7 +311,7 @@ if not os.path.exists(f"{CFG.Embeddings_path}/index.faiss"):
|
|
311 |
embeddings = HuggingFaceInstructEmbeddings(model_name=CFG.embeddings_model_repo, model_kwargs={"device":"cuda"})
|
312 |
vectordb = FAISS.load_local(f"{CFG.Output_folder}/faiss_index_ml_papers", embeddings, allow_dangerous_deserialization=True)
|
313 |
|
314 |
-
|
315 |
def build_model(model_repo=CFG.model_name):
|
316 |
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
317 |
model = AutoModelForCausalLM.from_pretrained(model_repo, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
|
@@ -340,16 +340,37 @@ Question: {question}
|
|
340 |
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
|
341 |
|
342 |
retriever = vectordb.as_retriever(search_type="similarity", search_kwargs={"k": CFG.k})
|
343 |
-
|
344 |
|
345 |
def process_llm_response(llm_response):
|
346 |
ans = textwrap.fill(llm_response['result'], width=1500)
|
347 |
sources_used = ' \n'.join([f"{source.metadata['source'].split('/')[-1][:-4]} - page: {str(source.metadata['page'])}" for source in llm_response['source_documents']])
|
348 |
return f"{ans}\n\nSources:\n{sources_used}"
|
349 |
|
350 |
-
|
|
|
|
|
|
|
|
|
|
|
351 |
def llm_ans(message, history):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
352 |
llm_response = qa_chain.invoke(message)
|
353 |
return process_llm_response(llm_response)
|
354 |
|
355 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
embeddings = HuggingFaceInstructEmbeddings(model_name=CFG.embeddings_model_repo, model_kwargs={"device":"cuda"})
|
312 |
vectordb = FAISS.load_local(f"{CFG.Output_folder}/faiss_index_ml_papers", embeddings, allow_dangerous_deserialization=True)
|
313 |
|
314 |
+
|
315 |
def build_model(model_repo=CFG.model_name):
|
316 |
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
317 |
model = AutoModelForCausalLM.from_pretrained(model_repo, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
|
|
|
340 |
PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
|
341 |
|
342 |
retriever = vectordb.as_retriever(search_type="similarity", search_kwargs={"k": CFG.k})
|
343 |
+
|
344 |
|
345 |
def process_llm_response(llm_response):
|
346 |
ans = textwrap.fill(llm_response['result'], width=1500)
|
347 |
sources_used = ' \n'.join([f"{source.metadata['source'].split('/')[-1][:-4]} - page: {str(source.metadata['page'])}" for source in llm_response['source_documents']])
|
348 |
return f"{ans}\n\nSources:\n{sources_used}"
|
349 |
|
350 |
+
|
351 |
+
|
352 |
+
|
353 |
+
|
354 |
+
|
355 |
+
@spaces.GPU
|
356 |
def llm_ans(message, history):
|
357 |
+
tok, model = build_model()
|
358 |
+
terminators = [tok.eos_token_id, 32007, 32011, 32001, 32000]
|
359 |
+
pipe = pipeline(task="text-generation", model=model, tokenizer=tok, eos_token_id=terminators, do_sample=True, max_new_tokens=CFG.max_new_tokens, temperature=CFG.temperature, top_p=CFG.top_p, repetition_penalty=CFG.repetition_penalty)
|
360 |
+
llm = HuggingFacePipeline(pipeline=pipe)
|
361 |
+
qa_chain = RetrievalQA(llm=llm, retriever=retriever, prompt_template=PROMPT, return_source_documents=True, verbose=False)
|
362 |
+
|
363 |
+
|
364 |
llm_response = qa_chain.invoke(message)
|
365 |
return process_llm_response(llm_response)
|
366 |
|
367 |
+
|
368 |
+
demo = gr.ChatInterface(
|
369 |
+
fn=llm_ans,
|
370 |
+
examples=[["Write me a poem about Machine Learning."]],
|
371 |
+
# multimodal=False,
|
372 |
+
stop_btn="Stop Generation",
|
373 |
+
title="Chat With LLMs",
|
374 |
+
description="Now Running Phi3-ORPO",
|
375 |
+
)
|
376 |
+
demo.launch()
|