Spaces:
Runtime error
Runtime error
import os | |
from dotenv import load_dotenv | |
import re | |
from loguru import logger | |
from langchain import PromptTemplate, LLMChain | |
from langchain.agents import initialize_agent, Tool | |
from langchain.chat_models import AzureChatOpenAI | |
from langchain.agents import ZeroShotAgent, AgentExecutor | |
from langchain.chains.conversation.memory import ConversationBufferMemory | |
from langchain.callbacks import get_openai_callback | |
from langchain.chains.llm import LLMChain | |
from langchain.llms import AzureOpenAI | |
from langchain.prompts import PromptTemplate | |
from utils import lctool_search_allo_api, cut_dialogue_history | |
from prompts.mod_prompt import MOD_PROMPT, FALLBACK_MESSAGE, MOD_PROMPT_OPTIM_v2 | |
from prompts.ans_prompt import ANS_PREFIX, ANS_FORMAT_INSTRUCTIONS, ANS_SUFFIX, ANS_CHAIN_PROMPT | |
from prompts.reco_prompt import RECO_PREFIX, RECO_FORMAT_INSTRUCTIONS, RECO_SUFFIX, NO_RECO_OUTPUT | |
load_dotenv() | |
class AllofreshChatbot(): | |
def __init__(self, debug=False): | |
self.ans_memory = None | |
self.debug = debug | |
# init llm | |
self.llms = self.init_llm() | |
# init moderation chain | |
self.mod_chain = self.init_mod_chain() | |
# init answering agent | |
self.ans_memory = self.init_ans_memory() | |
self.ans_agent = self.init_ans_agent() | |
self.ans_chain = self.init_ans_chain() | |
# init reco agent | |
self.reco_agent = self.init_reco_agent() | |
def init_llm(self): | |
return { | |
"gpt-4": AzureChatOpenAI( | |
temperature=0, | |
deployment_name = os.getenv("DEPLOYMENT_NAME_GPT4"), | |
model_name = os.getenv("MODEL_NAME_GPT4"), | |
openai_api_type = os.getenv("OPENAI_API_TYPE"), | |
openai_api_base = os.getenv("OPENAI_API_BASE"), | |
openai_api_version = os.getenv("OPENAI_API_VERSION"), | |
openai_api_key = os.getenv("OPENAI_API_KEY"), | |
openai_organization = os.getenv("OPENAI_ORGANIZATION") | |
), | |
"gpt-3.5": AzureChatOpenAI( | |
temperature=0, | |
deployment_name = os.getenv("DEPLOYMENT_NAME_GPT3.5"), | |
model_name = os.getenv("MODEL_NAME_GPT3.5"), | |
openai_api_type = os.getenv("OPENAI_API_TYPE"), | |
openai_api_base = os.getenv("OPENAI_API_BASE"), | |
openai_api_version = os.getenv("OPENAI_API_VERSION"), | |
openai_api_key = os.getenv("OPENAI_API_KEY"), | |
openai_organization = os.getenv("OPENAI_ORGANIZATION") | |
), | |
"gpt-3": AzureOpenAI( | |
temperature=0, | |
deployment_name = os.getenv("DEPLOYMENT_NAME_GPT3"), | |
model_name = os.getenv("MODEL_NAME_GPT3"), | |
openai_api_base = os.getenv("OPENAI_API_BASE"), | |
openai_api_key = os.getenv("OPENAI_API_KEY"), | |
openai_organization = os.getenv("OPENAI_ORGANIZATION") | |
), | |
} | |
def init_mod_chain(self): | |
mod_prompt = PromptTemplate( | |
template=MOD_PROMPT_OPTIM_v2, | |
input_variables=["input"] | |
) | |
# Define the first LLM chain with the shared AzureOpenAI object and prompt template | |
return LLMChain(llm=self.llms["gpt-4"], prompt=mod_prompt) | |
def init_ans_memory(self): | |
return ConversationBufferMemory(memory_key="chat_history", output_key='output') | |
def init_ans_agent(self): | |
ans_tools = [ | |
Tool( | |
name="Product Search", | |
func=lctool_search_allo_api, | |
description=""" | |
To search for products in Allofresh's Database. | |
Always use this to verify product names. | |
Outputs product names and prices | |
""" | |
) | |
] | |
return initialize_agent( | |
ans_tools, | |
self.llms["gpt-4"], | |
agent="conversational-react-description", | |
verbose=self.debug, | |
return_intermediate_steps=True, | |
agent_kwargs={ | |
'prefix': ANS_PREFIX, | |
# 'format_instructions': ANS_FORMAT_INSTRUCTIONS, # only needed for below gpt-4 | |
'suffix': ANS_SUFFIX | |
} | |
) | |
def init_ans_chain(self): | |
ans_prompt = PromptTemplate( | |
template=ANS_CHAIN_PROMPT, | |
input_variables=["input", "chat_history"] | |
) | |
# Define the first LLM chain with the shared AzureOpenAI object and prompt template | |
return LLMChain(llm=self.llms["gpt-4"], prompt=ans_prompt) | |
def init_reco_agent(self): | |
reco_tools = [ | |
Tool( | |
name="Product Search", | |
func=lctool_search_allo_api, | |
description=""" | |
To search for products in Allofresh's Database. | |
Always use this to verify product names. | |
Outputs product names and prices | |
""" | |
), | |
Tool( | |
name="No Recommendation", | |
func=lambda x: "No recommendation", | |
description=""" | |
Use this if based on the context you don't need to recommend any products | |
""" | |
) | |
] | |
prompt = ZeroShotAgent.create_prompt( | |
reco_tools, | |
prefix=RECO_PREFIX, | |
format_instructions=RECO_FORMAT_INSTRUCTIONS, | |
suffix=RECO_SUFFIX, | |
input_variables=["input", "agent_scratchpad"] | |
) | |
llm_chain_reco = LLMChain(llm=self.llms["gpt-4"], prompt=prompt) | |
agent_reco = ZeroShotAgent(llm_chain=llm_chain_reco, allowed_tools=[tool.name for tool in reco_tools]) | |
return AgentExecutor.from_agent_and_tools(agent=agent_reco, tools=reco_tools, verbose=self.debug) | |
def answer(self, query): | |
# moderate | |
mod_verdict = self.mod_chain.run({"query": query}) | |
# if pass moderation | |
if mod_verdict == "True": | |
# answer question | |
answer = self.ans_pipeline(query) | |
# recommend | |
reco = self.reco_agent.run({"input": self.ans_agent.memory.buffer}) | |
if len(reco) > 0: | |
self.ans_agent.memory.chat_memory.add_ai_message(reco) | |
# construct output | |
return (answer, reco) | |
else: | |
return ( | |
FALLBACK_MESSAGE, | |
None | |
) | |
def answer_optim_v1(self, query, chat_history): | |
""" | |
We plugged off the tools from the 'answering' component and replaced it with a simple chain | |
""" | |
# moderate | |
mod_verdict = self.mod_chain.run({"input": query}) | |
# if pass moderation | |
if mod_verdict == "True": | |
# answer question | |
return self.ans_chain.run({"input": query, "chat_history": str(chat_history)}) | |
return FALLBACK_MESSAGE | |
def answer_optim_v2(self, query, chat_history): | |
""" | |
We plugged off the tools from the 'answering' component and replaced it with a simple chain | |
""" | |
# moderate | |
mod_verdict = self.mod_chain.run({"input": query}) | |
llm_input = {"input": query, "chat_history": str(chat_history)} | |
logger.info(f"mod verdict: {mod_verdict}") | |
# if no need to access knowledge base | |
if mod_verdict == "ANS_CHAIN": | |
# answer question | |
return self.ans_chain.run(llm_input) | |
# if need to access knowledge base | |
elif mod_verdict == "ANS_AGENT": | |
res = self.ans_agent(llm_input) | |
return res['output'].replace("\\", "/") | |
return FALLBACK_MESSAGE | |
def reco_optim_v1(self, chat_history): | |
reco = self.reco_agent.run({"input": chat_history}) | |
# filter out reco (str) to only contain alphabeticals | |
return reco if reco != NO_RECO_OUTPUT else None |