Chatbot2 / pipeline.py
Phoenix21's picture
Update pipeline.py
5912bb3 verified
raw
history blame
8.42 kB
# pipeline.py
import os
import getpass
import pandas as pd
from typing import Optional, Dict, Any
# (Optional) from langchain.schema import RunnableConfig
# If you have the latest "langchain_core", use from langchain_core.runnables.base import Runnable
# or from langchain.runnables.base import Runnable (depending on your version)
from langchain_core.runnables.base import Runnable
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
import litellm
# Classification/Refusal/Tailor/Cleaner
from classification_chain import get_classification_chain
from refusal_chain import get_refusal_chain
from tailor_chain import get_tailor_chain
from cleaner_chain import get_cleaner_chain
from langchain.llms.base import LLM
###############################################################################
# 1) Environment keys
###############################################################################
if not os.environ.get("GEMINI_API_KEY"):
os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
if not os.environ.get("GROQ_API_KEY"):
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
###############################################################################
# 2) Build or load VectorStore
###############################################################################
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
if os.path.exists(store_dir):
print(f"DEBUG: Found existing FAISS store at '{store_dir}'. Loading from disk.")
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
vectorstore = FAISS.load_local(store_dir, embeddings)
return vectorstore
else:
print(f"DEBUG: Building new store from CSV: {csv_path}")
df = pd.read_csv(csv_path)
df = df.loc[:, ~df.columns.str.contains('^Unnamed')]
df.columns = df.columns.str.strip()
if "Answer" in df.columns:
df.rename(columns={"Answer": "Answers"}, inplace=True)
if "Question" not in df.columns and "Question " in df.columns:
df.rename(columns={"Question ": "Question"}, inplace=True)
if "Question" not in df.columns or "Answers" not in df.columns:
raise ValueError("CSV must have 'Question' and 'Answers' columns.")
docs = []
for _, row in df.iterrows():
q = str(row["Question"])
ans = str(row["Answers"])
doc = Document(page_content=ans, metadata={"question": q})
docs.append(doc)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/multi-qa-mpnet-base-dot-v1")
vectorstore = FAISS.from_documents(docs, embedding=embeddings)
vectorstore.save_local(store_dir)
return vectorstore
###############################################################################
# 3) Build RAG chain
###############################################################################
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
class GeminiLangChainLLM(LLM):
def _call(self, prompt: str, stop: Optional[list] = None, **kwargs) -> str:
messages = [{"role": "user", "content": prompt}]
return llm_model(messages, stop_sequences=stop)
@property
def _llm_type(self) -> str:
return "custom_gemini"
retriever = vectorstore.as_retriever(search_type="similarity", search_kwargs={"k": 3})
gemini_as_llm = GeminiLangChainLLM()
rag_chain = RetrievalQA.from_chain_type(
llm=gemini_as_llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True
)
return rag_chain
###############################################################################
# 4) Initialize sub-chains
###############################################################################
classification_chain = get_classification_chain()
refusal_chain = get_refusal_chain()
tailor_chain = get_tailor_chain()
cleaner_chain = get_cleaner_chain()
###############################################################################
# 5) Build vectorstores & RAG
###############################################################################
gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
wellness_csv = "AIChatbot.csv"
brand_csv = "BrandAI.csv"
wellness_store_dir = "faiss_wellness_store"
brand_store_dir = "faiss_brand_store"
wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
search_tool = DuckDuckGoSearchTool()
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
manager_agent = CodeAgent(tools=[], model=gemini_llm, managed_agents=[managed_web_agent])
def do_web_search(query: str) -> str:
print("DEBUG: Attempting web search for more info...")
search_query = f"Give me relevant info: {query}"
response = manager_agent.run(search_query)
return response
###############################################################################
# 6) Orchestrator function: returns a dict => {"answer": "..."}
###############################################################################
def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
"""
Called by the Runnable.
inputs: { "input": <user_query>, "chat_history": <list of messages> (optional) }
Output: { "answer": <final string> }
"""
user_query = inputs["input"]
chat_history = inputs.get("chat_history", [])
# 1) Classification
class_result = classification_chain.invoke({"query": user_query})
classification = class_result.get("text", "").strip()
if classification == "OutOfScope":
refusal_text = refusal_chain.run({})
final_refusal = tailor_chain.run({"response": refusal_text})
return {"answer": final_refusal.strip()}
if classification == "Wellness":
rag_result = wellness_rag_chain.invoke({"input": user_query, "chat_history": chat_history})
csv_answer = rag_result["result"].strip()
if not csv_answer:
web_answer = do_web_search(user_query)
else:
lower_ans = csv_answer.lower()
if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
web_answer = do_web_search(user_query)
else:
web_answer = ""
final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
final_answer = tailor_chain.run({"response": final_merged}).strip()
return {"answer": final_answer}
if classification == "Brand":
rag_result = brand_rag_chain.invoke({"input": user_query, "chat_history": chat_history})
csv_answer = rag_result["result"].strip()
final_merged = cleaner_chain.merge(kb=csv_answer, web="")
final_answer = tailor_chain.run({"response": final_merged}).strip()
return {"answer": final_answer}
# fallback
refusal_text = refusal_chain.run({})
final_refusal = tailor_chain.run({"response": refusal_text}).strip()
return {"answer": final_refusal}
###############################################################################
# 7) Build a "Runnable" wrapper so .with_listeners() works
###############################################################################
from langchain.runnables.base import Runnable
class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
"""
Wraps run_with_chain_context(...) in a Runnable
so that RunnableWithMessageHistory can attach listeners.
"""
def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
return run_with_chain_context(input)
# Export an instance of PipelineRunnable for use in my_memory_logic.py
pipeline_runnable = PipelineRunnable()