justinj92 commited on
Commit
1e9b57f
·
verified ·
1 Parent(s): 9f02f73

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -4
app.py CHANGED
@@ -311,7 +311,7 @@ if not os.path.exists(f"{CFG.Embeddings_path}/index.faiss"):
311
  embeddings = HuggingFaceInstructEmbeddings(model_name=CFG.embeddings_model_repo, model_kwargs={"device":"cuda"})
312
  vectordb = FAISS.load_local(f"{CFG.Output_folder}/faiss_index_ml_papers", embeddings, allow_dangerous_deserialization=True)
313
 
314
- @spaces.GPU
315
  def build_model(model_repo=CFG.model_name):
316
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
317
  model = AutoModelForCausalLM.from_pretrained(model_repo, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
@@ -340,16 +340,37 @@ Question: {question}
340
  PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
341
 
342
  retriever = vectordb.as_retriever(search_type="similarity", search_kwargs={"k": CFG.k})
343
- qa_chain = RetrievalQA(llm=llm, retriever=retriever, prompt_template=PROMPT, return_source_documents=True, verbose=False)
344
 
345
  def process_llm_response(llm_response):
346
  ans = textwrap.fill(llm_response['result'], width=1500)
347
  sources_used = ' \n'.join([f"{source.metadata['source'].split('/')[-1][:-4]} - page: {str(source.metadata['page'])}" for source in llm_response['source_documents']])
348
  return f"{ans}\n\nSources:\n{sources_used}"
349
 
350
- @gr.Interface(fn=process_llm_response, inputs=["text", "state"], outputs="text", title="Chat With LLMs", description="Now Running Phi3-ORPO")
 
 
 
 
 
351
  def llm_ans(message, history):
 
 
 
 
 
 
 
352
  llm_response = qa_chain.invoke(message)
353
  return process_llm_response(llm_response)
354
 
355
- llm_ans.launch()
 
 
 
 
 
 
 
 
 
 
311
  embeddings = HuggingFaceInstructEmbeddings(model_name=CFG.embeddings_model_repo, model_kwargs={"device":"cuda"})
312
  vectordb = FAISS.load_local(f"{CFG.Output_folder}/faiss_index_ml_papers", embeddings, allow_dangerous_deserialization=True)
313
 
314
+
315
  def build_model(model_repo=CFG.model_name):
316
  tokenizer = AutoTokenizer.from_pretrained(model_repo)
317
  model = AutoModelForCausalLM.from_pretrained(model_repo, attn_implementation="flash_attention_2", torch_dtype=torch.bfloat16)
 
340
  PROMPT = PromptTemplate(template=prompt_template, input_variables=["context", "question"])
341
 
342
  retriever = vectordb.as_retriever(search_type="similarity", search_kwargs={"k": CFG.k})
343
+
344
 
345
  def process_llm_response(llm_response):
346
  ans = textwrap.fill(llm_response['result'], width=1500)
347
  sources_used = ' \n'.join([f"{source.metadata['source'].split('/')[-1][:-4]} - page: {str(source.metadata['page'])}" for source in llm_response['source_documents']])
348
  return f"{ans}\n\nSources:\n{sources_used}"
349
 
350
+
351
+
352
+
353
+
354
+
355
+ @spaces.GPU
356
  def llm_ans(message, history):
357
+ tok, model = build_model()
358
+ terminators = [tok.eos_token_id, 32007, 32011, 32001, 32000]
359
+ pipe = pipeline(task="text-generation", model=model, tokenizer=tok, eos_token_id=terminators, do_sample=True, max_new_tokens=CFG.max_new_tokens, temperature=CFG.temperature, top_p=CFG.top_p, repetition_penalty=CFG.repetition_penalty)
360
+ llm = HuggingFacePipeline(pipeline=pipe)
361
+ qa_chain = RetrievalQA(llm=llm, retriever=retriever, prompt_template=PROMPT, return_source_documents=True, verbose=False)
362
+
363
+
364
  llm_response = qa_chain.invoke(message)
365
  return process_llm_response(llm_response)
366
 
367
+
368
+ demo = gr.ChatInterface(
369
+ fn=llm_ans,
370
+ examples=[["Write me a poem about Machine Learning."]],
371
+ # multimodal=False,
372
+ stop_btn="Stop Generation",
373
+ title="Chat With LLMs",
374
+ description="Now Running Phi3-ORPO",
375
+ )
376
+ demo.launch()