allofresh-assistant / allofresh_chatbot.py
ar-dy's picture
optim v2: separated response for each component, mod becomes a classifier
42c1e22
import os
from dotenv import load_dotenv
import re
from loguru import logger
from langchain import PromptTemplate, LLMChain
from langchain.agents import initialize_agent, Tool
from langchain.chat_models import AzureChatOpenAI
from langchain.agents import ZeroShotAgent, AgentExecutor
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.callbacks import get_openai_callback
from langchain.chains.llm import LLMChain
from langchain.llms import AzureOpenAI
from langchain.prompts import PromptTemplate
from utils import lctool_search_allo_api, cut_dialogue_history
from prompts.mod_prompt import MOD_PROMPT, FALLBACK_MESSAGE, MOD_PROMPT_OPTIM_v2
from prompts.ans_prompt import ANS_PREFIX, ANS_FORMAT_INSTRUCTIONS, ANS_SUFFIX, ANS_CHAIN_PROMPT
from prompts.reco_prompt import RECO_PREFIX, RECO_FORMAT_INSTRUCTIONS, RECO_SUFFIX, NO_RECO_OUTPUT
load_dotenv()
class AllofreshChatbot():
def __init__(self, debug=False):
self.ans_memory = None
self.debug = debug
# init llm
self.llms = self.init_llm()
# init moderation chain
self.mod_chain = self.init_mod_chain()
# init answering agent
self.ans_memory = self.init_ans_memory()
self.ans_agent = self.init_ans_agent()
self.ans_chain = self.init_ans_chain()
# init reco agent
self.reco_agent = self.init_reco_agent()
def init_llm(self):
return {
"gpt-4": AzureChatOpenAI(
temperature=0,
deployment_name = os.getenv("DEPLOYMENT_NAME_GPT4"),
model_name = os.getenv("MODEL_NAME_GPT4"),
openai_api_type = os.getenv("OPENAI_API_TYPE"),
openai_api_base = os.getenv("OPENAI_API_BASE"),
openai_api_version = os.getenv("OPENAI_API_VERSION"),
openai_api_key = os.getenv("OPENAI_API_KEY"),
openai_organization = os.getenv("OPENAI_ORGANIZATION")
),
"gpt-3.5": AzureChatOpenAI(
temperature=0,
deployment_name = os.getenv("DEPLOYMENT_NAME_GPT3.5"),
model_name = os.getenv("MODEL_NAME_GPT3.5"),
openai_api_type = os.getenv("OPENAI_API_TYPE"),
openai_api_base = os.getenv("OPENAI_API_BASE"),
openai_api_version = os.getenv("OPENAI_API_VERSION"),
openai_api_key = os.getenv("OPENAI_API_KEY"),
openai_organization = os.getenv("OPENAI_ORGANIZATION")
),
"gpt-3": AzureOpenAI(
temperature=0,
deployment_name = os.getenv("DEPLOYMENT_NAME_GPT3"),
model_name = os.getenv("MODEL_NAME_GPT3"),
openai_api_base = os.getenv("OPENAI_API_BASE"),
openai_api_key = os.getenv("OPENAI_API_KEY"),
openai_organization = os.getenv("OPENAI_ORGANIZATION")
),
}
def init_mod_chain(self):
mod_prompt = PromptTemplate(
template=MOD_PROMPT_OPTIM_v2,
input_variables=["input"]
)
# Define the first LLM chain with the shared AzureOpenAI object and prompt template
return LLMChain(llm=self.llms["gpt-4"], prompt=mod_prompt)
def init_ans_memory(self):
return ConversationBufferMemory(memory_key="chat_history", output_key='output')
def init_ans_agent(self):
ans_tools = [
Tool(
name="Product Search",
func=lctool_search_allo_api,
description="""
To search for products in Allofresh's Database.
Always use this to verify product names.
Outputs product names and prices
"""
)
]
return initialize_agent(
ans_tools,
self.llms["gpt-4"],
agent="conversational-react-description",
verbose=self.debug,
return_intermediate_steps=True,
agent_kwargs={
'prefix': ANS_PREFIX,
# 'format_instructions': ANS_FORMAT_INSTRUCTIONS, # only needed for below gpt-4
'suffix': ANS_SUFFIX
}
)
def init_ans_chain(self):
ans_prompt = PromptTemplate(
template=ANS_CHAIN_PROMPT,
input_variables=["input", "chat_history"]
)
# Define the first LLM chain with the shared AzureOpenAI object and prompt template
return LLMChain(llm=self.llms["gpt-4"], prompt=ans_prompt)
def init_reco_agent(self):
reco_tools = [
Tool(
name="Product Search",
func=lctool_search_allo_api,
description="""
To search for products in Allofresh's Database.
Always use this to verify product names.
Outputs product names and prices
"""
),
Tool(
name="No Recommendation",
func=lambda x: "No recommendation",
description="""
Use this if based on the context you don't need to recommend any products
"""
)
]
prompt = ZeroShotAgent.create_prompt(
reco_tools,
prefix=RECO_PREFIX,
format_instructions=RECO_FORMAT_INSTRUCTIONS,
suffix=RECO_SUFFIX,
input_variables=["input", "agent_scratchpad"]
)
llm_chain_reco = LLMChain(llm=self.llms["gpt-4"], prompt=prompt)
agent_reco = ZeroShotAgent(llm_chain=llm_chain_reco, allowed_tools=[tool.name for tool in reco_tools])
return AgentExecutor.from_agent_and_tools(agent=agent_reco, tools=reco_tools, verbose=self.debug)
def answer(self, query):
# moderate
mod_verdict = self.mod_chain.run({"query": query})
# if pass moderation
if mod_verdict == "True":
# answer question
answer = self.ans_pipeline(query)
# recommend
reco = self.reco_agent.run({"input": self.ans_agent.memory.buffer})
if len(reco) > 0:
self.ans_agent.memory.chat_memory.add_ai_message(reco)
# construct output
return (answer, reco)
else:
return (
FALLBACK_MESSAGE,
None
)
def answer_optim_v1(self, query, chat_history):
"""
We plugged off the tools from the 'answering' component and replaced it with a simple chain
"""
# moderate
mod_verdict = self.mod_chain.run({"input": query})
# if pass moderation
if mod_verdict == "True":
# answer question
return self.ans_chain.run({"input": query, "chat_history": str(chat_history)})
return FALLBACK_MESSAGE
def answer_optim_v2(self, query, chat_history):
"""
We plugged off the tools from the 'answering' component and replaced it with a simple chain
"""
# moderate
mod_verdict = self.mod_chain.run({"input": query})
llm_input = {"input": query, "chat_history": str(chat_history)}
logger.info(f"mod verdict: {mod_verdict}")
# if no need to access knowledge base
if mod_verdict == "ANS_CHAIN":
# answer question
return self.ans_chain.run(llm_input)
# if need to access knowledge base
elif mod_verdict == "ANS_AGENT":
res = self.ans_agent(llm_input)
return res['output'].replace("\\", "/")
return FALLBACK_MESSAGE
def reco_optim_v1(self, chat_history):
reco = self.reco_agent.run({"input": chat_history})
# filter out reco (str) to only contain alphabeticals
return reco if reco != NO_RECO_OUTPUT else None