|
|
|
|
|
import os |
|
import getpass |
|
import pandas as pd |
|
from typing import Optional, Dict, Any |
|
|
|
try: |
|
from langchain.runnables.base import Runnable |
|
except ImportError: |
|
from langchain_core.runnables.base import Runnable |
|
|
|
from langchain.docstore.document import Document |
|
from langchain.embeddings import HuggingFaceEmbeddings |
|
from langchain.vectorstores import FAISS |
|
from langchain.chains import RetrievalQA |
|
|
|
from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel |
|
import litellm |
|
|
|
from classification_chain import get_classification_chain |
|
from refusal_chain import get_refusal_chain |
|
from tailor_chain import get_tailor_chain |
|
from cleaner_chain import get_cleaner_chain |
|
|
|
from langchain.llms.base import LLM |
|
|
|
|
|
if not os.environ.get("GEMINI_API_KEY"): |
|
os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ") |
|
if not os.environ.get("GROQ_API_KEY"): |
|
os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ") |
|
|
|
def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS: |
|
|
|
|
|
|
|
pass |
|
|
|
def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA: |
|
|
|
pass |
|
|
|
classification_chain = get_classification_chain() |
|
refusal_chain = get_refusal_chain() |
|
tailor_chain = get_tailor_chain() |
|
cleaner_chain = get_cleaner_chain() |
|
|
|
gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY")) |
|
|
|
wellness_csv = "AIChatbot.csv" |
|
brand_csv = "BrandAI.csv" |
|
wellness_store_dir = "faiss_wellness_store" |
|
brand_store_dir = "faiss_brand_store" |
|
|
|
wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir) |
|
brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir) |
|
|
|
wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore) |
|
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore) |
|
|
|
search_tool = DuckDuckGoSearchTool() |
|
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm) |
|
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.") |
|
manager_agent = CodeAgent(tools=[], model=gemini_llm, managed_agents=[managed_web_agent]) |
|
|
|
def do_web_search(query: str) -> str: |
|
print("DEBUG: Attempting web search for more info...") |
|
search_query = f"Give me relevant info: {query}" |
|
response = manager_agent.run(search_query) |
|
return response |
|
|
|
def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]: |
|
user_query = inputs["input"] |
|
chat_history = inputs.get("chat_history", []) |
|
|
|
print("DEBUG: Starting run_with_chain_context...") |
|
class_result = classification_chain.invoke({"query": user_query}) |
|
classification = class_result.get("text", "").strip() |
|
print("DEBUG: Classification =>", classification) |
|
|
|
if classification == "OutOfScope": |
|
refusal_text = refusal_chain.run({}) |
|
final_refusal = tailor_chain.run({"response": refusal_text}) |
|
return {"answer": final_refusal.strip()} |
|
|
|
if classification == "Wellness": |
|
|
|
rag_result = wellness_rag_chain.invoke({"query": user_query, "chat_history": chat_history}) |
|
csv_answer = rag_result["result"].strip() |
|
if not csv_answer: |
|
web_answer = do_web_search(user_query) |
|
else: |
|
lower_ans = csv_answer.lower() |
|
if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]): |
|
web_answer = do_web_search(user_query) |
|
else: |
|
web_answer = "" |
|
final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer) |
|
final_answer = tailor_chain.run({"response": final_merged}).strip() |
|
return {"answer": final_answer} |
|
|
|
if classification == "Brand": |
|
rag_result = brand_rag_chain.invoke({"query": user_query, "chat_history": chat_history}) |
|
csv_answer = rag_result["result"].strip() |
|
final_merged = cleaner_chain.merge(kb=csv_answer, web="") |
|
final_answer = tailor_chain.run({"response": final_merged}).strip() |
|
return {"answer": final_answer} |
|
|
|
refusal_text = refusal_chain.run({}) |
|
final_refusal = tailor_chain.run({"response": refusal_text}).strip() |
|
return {"answer": final_refusal} |
|
|
|
|
|
class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]): |
|
def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]: |
|
return run_with_chain_context(input) |
|
|
|
pipeline_runnable = PipelineRunnable() |
|
|