File size: 4,281 Bytes
8ab1aa8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from langchain_community.tools import WikipediaQueryRun
from langchain_community.utilities import WikipediaAPIWrapper
from langchain_community.document_loaders import WebBaseLoader
from langchain_community.vectorstores import FAISS
from langchain_google_genai import GoogleGenerativeAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain.tools.retriever import create_retriever_tool
from dotenv import load_dotenv
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_community.chat_message_histories import ChatMessageHistory
from langchain.agents import create_react_agent, create_tool_calling_agent
from langchain.agents import AgentExecutor
from langchain import hub
from langchain.agents import Tool
from langchain_community.utilities import SerpAPIWrapper
from langchain_core.runnables.history import RunnableWithMessageHistory

class Search_Class:
    def __init__(self):
        self.setup_env()
        self.setup_llm()
        self.setup_embeddings()
        self.setup_vector_store()
        self.setup_tools()
        self.setup_memory()
        self.setup_agent()
        
    def setup_env(self):
        load_dotenv()

    def setup_llm(self):
        self.llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0)
        # self.llm = ChatGroq(model="llama3-70b-8192", temperature=0)
    def setup_embeddings(self):
        self.loader = WebBaseLoader("https://www.etmoney.com/stocks/list-of-stocks")
        self.docs = self.loader.load()
        self.embeddings = GoogleGenerativeAIEmbeddings(model="models/embedding-001")
        self.documents = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200).split_documents(self.docs)

    def setup_vector_store(self):
        self.vectordb = FAISS.from_documents(self.documents, self.embeddings)
        self.retriever = self.vectordb.as_retriever()

    def setup_tools(self):
        self.search_tool = Tool(
            name="Search",
            description="A search engine. Useful for when you need to answer questions about current events. Input should be a search query.",
            func=SerpAPIWrapper().run
        )
        api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=200)
        api_wrapper = WikipediaQueryRun(api_wrapper=WikipediaAPIWrapper())
        self.wiki_tool = Tool(
            name = "Wikipedia",
            description = "A wrapper around Wikipedia. Useful for when you need to answer general questions about people, places, companies, facts, historical events, or other subjects. Input should be a search query.",
            func = api_wrapper.run
        )
        self.retriever_tool = create_retriever_tool(
            self.retriever, "stock_search", "For any information related to stock prices use this tool"
        )
        self.tools = [self.search_tool, self.wiki_tool]
        self.names = ["Wikipedia","stock_search"]

    def setup_memory(self):
        self.memory = ChatMessageHistory(session_id="test-session")
        self.chat_history = []

    def setup_agent(self):
        self.prompt = hub.pull("satvikjain/react_smaller")
        self.agent = create_tool_calling_agent(self.llm, self.tools, self.prompt)
        self.agent_executor = AgentExecutor(agent=self.agent, tools=self.tools, verbose=True, handle_parsing_errors=True)
        self.agent_executor.return_intermediate_steps = True
        self.agent_with_chat_history = RunnableWithMessageHistory(
                                                            self.agent_executor,
                                                            lambda session_id: self.memory,
                                                            input_messages_key="input",
                                                            history_messages_key="chat_history"
                                                        )

    def run(self, user_input = "Hi"):
        response = self.agent_with_chat_history.invoke({
            "input": user_input, "tools": self.tools, "tool_names":self.names},
            config={"configurable": {"session_id": "test-session"}
                    }
        )
        self.chat_history.append([user_input, response["output"]])
        return self.chat_history