Phoenix21 commited on
Commit
827152a
·
verified ·
1 Parent(s): 0eedb96

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +146 -145
pipeline.py CHANGED
@@ -600,163 +600,164 @@ def run_with_chain(query: str) -> str:
600
  3) If not refused, check if query is aggression/ethical => route to chain
601
  4) Otherwise classify => brand/wellness/out-of-scope => RAG => tailor
602
  """
603
- start_time = time.time()
604
- try:
605
- # 1) Validate
606
- if not query or query.strip() == "":
607
- return create_error_response("validation", "Empty query.")
608
- if len(query.strip()) < 2:
609
- return create_error_response("validation", "Too short.")
610
- words_in_text = re.findall(r'\b\w+\b', query.lower())
611
- if not any(w in english_words for w in words_in_text):
612
- return create_error_response("validation", "Unclear words.")
613
- if len(query) > 500:
614
- return create_error_response("validation", "Too long (>500).")
615
- if not handle_rate_limiting(pipeline_state):
616
- return create_error_response("rate_limit")
617
- # New: Check if the query is a greeting
618
- if is_greeting(query):
619
- greeting_response = "Hello there!! Welcome to Healthy AI Expert, How may I assist you today?"
620
- manage_cache(pipeline_state, query, greeting_response)
621
- pipeline_state.update_metrics(start_time)
622
- return greeting_response
623
-
624
- if not handle_rate_limiting(pipeline_state):
625
- return create_error_response("rate_limit")
626
-
627
- # Cache check
628
- cached = manage_cache(pipeline_state, query)
629
- if cached:
630
- pipeline_state.update_metrics(start_time, is_cache_hit=True)
631
- return cached
632
-
633
- # 2) Mistral moderation
634
  try:
635
- mod_res = moderate_text(query)
636
- severity = compute_moderation_severity(mod_res)
637
-
638
- # If self-harm => supportive
639
- if mod_res.categories.get("selfharm", False):
640
- logger.info("Self-harm flagged => providing supportive chain response.")
641
- selfharm_resp = pipeline_state.self_harm_chain.run({"query": query})
642
- final_tailored = pipeline_state.tailor_chain.run({"response": selfharm_resp}).strip()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
  manage_cache(pipeline_state, query, final_tailored)
644
  pipeline_state.update_metrics(start_time)
645
  return final_tailored
646
-
647
- # If hate => refuse
648
- if mod_res.categories.get("hate", False):
649
- logger.info("Hate content => refusal.")
650
- refusal_resp = pipeline_state.refusal_chain.run({"topic": "moderation_flagged"})
651
- manage_cache(pipeline_state, query, refusal_resp)
652
  pipeline_state.update_metrics(start_time)
653
- return refusal_resp
654
-
655
- # If "dangerous" or "violence" is flagged, we might still want to
656
- # provide a "non-violent advice" approach (like revenge queries).
657
- # So we won't automatically refuse. We'll rely on the
658
- # is_ethical_conflict() check below.
659
-
660
- except Exception as e:
661
- logger.error(f"Moderation error: {e}")
662
- severity = 0.0
663
-
664
- # 3) Check for aggression or ethical conflict
665
- if is_aggressive_or_harsh(query):
666
- logger.info("Detected harsh/aggressive language => frustration_chain.")
667
- frustration_resp = pipeline_state.frustration_chain.run({"query": query})
668
- final_tailored = pipeline_state.tailor_chain.run({"response": frustration_resp}).strip()
669
- manage_cache(pipeline_state, query, final_tailored)
670
- pipeline_state.update_metrics(start_time)
671
- return final_tailored
672
-
673
- if is_ethical_conflict(query):
674
- logger.info("Detected ethical dilemma => ethical_conflict_chain.")
675
- ethical_resp = pipeline_state.ethical_conflict_chain.run({"query": query})
676
- final_tailored = pipeline_state.tailor_chain.run({"response": ethical_resp}).strip()
677
- manage_cache(pipeline_state, query, final_tailored)
678
- pipeline_state.update_metrics(start_time)
679
- return final_tailored
680
-
681
- # 4) Standard path: classification => brand/wellness/out-of-scope
682
- try:
683
- class_out = pipeline_state.classification_chain.run({"query": query})
684
- classification = class_out.strip().lower()
685
- except Exception as e:
686
- logger.error(f"Classification error: {e}")
687
- if not pipeline_state.handle_error(e):
688
- return create_error_response("processing", "Classification error.")
689
- return create_error_response("processing")
690
-
691
- if classification in ["outofscope", "out_of_scope"]:
692
  try:
693
- # Politely refuse if truly out-of-scope
694
- refusal_text = pipeline_state.refusal_chain.run({"topic": query})
695
- tailored_refusal = pipeline_state.tailor_chain.run({"response": refusal_text}).strip()
696
- manage_cache(pipeline_state, query, tailored_refusal)
697
- pipeline_state.update_metrics(start_time)
698
- return tailored_refusal
699
  except Exception as e:
700
- logger.error(f"Refusal chain error: {e}")
701
  if not pipeline_state.handle_error(e):
702
- return create_error_response("processing", "Refusal error.")
703
  return create_error_response("processing")
704
-
705
- # brand vs wellness
706
- if classification == "brand":
707
- rag_chain_main = pipeline_state.brand_rag_chain
708
- # rag_chain_fallback = pipeline_state.brand_rag_chain_fallback
709
- else:
710
- rag_chain_main = pipeline_state.wellness_rag_chain
711
- # rag_chain_fallback = pipeline_state.wellness_rag_chain_fallback
712
-
713
- # RAG with fallback
714
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
715
  try:
716
- rag_output = rag_chain_main({"query": query})
717
- except Exception as e_main:
718
- if "resource exhausted" in str(e_main).lower():
719
- logger.warning("Gemini resource exhausted. Falling back to Groq.")
720
- # rag_output = rag_chain_fallback({"query": query})
 
 
 
 
 
 
721
  else:
722
- raise
723
-
724
- if isinstance(rag_output, dict) and "result" in rag_output:
725
- csv_ans = rag_output["result"].strip()
726
- else:
727
- csv_ans = str(rag_output).strip()
728
-
729
- # If not enough => web
730
- if "not enough context" in csv_ans.lower() or len(csv_ans) < 40:
731
- logger.info("Insufficient RAG => web search.")
732
- web_info = do_web_search(query)
733
- if web_info:
734
- csv_ans += f"\n\nAdditional info:\n{web_info}"
735
- except Exception as e:
736
- logger.error(f"RAG error: {e}")
737
- if not pipeline_state.handle_error(e):
738
- return create_error_response("processing", "RAG error.")
739
- return create_error_response("processing")
740
-
741
- # Tailor final
742
- try:
743
- final_tailored = pipeline_state.tailor_chainWellnessBrand.run({"response": csv_ans}).strip()
744
- if severity > 0.5:
745
- final_tailored += "\n\n(Please note: This may involve sensitive content.)"
746
-
747
- manage_cache(pipeline_state, query, final_tailored)
748
- pipeline_state.update_metrics(start_time)
749
- return final_tailored
 
750
  except Exception as e:
751
- logger.error(f"Tailor chain error: {e}")
752
- if not pipeline_state.handle_error(e):
753
- return create_error_response("processing", "Tailoring error.")
754
- return create_error_response("processing")
755
-
756
- except Exception as e:
757
- logger.error(f"Critical error in run_with_chain: {e}")
758
- pipeline_state.metrics.errors += 1
759
- return create_error_response("general")
760
 
761
  # -------------------------------------------------------
762
  # Health & Utility
 
600
  3) If not refused, check if query is aggression/ethical => route to chain
601
  4) Otherwise classify => brand/wellness/out-of-scope => RAG => tailor
602
  """
603
+ with tracer.new_trace(name="wellness_pipeline_run") as run:
604
+ start_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
605
  try:
606
+ # 1) Validate
607
+ if not query or query.strip() == "":
608
+ return create_error_response("validation", "Empty query.")
609
+ if len(query.strip()) < 2:
610
+ return create_error_response("validation", "Too short.")
611
+ words_in_text = re.findall(r'\b\w+\b', query.lower())
612
+ if not any(w in english_words for w in words_in_text):
613
+ return create_error_response("validation", "Unclear words.")
614
+ if len(query) > 500:
615
+ return create_error_response("validation", "Too long (>500).")
616
+ if not handle_rate_limiting(pipeline_state):
617
+ return create_error_response("rate_limit")
618
+ # New: Check if the query is a greeting
619
+ if is_greeting(query):
620
+ greeting_response = "Hello there!! Welcome to Healthy AI Expert, How may I assist you today?"
621
+ manage_cache(pipeline_state, query, greeting_response)
622
+ pipeline_state.update_metrics(start_time)
623
+ return greeting_response
624
+
625
+ if not handle_rate_limiting(pipeline_state):
626
+ return create_error_response("rate_limit")
627
+
628
+ # Cache check
629
+ cached = manage_cache(pipeline_state, query)
630
+ if cached:
631
+ pipeline_state.update_metrics(start_time, is_cache_hit=True)
632
+ return cached
633
+
634
+ # 2) Mistral moderation
635
+ try:
636
+ mod_res = moderate_text(query)
637
+ severity = compute_moderation_severity(mod_res)
638
+
639
+ # If self-harm => supportive
640
+ if mod_res.categories.get("selfharm", False):
641
+ logger.info("Self-harm flagged => providing supportive chain response.")
642
+ selfharm_resp = pipeline_state.self_harm_chain.run({"query": query})
643
+ final_tailored = pipeline_state.tailor_chain.run({"response": selfharm_resp}).strip()
644
+ manage_cache(pipeline_state, query, final_tailored)
645
+ pipeline_state.update_metrics(start_time)
646
+ return final_tailored
647
+
648
+ # If hate => refuse
649
+ if mod_res.categories.get("hate", False):
650
+ logger.info("Hate content => refusal.")
651
+ refusal_resp = pipeline_state.refusal_chain.run({"topic": "moderation_flagged"})
652
+ manage_cache(pipeline_state, query, refusal_resp)
653
+ pipeline_state.update_metrics(start_time)
654
+ return refusal_resp
655
+
656
+ # If "dangerous" or "violence" is flagged, we might still want to
657
+ # provide a "non-violent advice" approach (like revenge queries).
658
+ # So we won't automatically refuse. We'll rely on the
659
+ # is_ethical_conflict() check below.
660
+
661
+ except Exception as e:
662
+ logger.error(f"Moderation error: {e}")
663
+ severity = 0.0
664
+
665
+ # 3) Check for aggression or ethical conflict
666
+ if is_aggressive_or_harsh(query):
667
+ logger.info("Detected harsh/aggressive language => frustration_chain.")
668
+ frustration_resp = pipeline_state.frustration_chain.run({"query": query})
669
+ final_tailored = pipeline_state.tailor_chain.run({"response": frustration_resp}).strip()
670
  manage_cache(pipeline_state, query, final_tailored)
671
  pipeline_state.update_metrics(start_time)
672
  return final_tailored
673
+
674
+ if is_ethical_conflict(query):
675
+ logger.info("Detected ethical dilemma => ethical_conflict_chain.")
676
+ ethical_resp = pipeline_state.ethical_conflict_chain.run({"query": query})
677
+ final_tailored = pipeline_state.tailor_chain.run({"response": ethical_resp}).strip()
678
+ manage_cache(pipeline_state, query, final_tailored)
679
  pipeline_state.update_metrics(start_time)
680
+ return final_tailored
681
+
682
+ # 4) Standard path: classification => brand/wellness/out-of-scope
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
683
  try:
684
+ class_out = pipeline_state.classification_chain.run({"query": query})
685
+ classification = class_out.strip().lower()
 
 
 
 
686
  except Exception as e:
687
+ logger.error(f"Classification error: {e}")
688
  if not pipeline_state.handle_error(e):
689
+ return create_error_response("processing", "Classification error.")
690
  return create_error_response("processing")
691
+
692
+ if classification in ["outofscope", "out_of_scope"]:
693
+ try:
694
+ # Politely refuse if truly out-of-scope
695
+ refusal_text = pipeline_state.refusal_chain.run({"topic": query})
696
+ tailored_refusal = pipeline_state.tailor_chain.run({"response": refusal_text}).strip()
697
+ manage_cache(pipeline_state, query, tailored_refusal)
698
+ pipeline_state.update_metrics(start_time)
699
+ return tailored_refusal
700
+ except Exception as e:
701
+ logger.error(f"Refusal chain error: {e}")
702
+ if not pipeline_state.handle_error(e):
703
+ return create_error_response("processing", "Refusal error.")
704
+ return create_error_response("processing")
705
+
706
+ # brand vs wellness
707
+ if classification == "brand":
708
+ rag_chain_main = pipeline_state.brand_rag_chain
709
+ # rag_chain_fallback = pipeline_state.brand_rag_chain_fallback
710
+ else:
711
+ rag_chain_main = pipeline_state.wellness_rag_chain
712
+ # rag_chain_fallback = pipeline_state.wellness_rag_chain_fallback
713
+
714
+ # RAG with fallback
715
  try:
716
+ try:
717
+ rag_output = rag_chain_main({"query": query})
718
+ except Exception as e_main:
719
+ if "resource exhausted" in str(e_main).lower():
720
+ logger.warning("Gemini resource exhausted. Falling back to Groq.")
721
+ # rag_output = rag_chain_fallback({"query": query})
722
+ else:
723
+ raise
724
+
725
+ if isinstance(rag_output, dict) and "result" in rag_output:
726
+ csv_ans = rag_output["result"].strip()
727
  else:
728
+ csv_ans = str(rag_output).strip()
729
+
730
+ # If not enough => web
731
+ if "not enough context" in csv_ans.lower() or len(csv_ans) < 40:
732
+ logger.info("Insufficient RAG => web search.")
733
+ web_info = do_web_search(query)
734
+ if web_info:
735
+ csv_ans += f"\n\nAdditional info:\n{web_info}"
736
+ except Exception as e:
737
+ logger.error(f"RAG error: {e}")
738
+ if not pipeline_state.handle_error(e):
739
+ return create_error_response("processing", "RAG error.")
740
+ return create_error_response("processing")
741
+
742
+ # Tailor final
743
+ try:
744
+ final_tailored = pipeline_state.tailor_chainWellnessBrand.run({"response": csv_ans}).strip()
745
+ if severity > 0.5:
746
+ final_tailored += "\n\n(Please note: This may involve sensitive content.)"
747
+
748
+ manage_cache(pipeline_state, query, final_tailored)
749
+ pipeline_state.update_metrics(start_time)
750
+ return final_tailored
751
+ except Exception as e:
752
+ logger.error(f"Tailor chain error: {e}")
753
+ if not pipeline_state.handle_error(e):
754
+ return create_error_response("processing", "Tailoring error.")
755
+ return create_error_response("processing")
756
+
757
  except Exception as e:
758
+ logger.error(f"Critical error in run_with_chain: {e}")
759
+ pipeline_state.metrics.errors += 1
760
+ return create_error_response("general")
 
 
 
 
 
 
761
 
762
  # -------------------------------------------------------
763
  # Health & Utility