File size: 4,793 Bytes
b0c64f6 7997061 b0c64f6 a79a41b b0c64f6 8cdf335 7997061 b0c64f6 99474e2 b0c64f6 99474e2 b0c64f6 8cdf335 b0c64f6 8cdf335 b0c64f6 8cdf335 b0c64f6 7997061 b0c64f6 a79a41b 7997061 a79a41b 8cdf335 a79a41b b0c64f6 8cdf335 b0c64f6 a79a41b b0c64f6 8cdf335 b0c64f6 a79a41b b0c64f6 a79a41b b0c64f6 a79a41b b0c64f6 8cdf335 b0c64f6 a79a41b b0c64f6 a79a41b 7997061 8cdf335 7997061 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
# pipeline.py
import os
import getpass
import pandas as pd
from typing import Optional, Dict, Any
try:
from langchain.runnables.base import Runnable
except ImportError:
from langchain_core.runnables.base import Runnable
from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
import litellm
from classification_chain import get_classification_chain
from refusal_chain import get_refusal_chain
from tailor_chain import get_tailor_chain
from cleaner_chain import get_cleaner_chain
from langchain.llms.base import LLM
# Environment keys
if not os.environ.get("GEMINI_API_KEY"):
os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
if not os.environ.get("GROQ_API_KEY"):
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
# ... [unchanged code for building/loading vectorstore] ...
# Use your previously provided implementation here.
# For brevity, not repeating this section.
pass
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
# ... [unchanged code for building a RAG chain] ...
pass
classification_chain = get_classification_chain()
refusal_chain = get_refusal_chain()
tailor_chain = get_tailor_chain()
cleaner_chain = get_cleaner_chain()
gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))
wellness_csv = "AIChatbot.csv"
brand_csv = "BrandAI.csv"
wellness_store_dir = "faiss_wellness_store"
brand_store_dir = "faiss_brand_store"
wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)
search_tool = DuckDuckGoSearchTool()
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
manager_agent = CodeAgent(tools=[], model=gemini_llm, managed_agents=[managed_web_agent])
def do_web_search(query: str) -> str:
print("DEBUG: Attempting web search for more info...")
search_query = f"Give me relevant info: {query}"
response = manager_agent.run(search_query)
return response
def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
user_query = inputs["input"]
chat_history = inputs.get("chat_history", [])
print("DEBUG: Starting run_with_chain_context...")
class_result = classification_chain.invoke({"query": user_query})
classification = class_result.get("text", "").strip()
print("DEBUG: Classification =>", classification)
if classification == "OutOfScope":
refusal_text = refusal_chain.run({})
final_refusal = tailor_chain.run({"response": refusal_text})
return {"answer": final_refusal.strip()}
if classification == "Wellness":
# Use the correct key "query" instead of "input"
rag_result = wellness_rag_chain.invoke({"query": user_query, "chat_history": chat_history})
csv_answer = rag_result["result"].strip()
if not csv_answer:
web_answer = do_web_search(user_query)
else:
lower_ans = csv_answer.lower()
if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
web_answer = do_web_search(user_query)
else:
web_answer = ""
final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
final_answer = tailor_chain.run({"response": final_merged}).strip()
return {"answer": final_answer}
if classification == "Brand":
rag_result = brand_rag_chain.invoke({"query": user_query, "chat_history": chat_history})
csv_answer = rag_result["result"].strip()
final_merged = cleaner_chain.merge(kb=csv_answer, web="")
final_answer = tailor_chain.run({"response": final_merged}).strip()
return {"answer": final_answer}
refusal_text = refusal_chain.run({})
final_refusal = tailor_chain.run({"response": refusal_text}).strip()
return {"answer": final_refusal}
# Runnable wrapper for my_memory_logic.py
class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
return run_with_chain_context(input)
pipeline_runnable = PipelineRunnable()
|