File size: 4,793 Bytes
b0c64f6
7997061
b0c64f6
 
 
a79a41b
b0c64f6
8cdf335
 
 
 
7997061
b0c64f6
 
 
 
 
 
 
 
 
 
 
99474e2
b0c64f6
99474e2
b0c64f6
8cdf335
b0c64f6
 
 
 
 
 
8cdf335
 
 
 
 
b0c64f6
8cdf335
 
 
b0c64f6
 
 
 
 
7997061
 
b0c64f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a79a41b
7997061
a79a41b
 
8cdf335
a79a41b
b0c64f6
8cdf335
b0c64f6
 
 
 
a79a41b
b0c64f6
 
8cdf335
 
b0c64f6
 
a79a41b
b0c64f6
 
 
a79a41b
b0c64f6
 
 
a79a41b
 
b0c64f6
 
8cdf335
b0c64f6
 
a79a41b
 
b0c64f6
 
a79a41b
 
7997061
8cdf335
7997061
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
# pipeline.py

import os
import getpass
import pandas as pd
from typing import Optional, Dict, Any

try:
    from langchain.runnables.base import Runnable
except ImportError:
    from langchain_core.runnables.base import Runnable

from langchain.docstore.document import Document
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.vectorstores import FAISS
from langchain.chains import RetrievalQA

from smolagents import CodeAgent, DuckDuckGoSearchTool, ManagedAgent, LiteLLMModel
import litellm

from classification_chain import get_classification_chain
from refusal_chain import get_refusal_chain
from tailor_chain import get_tailor_chain
from cleaner_chain import get_cleaner_chain

from langchain.llms.base import LLM

# Environment keys
if not os.environ.get("GEMINI_API_KEY"):
    os.environ["GEMINI_API_KEY"] = getpass.getpass("Enter your Gemini API Key: ")
if not os.environ.get("GROQ_API_KEY"):
    os.environ["GROQ_API_KEY"] = getpass.getpass("Enter your GROQ API Key: ")

def build_or_load_vectorstore(csv_path: str, store_dir: str) -> FAISS:
    # ... [unchanged code for building/loading vectorstore] ...
    # Use your previously provided implementation here.
    # For brevity, not repeating this section.
    pass

def build_rag_chain(llm_model: LiteLLMModel, vectorstore: FAISS) -> RetrievalQA:
    # ... [unchanged code for building a RAG chain] ...
    pass

classification_chain = get_classification_chain()
refusal_chain = get_refusal_chain()
tailor_chain = get_tailor_chain()
cleaner_chain = get_cleaner_chain()

gemini_llm = LiteLLMModel(model_id="gemini/gemini-pro", api_key=os.environ.get("GEMINI_API_KEY"))

wellness_csv = "AIChatbot.csv"
brand_csv = "BrandAI.csv"
wellness_store_dir = "faiss_wellness_store"
brand_store_dir = "faiss_brand_store"

wellness_vectorstore = build_or_load_vectorstore(wellness_csv, wellness_store_dir)
brand_vectorstore = build_or_load_vectorstore(brand_csv, brand_store_dir)

wellness_rag_chain = build_rag_chain(gemini_llm, wellness_vectorstore)
brand_rag_chain = build_rag_chain(gemini_llm, brand_vectorstore)

search_tool = DuckDuckGoSearchTool()
web_agent = CodeAgent(tools=[search_tool], model=gemini_llm)
managed_web_agent = ManagedAgent(agent=web_agent, name="web_search", description="Runs web search for you.")
manager_agent = CodeAgent(tools=[], model=gemini_llm, managed_agents=[managed_web_agent])

def do_web_search(query: str) -> str:
    print("DEBUG: Attempting web search for more info...")
    search_query = f"Give me relevant info: {query}"
    response = manager_agent.run(search_query)
    return response

def run_with_chain_context(inputs: Dict[str, Any]) -> Dict[str, str]:
    user_query = inputs["input"]
    chat_history = inputs.get("chat_history", [])

    print("DEBUG: Starting run_with_chain_context...")
    class_result = classification_chain.invoke({"query": user_query})
    classification = class_result.get("text", "").strip()
    print("DEBUG: Classification =>", classification)

    if classification == "OutOfScope":
        refusal_text = refusal_chain.run({})
        final_refusal = tailor_chain.run({"response": refusal_text})
        return {"answer": final_refusal.strip()}

    if classification == "Wellness":
        # Use the correct key "query" instead of "input"
        rag_result = wellness_rag_chain.invoke({"query": user_query, "chat_history": chat_history})
        csv_answer = rag_result["result"].strip()
        if not csv_answer:
            web_answer = do_web_search(user_query)
        else:
            lower_ans = csv_answer.lower()
            if any(phrase in lower_ans for phrase in ["i do not know", "not sure", "no context", "cannot answer"]):
                web_answer = do_web_search(user_query)
            else:
                web_answer = ""
        final_merged = cleaner_chain.merge(kb=csv_answer, web=web_answer)
        final_answer = tailor_chain.run({"response": final_merged}).strip()
        return {"answer": final_answer}

    if classification == "Brand":
        rag_result = brand_rag_chain.invoke({"query": user_query, "chat_history": chat_history})
        csv_answer = rag_result["result"].strip()
        final_merged = cleaner_chain.merge(kb=csv_answer, web="")
        final_answer = tailor_chain.run({"response": final_merged}).strip()
        return {"answer": final_answer}

    refusal_text = refusal_chain.run({})
    final_refusal = tailor_chain.run({"response": refusal_text}).strip()
    return {"answer": final_refusal}

# Runnable wrapper for my_memory_logic.py
class PipelineRunnable(Runnable[Dict[str, Any], Dict[str, str]]):
    def invoke(self, input: Dict[str, Any], config: Optional[Any] = None) -> Dict[str, str]:
        return run_with_chain_context(input)

pipeline_runnable = PipelineRunnable()