Chatbot2 / pipeline.py
Phoenix21's picture
Update pipeline.py
8cdf335 verified
raw
history blame
4.79 kB
# pipeline.py
import os
import getpass
import pandas as pd
from typing import Optional, Dict, Any
try:
from langchain.runnables.base import Runnable
except ImportError:
from langchain_core.runnables.base import Runnable
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
import litellm
from classification_chain import get_classification_chain
from refusal_chain import get_refusal_chain
from tailor_chain import get_tailor_chain
from cleaner_chain import get_cleaner_chain
from langchain.llms.base import LLM
# Environment keys
if not os.environ.get("GEMINI_API_KEY"):
os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
if not os.environ.get("GROQ_API_KEY"):
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
# ... [unchanged code for building/loading vectorstore] ...
# Use your previously provided implementation here.
# For brevity, not repeating this section.
pass
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
# ... [unchanged code for building a RAG chain] ...
pass
classification_chain = get_classification_chain()
refusal_chain = get_refusal_chain()
tailor_chain = get_tailor_chain()
cleaner_chain = get_cleaner_chain()
gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
wellness_csv = "AIChatbot.csv"
brand_csv = "BrandAI.csv"
wellness_store_dir = "faiss_wellness_store"
brand_store_dir = "faiss_brand_store"
wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
search_tool = DuckDuckGoSearchTool()
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
manager_agent = CodeAgent(tools=[], model=gemini_llm, managed_agents=[managed_web_agent])
def do_web_search(query: str) -> str:
print("DEBUG: Attempting web search for more info...")
search_query = f"Give me relevant info: {query}"
response = manager_agent.run(search_query)
return response
def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
user_query = inputs["input"]
chat_history = inputs.get("chat_history", [])
print("DEBUG: Starting run_with_chain_context...")
class_result = classification_chain.invoke({"query": user_query})
classification = class_result.get("text", "").strip()
print("DEBUG: Classification =>", classification)
if classification == "OutOfScope":
refusal_text = refusal_chain.run({})
final_refusal = tailor_chain.run({"response": refusal_text})
return {"answer": final_refusal.strip()}
if classification == "Wellness":
# Use the correct key "query" instead of "input"
rag_result = wellness_rag_chain.invoke({"query": user_query, "chat_history": chat_history})
csv_answer = rag_result["result"].strip()
if not csv_answer:
web_answer = do_web_search(user_query)
else:
lower_ans = csv_answer.lower()
if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
web_answer = do_web_search(user_query)
else:
web_answer = ""
final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
final_answer = tailor_chain.run({"response": final_merged}).strip()
return {"answer": final_answer}
if classification == "Brand":
rag_result = brand_rag_chain.invoke({"query": user_query, "chat_history": chat_history})
csv_answer = rag_result["result"].strip()
final_merged = cleaner_chain.merge(kb=csv_answer, web="")
final_answer = tailor_chain.run({"response": final_merged}).strip()
return {"answer": final_answer}
refusal_text = refusal_chain.run({})
final_refusal = tailor_chain.run({"response": refusal_text}).strip()
return {"answer": final_refusal}
# Runnable wrapper for my_memory_logic.py
class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
return run_with_chain_context(input)
pipeline_runnable = PipelineRunnable()