File size: 8,139 Bytes
1b49043
 
 
42c1e22
1b49043
 
 
 
 
 
 
 
 
 
 
 
42c1e22
1b49043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42c1e22
1b49043
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42c1e22
1b49043
 
 
42c1e22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1b49043
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
import os
from dotenv import load_dotenv
import re
from loguru import logger

from langchain import PromptTemplate, LLMChain
from langchain.agents import initialize_agent, Tool
from langchain.chat_models import AzureChatOpenAI
from langchain.agents import ZeroShotAgent, AgentExecutor
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.callbacks import get_openai_callback
from langchain.chains.llm import LLMChain
from langchain.llms import AzureOpenAI
from langchain.prompts import PromptTemplate

from utils import lctool_search_allo_api, cut_dialogue_history
from prompts.mod_prompt import MOD_PROMPT, FALLBACK_MESSAGE, MOD_PROMPT_OPTIM_v2
from prompts.ans_prompt import ANS_PREFIX, ANS_FORMAT_INSTRUCTIONS, ANS_SUFFIX, ANS_CHAIN_PROMPT
from prompts.reco_prompt import RECO_PREFIX, RECO_FORMAT_INSTRUCTIONS, RECO_SUFFIX, NO_RECO_OUTPUT

load_dotenv()

class AllofreshChatbot():
    def __init__(self, debug=False):
        self.ans_memory = None
        self.debug = debug

        # init llm
        self.llms = self.init_llm()
        # init moderation chain
        self.mod_chain = self.init_mod_chain()
        # init answering agent
        self.ans_memory = self.init_ans_memory()
        self.ans_agent = self.init_ans_agent()
        self.ans_chain = self.init_ans_chain()
        # init reco agent
        self.reco_agent = self.init_reco_agent()

    def init_llm(self):
        return {
            "gpt-4": AzureChatOpenAI(
                temperature=0,
                deployment_name = os.getenv("DEPLOYMENT_NAME_GPT4"),
                model_name = os.getenv("MODEL_NAME_GPT4"),
                openai_api_type = os.getenv("OPENAI_API_TYPE"),
                openai_api_base = os.getenv("OPENAI_API_BASE"),
                openai_api_version = os.getenv("OPENAI_API_VERSION"),
                openai_api_key = os.getenv("OPENAI_API_KEY"),
                openai_organization = os.getenv("OPENAI_ORGANIZATION")
            ),
            "gpt-3.5": AzureChatOpenAI(
                temperature=0,
                deployment_name = os.getenv("DEPLOYMENT_NAME_GPT3.5"),
                model_name = os.getenv("MODEL_NAME_GPT3.5"),
                openai_api_type = os.getenv("OPENAI_API_TYPE"),
                openai_api_base = os.getenv("OPENAI_API_BASE"),
                openai_api_version = os.getenv("OPENAI_API_VERSION"),
                openai_api_key = os.getenv("OPENAI_API_KEY"),
                openai_organization = os.getenv("OPENAI_ORGANIZATION")
            ),
            "gpt-3": AzureOpenAI(
                temperature=0,
                deployment_name = os.getenv("DEPLOYMENT_NAME_GPT3"),
                model_name = os.getenv("MODEL_NAME_GPT3"),
                openai_api_base = os.getenv("OPENAI_API_BASE"),
                openai_api_key = os.getenv("OPENAI_API_KEY"),
                openai_organization = os.getenv("OPENAI_ORGANIZATION")
            ),
        }
    
    def init_mod_chain(self):
        mod_prompt = PromptTemplate(
            template=MOD_PROMPT_OPTIM_v2,
            input_variables=["input"]
        )

        # Define the first LLM chain with the shared AzureOpenAI object and prompt template
        return LLMChain(llm=self.llms["gpt-4"], prompt=mod_prompt)
        
    def init_ans_memory(self):
        return ConversationBufferMemory(memory_key="chat_history", output_key='output')

    def init_ans_agent(self):
        ans_tools = [
            Tool(
                name="Product Search",
                func=lctool_search_allo_api,
                description="""

                    To search for products in Allofresh's Database. 

                    Always use this to verify product names. 

                    Outputs product names and prices

                """    
            )
        ]

        return initialize_agent(
            ans_tools,
            self.llms["gpt-4"],
            agent="conversational-react-description",
            verbose=self.debug,
            return_intermediate_steps=True,
            agent_kwargs={
                'prefix': ANS_PREFIX, 
                # 'format_instructions': ANS_FORMAT_INSTRUCTIONS, # only needed for below gpt-4
                'suffix': ANS_SUFFIX
            }
        )
    
    def init_ans_chain(self):
        ans_prompt = PromptTemplate(
            template=ANS_CHAIN_PROMPT,
            input_variables=["input", "chat_history"]
        )

        # Define the first LLM chain with the shared AzureOpenAI object and prompt template
        return LLMChain(llm=self.llms["gpt-4"], prompt=ans_prompt)
        
    def init_reco_agent(self):
        reco_tools = [
            Tool(
                name="Product Search",
                func=lctool_search_allo_api,
                description="""

                    To search for products in Allofresh's Database. 

                    Always use this to verify product names. 

                    Outputs product names and prices

                """    
            ),
            Tool(
                name="No Recommendation",
                func=lambda x: "No recommendation",
                description="""

                    Use this if based on the context you don't need to recommend any products

                """
            )
        ]
        prompt = ZeroShotAgent.create_prompt(
            reco_tools, 
            prefix=RECO_PREFIX,
            format_instructions=RECO_FORMAT_INSTRUCTIONS,
            suffix=RECO_SUFFIX, 
            input_variables=["input", "agent_scratchpad"]
        )

        llm_chain_reco = LLMChain(llm=self.llms["gpt-4"], prompt=prompt)
        agent_reco = ZeroShotAgent(llm_chain=llm_chain_reco, allowed_tools=[tool.name for tool in reco_tools])
        return AgentExecutor.from_agent_and_tools(agent=agent_reco, tools=reco_tools, verbose=self.debug)
        
    def answer(self, query):
        # moderate
        mod_verdict = self.mod_chain.run({"query": query})
        # if pass moderation
        if mod_verdict == "True":
            # answer question
            answer = self.ans_pipeline(query)
            # recommend
            reco = self.reco_agent.run({"input": self.ans_agent.memory.buffer})
            if len(reco) > 0:
                self.ans_agent.memory.chat_memory.add_ai_message(reco)
            # construct output
            return (answer, reco)
        else:
            return (
                FALLBACK_MESSAGE,
                None
            )
        
    def answer_optim_v1(self, query, chat_history):
        """

        We plugged off the tools from the 'answering' component and replaced it with a simple chain

        """
        # moderate
        mod_verdict = self.mod_chain.run({"input": query})
        # if pass moderation
        if mod_verdict == "True":
            # answer question
            return self.ans_chain.run({"input": query, "chat_history": str(chat_history)})
        return FALLBACK_MESSAGE
    
    def answer_optim_v2(self, query, chat_history):
        """

        We plugged off the tools from the 'answering' component and replaced it with a simple chain

        """
        # moderate
        mod_verdict = self.mod_chain.run({"input": query})
        llm_input = {"input": query, "chat_history": str(chat_history)}

        logger.info(f"mod verdict: {mod_verdict}")
        # if no need to access knowledge base
        if mod_verdict == "ANS_CHAIN":
            # answer question
            return self.ans_chain.run(llm_input)
        # if need to access knowledge base
        elif mod_verdict == "ANS_AGENT":
            res = self.ans_agent(llm_input)
            return res['output'].replace("\\", "/")
        return FALLBACK_MESSAGE

    def reco_optim_v1(self, chat_history): 
        reco = self.reco_agent.run({"input": chat_history})
        # filter out reco (str) to only contain alphabeticals
        return reco if reco != NO_RECO_OUTPUT else None