Pijush2023 commited on
Commit
61e3841
·
verified ·
1 Parent(s): 13eb1f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -9
app.py CHANGED
@@ -356,15 +356,27 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
356
  prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
357
 
358
  if retrieval_mode == "VDB":
359
- qa_chain = RetrievalQA.from_chain_type(
360
- llm=selected_model,
361
- chain_type="stuff",
362
- retriever=retriever,
363
- chain_type_kwargs={"prompt": prompt_template}
364
- )
365
- response = qa_chain({"query": message})
366
- logging.debug(f"Vector response: {response}")
367
- return response['result'], extract_addresses(response['result'])
 
 
 
 
 
 
 
 
 
 
 
 
368
  elif retrieval_mode == "KGF":
369
  response = chain_neo4j.invoke({"question": message})
370
  logging.debug(f"Knowledge-Graph response: {response}")
@@ -373,6 +385,7 @@ def generate_answer(message, choice, retrieval_mode, selected_model):
373
  return "Invalid retrieval mode selected.", []
374
 
375
 
 
376
  # def bot(history, choice, tts_choice, retrieval_mode):
377
  # if not history:
378
  # return history
 
356
  prompt_template = QA_CHAIN_PROMPT_1 if choice == "Details" else QA_CHAIN_PROMPT_2
357
 
358
  if retrieval_mode == "VDB":
359
+ if selected_model == "GPT-4o":
360
+ # Use Langchain with GPT-4o
361
+ qa_chain = RetrievalQA.from_chain_type(
362
+ llm=chat_model,
363
+ chain_type="stuff",
364
+ retriever=retriever,
365
+ chain_type_kwargs={"prompt": prompt_template}
366
+ )
367
+ response = qa_chain({"query": message})
368
+ logging.debug(f"Vector response: {response}")
369
+ return response['result'], extract_addresses(response['result'])
370
+ elif selected_model == "Phi-3.5":
371
+ # Directly use the Phi-3.5 model for text generation
372
+ response = selected_model(message, **{
373
+ "max_new_tokens": 500,
374
+ "return_full_text": False,
375
+ "temperature": 0.0,
376
+ "do_sample": False,
377
+ })[0]['generated_text']
378
+ logging.debug(f"Phi-3.5 response: {response}")
379
+ return response, extract_addresses(response)
380
  elif retrieval_mode == "KGF":
381
  response = chain_neo4j.invoke({"question": message})
382
  logging.debug(f"Knowledge-Graph response: {response}")
 
385
  return "Invalid retrieval mode selected.", []
386
 
387
 
388
+
389
  # def bot(history, choice, tts_choice, retrieval_mode):
390
  # if not history:
391
  # return history