Pijush2023 commited on
Commit
b27de52
·
verified ·
1 Parent(s): 7cc3b4e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -120
app.py CHANGED
@@ -593,129 +593,152 @@ Detailed Answer:
593
 
594
  import traceback
595
 
596
- # def generate_answer(message, choice, retrieval_mode, selected_model):
597
- # logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
598
-
599
- # # Logic for disabling options for Phi-3.5
600
- # if selected_model == "LM-2":
601
- # choice = None
602
- # retrieval_mode = None
603
-
604
- # try:
605
- # # Select the appropriate template based on the choice and model
606
- # if choice == "Details" and selected_model == chat_model1: # GPT-4o-mini
607
- # prompt_template = PromptTemplate(input_variables=["context", "question"], template=gpt4o_mini_template_details)
608
- # elif choice == "Details":
609
- # prompt_template = QA_CHAIN_PROMPT_1
610
- # elif choice == "Conversational":
611
- # prompt_template = QA_CHAIN_PROMPT_2
612
- # else:
613
- # prompt_template = QA_CHAIN_PROMPT_1 # Fallback to template1
614
-
615
- # # # Handle hotel-related queries
616
- # # if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
617
- # # logging.debug("Handling hotel-related query")
618
- # # response = fetch_google_hotels()
619
- # # logging.debug(f"Hotel response: {response}")
620
- # # return response, extract_addresses(response)
621
-
622
- # # # Handle restaurant-related queries
623
- # # if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
624
- # # logging.debug("Handling restaurant-related query")
625
- # # response = fetch_yelp_restaurants()
626
- # # logging.debug(f"Restaurant response: {response}")
627
- # # return response, extract_addresses(response)
628
-
629
- # # # Handle flight-related queries
630
- # # if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
631
- # # logging.debug("Handling flight-related query")
632
- # # response = fetch_google_flights()
633
- # # logging.debug(f"Flight response: {response}")
634
- # # return response, extract_addresses(response)
635
-
636
- # # Retrieval-based response
637
- # if retrieval_mode == "VDB":
638
- # logging.debug("Using VDB retrieval mode")
639
- # if selected_model == chat_model:
640
- # logging.debug("Selected model: LM-1")
641
- # retriever = gpt_retriever
642
- # context = retriever.get_relevant_documents(message)
643
- # logging.debug(f"Retrieved context: {context}")
644
-
645
- # prompt = prompt_template.format(context=context, question=message)
646
- # logging.debug(f"Generated prompt: {prompt}")
647
-
648
- # qa_chain = RetrievalQA.from_chain_type(
649
- # llm=chat_model,
650
- # chain_type="stuff",
651
- # retriever=retriever,
652
- # chain_type_kwargs={"prompt": prompt_template}
653
- # )
654
- # response = qa_chain({"query": message})
655
- # logging.debug(f"LM-1 response: {response}")
656
- # return response['result'], extract_addresses(response['result'])
657
-
658
- # elif selected_model == chat_model1:
659
- # logging.debug("Selected model: LM-3")
660
- # retriever = gpt_retriever
661
- # context = retriever.get_relevant_documents(message)
662
- # logging.debug(f"Retrieved context: {context}")
663
-
664
- # prompt = prompt_template.format(context=context, question=message)
665
- # logging.debug(f"Generated prompt: {prompt}")
666
-
667
- # qa_chain = RetrievalQA.from_chain_type(
668
- # llm=chat_model1,
669
- # chain_type="stuff",
670
- # retriever=retriever,
671
- # chain_type_kwargs={"prompt": prompt_template}
672
- # )
673
- # response = qa_chain({"query": message})
674
- # logging.debug(f"LM-3 response: {response}")
675
- # return response['result'], extract_addresses(response['result'])
676
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
677
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
678
 
679
- # elif selected_model == phi_pipe:
680
- # logging.debug("Selected model: LM-2")
681
- # retriever = phi_retriever
682
- # context_documents = retriever.get_relevant_documents(message)
683
- # context = "\n".join([doc.page_content for doc in context_documents])
684
- # logging.debug(f"Retrieved context for LM-2: {context}")
685
-
686
- # # Use the correct template variable
687
- # prompt = phi_custom_template.format(context=context, question=message)
688
- # logging.debug(f"Generated LM-2 prompt: {prompt}")
689
-
690
- # response = selected_model(prompt, **{
691
- # "max_new_tokens": 250,
692
- # "return_full_text": True,
693
- # "temperature": 0.0,
694
- # "do_sample": False,
695
- # })
696
-
697
- # if response:
698
- # generated_text = response[0]['generated_text']
699
- # logging.debug(f"LM-2 Response: {generated_text}")
700
- # cleaned_response = clean_response(generated_text)
701
- # return cleaned_response, extract_addresses(cleaned_response)
702
- # else:
703
- # logging.error("LM-2 did not return any response.")
704
- # return "No response generated.", []
705
-
706
- # elif retrieval_mode == "KGF":
707
- # logging.debug("Using KGF retrieval mode")
708
- # response = chain_neo4j.invoke({"question": message})
709
- # logging.debug(f"KGF response: {response}")
710
- # return response, extract_addresses(response)
711
- # else:
712
- # logging.error("Invalid retrieval mode selected.")
713
- # return "Invalid retrieval mode selected.", []
714
-
715
- # except Exception as e:
716
- # logging.error(f"Error in generate_answer: {str(e)}")
717
- # logging.error(traceback.format_exc())
718
- # return "Sorry, I encountered an error while processing your request.", []
 
719
 
720
  def generate_answer(message, choice, retrieval_mode, selected_model, selected_file):
721
  # Ensure a file is selected
 
593
 
594
  import traceback
595
 
596
+ def generate_answer(message, choice, retrieval_mode, selected_model):
597
+ logging.debug(f"generate_answer called with choice: {choice}, retrieval_mode: {retrieval_mode}, and selected_model: {selected_model}")
598
+
599
+ # Logic for disabling options for Phi-3.5
600
+ if selected_model == "LM-2":
601
+ choice = None
602
+ retrieval_mode = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
603
 
604
+ try:
605
+ # Select the appropriate template based on the choice and model
606
+ if choice == "Details" and selected_model == chat_model1: # GPT-4o-mini
607
+ prompt_template = PromptTemplate(input_variables=["context", "question"], template=gpt4o_mini_template_details)
608
+ elif choice == "Details":
609
+ prompt_template = QA_CHAIN_PROMPT_1
610
+ elif choice == "Conversational":
611
+ prompt_template = QA_CHAIN_PROMPT_2
612
+ else:
613
+ prompt_template = QA_CHAIN_PROMPT_1 # Fallback to template1
614
+
615
+ # # Handle hotel-related queries
616
+ # if "hotel" in message.lower() or "hotels" in message.lower() and "birmingham" in message.lower():
617
+ # logging.debug("Handling hotel-related query")
618
+ # response = fetch_google_hotels()
619
+ # logging.debug(f"Hotel response: {response}")
620
+ # return response, extract_addresses(response)
621
+
622
+ # # Handle restaurant-related queries
623
+ # if "restaurant" in message.lower() or "restaurants" in message.lower() and "birmingham" in message.lower():
624
+ # logging.debug("Handling restaurant-related query")
625
+ # response = fetch_yelp_restaurants()
626
+ # logging.debug(f"Restaurant response: {response}")
627
+ # return response, extract_addresses(response)
628
+
629
+ # # Handle flight-related queries
630
+ # if "flight" in message.lower() or "flights" in message.lower() and "birmingham" in message.lower():
631
+ # logging.debug("Handling flight-related query")
632
+ # response = fetch_google_flights()
633
+ # logging.debug(f"Flight response: {response}")
634
+ # return response, extract_addresses(response)
635
+
636
+ # Retrieval-based response
637
+ if retrieval_mode == "VDB":
638
+ logging.debug("Using VDB retrieval mode")
639
+ if selected_model == chat_model:
640
+ logging.debug("Selected model: LM-1")
641
+ retriever = gpt_retriever
642
+ context = retriever.get_relevant_documents(message)
643
+ logging.debug(f"Retrieved context: {context}")
644
+
645
+ prompt = prompt_template.format(context=context, question=message)
646
+ logging.debug(f"Generated prompt: {prompt}")
647
+
648
+ qa_chain = RetrievalQA.from_chain_type(
649
+ llm=chat_model,
650
+ chain_type="stuff",
651
+ retriever=retriever,
652
+ chain_type_kwargs={"prompt": prompt_template}
653
+ )
654
+ response = qa_chain({"query": message})
655
+ logging.debug(f"LM-1 response: {response}")
656
+ return response['result'], extract_addresses(response['result'])
657
 
658
+ elif selected_model == chat_model1:
659
+ logging.debug("Selected model: LM-3")
660
+ retriever = gpt_retriever
661
+ context = retriever.get_relevant_documents(message)
662
+ logging.debug(f"Retrieved context: {context}")
663
+
664
+ prompt = prompt_template.format(context=context, question=message)
665
+ logging.debug(f"Generated prompt: {prompt}")
666
+
667
+ qa_chain = RetrievalQA.from_chain_type(
668
+ llm=chat_model1,
669
+ chain_type="stuff",
670
+ retriever=retriever,
671
+ chain_type_kwargs={"prompt": prompt_template}
672
+ )
673
+ response = qa_chain({"query": message})
674
+ logging.debug(f"LM-3 response: {response}")
675
+ return response['result'], extract_addresses(response['result'])
676
+ #-----------------------------------------------------------------------------------------------------------------
677
+
678
+ # Modify the Phi-3.5 prompt to include the selected file
679
+ elif selected_model == phi_pipe:
680
+ retriever = phi_retriever
681
+ context_documents = retriever.get_relevant_documents(message)
682
+ context = "\n".join([doc.page_content for doc in context_documents])
683
+
684
+ prompt = phi_custom_template.format(context=context, question=message, document_name=selected_file)
685
+ response = selected_model(prompt, **{
686
+ "max_new_tokens": 250,
687
+ "return_full_text": True,
688
+ "temperature": 0.0,
689
+ "do_sample": False,
690
+ })
691
+
692
+ if response:
693
+ generated_text = response[0]['generated_text']
694
+ cleaned_response = clean_response(generated_text)
695
+ return cleaned_response
696
+ else:
697
+ return "No response generated.", []
698
+
699
+
700
 
701
+ #------------------------------------------------------------------------------------------------------------
702
+ # elif selected_model == phi_pipe:
703
+ # logging.debug("Selected model: LM-2")
704
+ # retriever = phi_retriever
705
+ # context_documents = retriever.get_relevant_documents(message)
706
+ # context = "\n".join([doc.page_content for doc in context_documents])
707
+ # logging.debug(f"Retrieved context for LM-2: {context}")
708
+
709
+ # # Use the correct template variable
710
+ # prompt = phi_custom_template.format(context=context, question=message)
711
+ # logging.debug(f"Generated LM-2 prompt: {prompt}")
712
+
713
+ # response = selected_model(prompt, **{
714
+ # "max_new_tokens": 250,
715
+ # "return_full_text": True,
716
+ # "temperature": 0.0,
717
+ # "do_sample": False,
718
+ # })
719
+
720
+ # if response:
721
+ # generated_text = response[0]['generated_text']
722
+ # logging.debug(f"LM-2 Response: {generated_text}")
723
+ # cleaned_response = clean_response(generated_text)
724
+ # return cleaned_response, extract_addresses(cleaned_response)
725
+ # else:
726
+ # logging.error("LM-2 did not return any response.")
727
+ # return "No response generated.", []
728
+
729
+ elif retrieval_mode == "KGF":
730
+ logging.debug("Using KGF retrieval mode")
731
+ response = chain_neo4j.invoke({"question": message})
732
+ logging.debug(f"KGF response: {response}")
733
+ return response, extract_addresses(response)
734
+ else:
735
+ logging.error("Invalid retrieval mode selected.")
736
+ return "Invalid retrieval mode selected.", []
737
+
738
+ except Exception as e:
739
+ logging.error(f"Error in generate_answer: {str(e)}")
740
+ logging.error(traceback.format_exc())
741
+ return "Sorry, I encountered an error while processing your request.", []
742
 
743
  def generate_answer(message, choice, retrieval_mode, selected_model, selected_file):
744
  # Ensure a file is selected