Phoenix21 commited on
Commit
99474e2
·
verified ·
1 Parent(s): 1f10543

Updated pipeline.py for the history feature

Browse files
Files changed (1) hide show
  1. pipeline.py +48 -28
pipeline.py CHANGED
@@ -2,7 +2,7 @@
2
  import os
3
  import getpass
4
  import pandas as pd
5
- from typing import Optional
6
 
7
  from langchain.docstore.document import Document
8
  from langchain.embeddings import HuggingFaceEmbeddings
@@ -12,17 +12,17 @@ from langchain.chains import RetrievalQA
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
  import litellm
14
 
15
- # We import the chain builders from our separate files
16
  from classification_chain import get_classification_chain
17
  from refusal_chain import get_refusal_chain
18
  from tailor_chain import get_tailor_chain
19
- from cleaner_chain import get_cleaner_chain, CleanerChain
20
 
21
- # We also import the relevant RAG logic here or define it directly
22
- # (We define build_rag_chain in this file for clarity)
23
 
24
  ###############################################################################
25
- # 1) Environment: set up keys if missing
26
  ###############################################################################
27
  if not os.environ.get("GEMINI_API_KEY"):
28
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
@@ -30,7 +30,7 @@ if not os.environ.get("GROQ_API_KEY"):
30
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
31
 
32
  ###############################################################################
33
- # 2) build_or_load_vectorstore
34
  ###############################################################################
35
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
36
  if os.path.exists(store_dir):
@@ -43,18 +43,22 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
43
  df = pd.read_csv(csv_path)
44
  df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
45
  df.columns = df.columns.str.strip()
 
46
  if "Answer" in df.columns:
47
  df.rename(columns={"Answer": "Answers"}, inplace=True)
48
  if "Question" not in df.columns and "Question " in df.columns:
49
  df.rename(columns={"Question ": "Question"}, inplace=True)
 
50
  if "Question" not in df.columns or "Answers" not in df.columns:
51
  raise ValueError("CSV must have 'Question' and 'Answers' columns.")
 
52
  docs = []
53
  for _, row in df.iterrows():
54
  q = str(row["Question"])
55
  ans = str(row["Answers"])
56
  doc = Document(page_content=ans, metadata={"question": q})
57
  docs.append(doc)
 
58
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
59
  vectorstore = FAISS.from_documents(docs, embedding=embeddings)
60
  vectorstore.save_local(store_dir)
@@ -63,15 +67,17 @@ def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
63
  ###############################################################################
64
  # 3) Build RAG chain for Gemini
65
  ###############################################################################
66
- from langchain.llms.base import LLM
67
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
68
  class GeminiLangChainLLM(LLM):
69
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
 
70
  messages = [{"role": "user", "content": prompt}]
71
  return llm_model(messages, stop_sequences=stop)
 
72
  @property
73
  def _llm_type(self) -> str:
74
  return "custom_gemini"
 
75
  retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
76
  gemini_as_llm = GeminiLangChainLLM()
77
  rag_chain = RetrievalQA.from_chain_type(
@@ -83,35 +89,29 @@ def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
83
  return rag_chain
84
 
85
  ###############################################################################
86
- # 4) Initialize all the separate chains
87
  ###############################################################################
88
- # Classification chain
89
  classification_chain = get_classification_chain()
90
- # Refusal chain
91
  refusal_chain = get_refusal_chain()
92
- # Tailor chain
93
  tailor_chain = get_tailor_chain()
94
- # Cleaner chain
95
  cleaner_chain = get_cleaner_chain()
96
 
97
  ###############################################################################
98
- # 5) Build our vectorstores + RAG chains
99
  ###############################################################################
100
  wellness_csv = "AIChatbot.csv"
101
  brand_csv = "BrandAI.csv"
102
  wellness_store_dir = "faiss_wellness_store"
103
  brand_store_dir = "faiss_brand_store"
104
 
 
 
105
  wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
106
  brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
107
 
108
- gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
109
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
110
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
111
 
112
- ###############################################################################
113
- # 6) Tools / Agents for web search
114
- ###############################################################################
115
  search_tool = DuckDuckGoSearchTool()
116
  web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
117
  managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
@@ -124,24 +124,40 @@ def do_web_search(query: str) -> str:
124
  return response
125
 
126
  ###############################################################################
127
- # 7) Orchestrator: run_with_chain
128
  ###############################################################################
129
- def run_with_chain(query: str) -> str:
130
- print("DEBUG: Starting run_with_chain...")
131
- # 1) Classify
 
 
 
 
 
 
 
 
 
 
 
 
132
  class_result = classification_chain.invoke({"query": query})
133
  classification = class_result.get("text", "").strip()
134
  print("DEBUG: Classification =>", classification)
135
 
136
- # If OutOfScope => refusal => tailor => return
137
  if classification == "OutOfScope":
138
  refusal_text = refusal_chain.run({})
139
  final_refusal = tailor_chain.run({"response": refusal_text})
140
  return final_refusal.strip()
141
 
142
- # If Wellness => wellness RAG => if insufficient => web => unify => tailor
143
  if classification == "Wellness":
144
- rag_result = wellness_rag_chain({"query": query})
 
 
 
 
145
  csv_answer = rag_result["result"].strip()
146
  if not csv_answer:
147
  web_answer = do_web_search(query)
@@ -151,19 +167,23 @@ def run_with_chain(query: str) -> str:
151
  web_answer = do_web_search(query)
152
  else:
153
  web_answer = ""
 
154
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
155
  final_answer = tailor_chain.run({"response": final_merged})
156
  return final_answer.strip()
157
 
158
- # If Brand => brand RAG => tailor => return
159
  if classification == "Brand":
160
- rag_result = brand_rag_chain({"query": query})
 
 
 
161
  csv_answer = rag_result["result"].strip()
162
  final_merged = cleaner_chain.merge(kb=csv_answer, web="")
163
  final_answer = tailor_chain.run({"response": final_merged})
164
  return final_answer.strip()
165
 
166
- # fallback
167
  refusal_text = refusal_chain.run({})
168
  final_refusal = tailor_chain.run({"response": refusal_text})
169
  return final_refusal.strip()
 
2
  import os
3
  import getpass
4
  import pandas as pd
5
+ from typing import Optional, List
6
 
7
  from langchain.docstore.document import Document
8
  from langchain.embeddings import HuggingFaceEmbeddings
 
12
  from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
13
  import litellm
14
 
15
+ # Import your classification/refusal/tailor/cleaner chains
16
  from classification_chain import get_classification_chain
17
  from refusal_chain import get_refusal_chain
18
  from tailor_chain import get_tailor_chain
19
+ from cleaner_chain import get_cleaner_chain
20
 
21
+ # For RAG chain building
22
+ from langchain.llms.base import LLM
23
 
24
  ###############################################################################
25
+ # 1) Environment: set up keys
26
  ###############################################################################
27
  if not os.environ.get("GEMINI_API_KEY"):
28
  os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
 
30
  os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
31
 
32
  ###############################################################################
33
+ # 2) Build or Load VectorStore
34
  ###############################################################################
35
  def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
36
  if os.path.exists(store_dir):
 
43
  df = pd.read_csv(csv_path)
44
  df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
45
  df.columns = df.columns.str.strip()
46
+
47
  if "Answer" in df.columns:
48
  df.rename(columns={"Answer": "Answers"}, inplace=True)
49
  if "Question" not in df.columns and "Question " in df.columns:
50
  df.rename(columns={"Question ": "Question"}, inplace=True)
51
+
52
  if "Question" not in df.columns or "Answers" not in df.columns:
53
  raise ValueError("CSV must have 'Question' and 'Answers' columns.")
54
+
55
  docs = []
56
  for _, row in df.iterrows():
57
  q = str(row["Question"])
58
  ans = str(row["Answers"])
59
  doc = Document(page_content=ans, metadata={"question": q})
60
  docs.append(doc)
61
+
62
  embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
63
  vectorstore = FAISS.from_documents(docs, embedding=embeddings)
64
  vectorstore.save_local(store_dir)
 
67
  ###############################################################################
68
  # 3) Build RAG chain for Gemini
69
  ###############################################################################
 
70
  def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
71
  class GeminiLangChainLLM(LLM):
72
  def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
73
+ # We'll treat the entire prompt as 'user' content
74
  messages = [{"role": "user", "content": prompt}]
75
  return llm_model(messages, stop_sequences=stop)
76
+
77
  @property
78
  def _llm_type(self) -> str:
79
  return "custom_gemini"
80
+
81
  retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
82
  gemini_as_llm = GeminiLangChainLLM()
83
  rag_chain = RetrievalQA.from_chain_type(
 
89
  return rag_chain
90
 
91
  ###############################################################################
92
+ # 4) Initialize your sub-chains
93
  ###############################################################################
 
94
  classification_chain = get_classification_chain()
 
95
  refusal_chain = get_refusal_chain()
 
96
  tailor_chain = get_tailor_chain()
 
97
  cleaner_chain = get_cleaner_chain()
98
 
99
  ###############################################################################
100
+ # 5) Build VectorStores & RAG Chains
101
  ###############################################################################
102
  wellness_csv = "AIChatbot.csv"
103
  brand_csv = "BrandAI.csv"
104
  wellness_store_dir = "faiss_wellness_store"
105
  brand_store_dir = "faiss_brand_store"
106
 
107
+ gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
108
+
109
  wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
110
  brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
111
 
 
112
  wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
113
  brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
114
 
 
 
 
115
  search_tool = DuckDuckGoSearchTool()
116
  web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
117
  managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
 
124
  return response
125
 
126
  ###############################################################################
127
+ # 6) Orchestrator: run_with_chain_context
128
  ###############################################################################
129
+ def run_with_chain_context(query: str, chat_history: list) -> str:
130
+ """
131
+ Like run_with_chain, but also references `chat_history`.
132
+ We'll do single-turn classification, but pass chat_history
133
+ to the RAG chain if needed.
134
+
135
+ Example usage:
136
+ chat_history = []
137
+ question = "What is Self-Reflection?"
138
+ resp1 = run_with_chain_context(question, chat_history)
139
+ # then chat_history.extend([...]) with HumanMessage/AIMessage
140
+ """
141
+ print("DEBUG: Starting run_with_chain_context...")
142
+
143
+ # 1) Classification (no multi-turn, just single-turn classification)
144
  class_result = classification_chain.invoke({"query": query})
145
  classification = class_result.get("text", "").strip()
146
  print("DEBUG: Classification =>", classification)
147
 
148
+ # 2) If OutOfScope => refusal => tailor => return
149
  if classification == "OutOfScope":
150
  refusal_text = refusal_chain.run({})
151
  final_refusal = tailor_chain.run({"response": refusal_text})
152
  return final_refusal.strip()
153
 
154
+ # 3) If Wellness => call wellness_rag_chain with chat_history
155
  if classification == "Wellness":
156
+ # pass the conversation to .invoke(...) so it can see it if needed
157
+ rag_result = wellness_rag_chain.invoke({
158
+ "input": query,
159
+ "chat_history": chat_history # pass the entire list of prior messages
160
+ })
161
  csv_answer = rag_result["result"].strip()
162
  if not csv_answer:
163
  web_answer = do_web_search(query)
 
167
  web_answer = do_web_search(query)
168
  else:
169
  web_answer = ""
170
+
171
  final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
172
  final_answer = tailor_chain.run({"response": final_merged})
173
  return final_answer.strip()
174
 
175
+ # 4) If Brand => brand_rag_chain with chat_history
176
  if classification == "Brand":
177
+ rag_result = brand_rag_chain.invoke({
178
+ "input": query,
179
+ "chat_history": chat_history
180
+ })
181
  csv_answer = rag_result["result"].strip()
182
  final_merged = cleaner_chain.merge(kb=csv_answer, web="")
183
  final_answer = tailor_chain.run({"response": final_merged})
184
  return final_answer.strip()
185
 
186
+ # fallback => refusal
187
  refusal_text = refusal_chain.run({})
188
  final_refusal = tailor_chain.run({"response": refusal_text})
189
  return final_refusal.strip()