Phoenix21 commited on
Commit
a79a41b
·
verified ·
1 Parent(s): 9b11728
Files changed (1) hide show
  1. pipeline.py +33 -43
pipeline.py CHANGED
@@ -2,7 +2,7 @@
2
  import os
3
  import getpass
4
  import pandas as pd
5
- from typing import Optional, List
6
 
7
  from langchain.docstore.document import Document
8
  from langchain.embeddings import HuggingFaceEmbeddings
@@ -12,17 +12,16 @@ from langchain.chains import RetrievalQA
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
  import litellm
14
 
15
- # Import your classification/refusal/tailor/cleaner chains
16
  from classification_chain import get_classification_chain
17
  from refusal_chain import get_refusal_chain
18
  from tailor_chain import get_tailor_chain
19
  from cleaner_chain import get_cleaner_chain
20
 
21
- # For RAG chain building
22
  from langchain.llms.base import LLM
23
 
24
  ###############################################################################
25
- # 1) Environment: set up keys
26
  ###############################################################################
27
  if not os.environ.get("GEMINI_API_KEY"):
28
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
@@ -30,7 +29,7 @@ if not os.environ.get("GROQ_API_KEY"):
30
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
31
 
32
  ###############################################################################
33
- # 2) Build or Load VectorStore
34
  ###############################################################################
35
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
36
  if os.path.exists(store_dir):
@@ -70,7 +69,6 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
70
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
71
  class GeminiLangChainLLM(LLM):
72
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
73
- # We'll treat the entire prompt as 'user' content
74
  messages = [{"role": "user", "content": prompt}]
75
  return llm_model(messages, stop_sequences=stop)
76
 
@@ -89,7 +87,7 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
89
  return rag_chain
90
 
91
  ###############################################################################
92
- # 4) Initialize your sub-chains
93
  ###############################################################################
94
  classification_chain = get_classification_chain()
95
  refusal_chain = get_refusal_chain()
@@ -97,7 +95,7 @@ tailor_chain = get_tailor_chain()
97
  cleaner_chain = get_cleaner_chain()
98
 
99
  ###############################################################################
100
- # 5) Build VectorStores & RAG Chains
101
  ###############################################################################
102
  wellness_csv = "AIChatbot.csv"
103
  brand_csv = "BrandAI.csv"
@@ -126,22 +124,21 @@ def do_web_search(query: str) -> str:
126
  ###############################################################################
127
  # 6) Orchestrator: run_with_chain_context
128
  ###############################################################################
129
- def run_with_chain_context(query: str, chat_history: list) -> str:
130
  """
131
- Like run_with_chain, but also references `chat_history`.
132
- We'll do single-turn classification, but pass chat_history
133
- to the RAG chain if needed.
134
-
135
- Example usage:
136
- chat_history = []
137
- question = "What is Self-Reflection?"
138
- resp1 = run_with_chain_context(question, chat_history)
139
- # then chat_history.extend([...]) with HumanMessage/AIMessage
140
  """
141
- print("DEBUG: Starting run_with_chain_context...")
142
 
143
- # 1) Classification (no multi-turn, just single-turn classification)
144
- class_result = classification_chain.invoke({"query": query})
 
 
 
 
 
 
145
  classification = class_result.get("text", "").strip()
146
  print("DEBUG: Classification =>", classification)
147
 
@@ -149,41 +146,34 @@ def run_with_chain_context(query: str, chat_history: list) -> str:
149
  if classification == "OutOfScope":
150
  refusal_text = refusal_chain.run({})
151
  final_refusal = tailor_chain.run({"response": refusal_text})
152
- return final_refusal.strip()
153
 
154
- # 3) If Wellness => call wellness_rag_chain with chat_history
155
  if classification == "Wellness":
156
- # pass the conversation to .invoke(...) so it can see it if needed
157
- rag_result = wellness_rag_chain.invoke({
158
- "input": query,
159
- "chat_history": chat_history # pass the entire list of prior messages
160
- })
161
  csv_answer = rag_result["result"].strip()
162
  if not csv_answer:
163
- web_answer = do_web_search(query)
164
  else:
165
  lower_ans = csv_answer.lower()
166
  if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
167
- web_answer = do_web_search(query)
168
  else:
169
  web_answer = ""
170
-
171
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
172
- final_answer = tailor_chain.run({"response": final_merged})
173
- return final_answer.strip()
174
 
175
- # 4) If Brand => brand_rag_chain with chat_history
176
  if classification == "Brand":
177
- rag_result = brand_rag_chain.invoke({
178
- "input": query,
179
- "chat_history": chat_history
180
- })
181
  csv_answer = rag_result["result"].strip()
182
  final_merged = cleaner_chain.merge(kb=csv_answer, web="")
183
- final_answer = tailor_chain.run({"response": final_merged})
184
- return final_answer.strip()
185
 
186
- # fallback => refusal
187
  refusal_text = refusal_chain.run({})
188
- final_refusal = tailor_chain.run({"response": refusal_text})
189
- return final_refusal.strip()
 
2
  import os
3
  import getpass
4
  import pandas as pd
5
+ from typing import Optional, Dict, Any
6
 
7
  from langchain.docstore.document import Document
8
  from langchain.embeddings import HuggingFaceEmbeddings
 
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
  import litellm
14
 
15
+ # For classification/refusal/tailor/cleaner logic
16
  from classification_chain import get_classification_chain
17
  from refusal_chain import get_refusal_chain
18
  from tailor_chain import get_tailor_chain
19
  from cleaner_chain import get_cleaner_chain
20
 
 
21
  from langchain.llms.base import LLM
22
 
23
  ###############################################################################
24
+ # 1) Environment Setup
25
  ###############################################################################
26
  if not os.environ.get("GEMINI_API_KEY"):
27
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
 
29
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
30
 
31
  ###############################################################################
32
+ # 2) VectorStore Building/Loading
33
  ###############################################################################
34
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
35
  if os.path.exists(store_dir):
 
69
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
70
  class GeminiLangChainLLM(LLM):
71
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
 
72
  messages = [{"role": "user", "content": prompt}]
73
  return llm_model(messages, stop_sequences=stop)
74
 
 
87
  return rag_chain
88
 
89
  ###############################################################################
90
+ # 4) Init Sub-Chains
91
  ###############################################################################
92
  classification_chain = get_classification_chain()
93
  refusal_chain = get_refusal_chain()
 
95
  cleaner_chain = get_cleaner_chain()
96
 
97
  ###############################################################################
98
+ # 5) Build VectorStores & RAG
99
  ###############################################################################
100
  wellness_csv = "AIChatbot.csv"
101
  brand_csv = "BrandAI.csv"
 
124
  ###############################################################################
125
  # 6) Orchestrator: run_with_chain_context
126
  ###############################################################################
127
+ def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
128
  """
129
+ This function is called by the RunnableWithMessageHistory in my_memory_logic.py
130
+ inputs: { "input": <user_query>, "chat_history": <list of messages> }
131
+ Returns: { "answer": <final response> }
 
 
 
 
 
 
132
  """
 
133
 
134
+ user_query = inputs["input"] # The user's new question
135
+ # You can optionally use inputs.get("chat_history") if needed
136
+ chat_history = inputs.get("chat_history", [])
137
+
138
+ print("DEBUG: Starting run_with_chain_context...")
139
+ print(f"User query: {user_query}")
140
+ # 1) Classification
141
+ class_result = classification_chain.invoke({"query": user_query})
142
  classification = class_result.get("text", "").strip()
143
  print("DEBUG: Classification =>", classification)
144
 
 
146
  if classification == "OutOfScope":
147
  refusal_text = refusal_chain.run({})
148
  final_refusal = tailor_chain.run({"response": refusal_text})
149
+ return {"answer": final_refusal.strip()}
150
 
151
+ # 3) If Wellness => wellness RAG => if insufficient => web => unify => tailor
152
  if classification == "Wellness":
153
+ # pass chat_history if your chain can use it
154
+ rag_result = wellness_rag_chain.invoke({"input": user_query, "chat_history": chat_history})
 
 
 
155
  csv_answer = rag_result["result"].strip()
156
  if not csv_answer:
157
+ web_answer = do_web_search(user_query)
158
  else:
159
  lower_ans = csv_answer.lower()
160
  if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
161
+ web_answer = do_web_search(user_query)
162
  else:
163
  web_answer = ""
 
164
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
165
+ final_answer = tailor_chain.run({"response": final_merged}).strip()
166
+ return {"answer": final_answer}
167
 
168
+ # 4) If Brand => brand RAG => tailor => return
169
  if classification == "Brand":
170
+ rag_result = brand_rag_chain.invoke({"input": user_query, "chat_history": chat_history})
 
 
 
171
  csv_answer = rag_result["result"].strip()
172
  final_merged = cleaner_chain.merge(kb=csv_answer, web="")
173
+ final_answer = tailor_chain.run({"response": final_merged}).strip()
174
+ return {"answer": final_answer}
175
 
176
+ # 5) fallback => refusal
177
  refusal_text = refusal_chain.run({})
178
+ final_refusal = tailor_chain.run({"response": refusal_text}).strip()
179
+ return {"answer": final_refusal}