Spaces:
Runtime error
Runtime error
File size: 8,139 Bytes
1b49043 42c1e22 1b49043 42c1e22 1b49043 42c1e22 1b49043 42c1e22 1b49043 42c1e22 1b49043 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
import os
from dotenv import load_dotenv
import re
from loguru import logger
from langchain import PromptTemplate, LLMChain
from langchain.agents import initialize_agent, Tool
from langchain.chat_models import AzureChatOpenAI
from langchain.agents import ZeroShotAgent, AgentExecutor
from langchain.chains.conversation.memory import ConversationBufferMemory
from langchain.callbacks import get_openai_callback
from langchain.chains.llm import LLMChain
from langchain.llms import AzureOpenAI
from langchain.prompts import PromptTemplate
from utils import lctool_search_allo_api, cut_dialogue_history
from prompts.mod_prompt import MOD_PROMPT, FALLBACK_MESSAGE, MOD_PROMPT_OPTIM_v2
from prompts.ans_prompt import ANS_PREFIX, ANS_FORMAT_INSTRUCTIONS, ANS_SUFFIX, ANS_CHAIN_PROMPT
from prompts.reco_prompt import RECO_PREFIX, RECO_FORMAT_INSTRUCTIONS, RECO_SUFFIX, NO_RECO_OUTPUT
load_dotenv()
class AllofreshChatbot():
def __init__(self, debug=False):
self.ans_memory = None
self.debug = debug
# init llm
self.llms = self.init_llm()
# init moderation chain
self.mod_chain = self.init_mod_chain()
# init answering agent
self.ans_memory = self.init_ans_memory()
self.ans_agent = self.init_ans_agent()
self.ans_chain = self.init_ans_chain()
# init reco agent
self.reco_agent = self.init_reco_agent()
def init_llm(self):
return {
"gpt-4": AzureChatOpenAI(
temperature=0,
deployment_name = os.getenv("DEPLOYMENT_NAME_GPT4"),
model_name = os.getenv("MODEL_NAME_GPT4"),
openai_api_type = os.getenv("OPENAI_API_TYPE"),
openai_api_base = os.getenv("OPENAI_API_BASE"),
openai_api_version = os.getenv("OPENAI_API_VERSION"),
openai_api_key = os.getenv("OPENAI_API_KEY"),
openai_organization = os.getenv("OPENAI_ORGANIZATION")
),
"gpt-3.5": AzureChatOpenAI(
temperature=0,
deployment_name = os.getenv("DEPLOYMENT_NAME_GPT3.5"),
model_name = os.getenv("MODEL_NAME_GPT3.5"),
openai_api_type = os.getenv("OPENAI_API_TYPE"),
openai_api_base = os.getenv("OPENAI_API_BASE"),
openai_api_version = os.getenv("OPENAI_API_VERSION"),
openai_api_key = os.getenv("OPENAI_API_KEY"),
openai_organization = os.getenv("OPENAI_ORGANIZATION")
),
"gpt-3": AzureOpenAI(
temperature=0,
deployment_name = os.getenv("DEPLOYMENT_NAME_GPT3"),
model_name = os.getenv("MODEL_NAME_GPT3"),
openai_api_base = os.getenv("OPENAI_API_BASE"),
openai_api_key = os.getenv("OPENAI_API_KEY"),
openai_organization = os.getenv("OPENAI_ORGANIZATION")
),
}
def init_mod_chain(self):
mod_prompt = PromptTemplate(
template=MOD_PROMPT_OPTIM_v2,
input_variables=["input"]
)
# Define the first LLM chain with the shared AzureOpenAI object and prompt template
return LLMChain(llm=self.llms["gpt-4"], prompt=mod_prompt)
def init_ans_memory(self):
return ConversationBufferMemory(memory_key="chat_history", output_key='output')
def init_ans_agent(self):
ans_tools = [
Tool(
name="Product Search",
func=lctool_search_allo_api,
description="""
To search for products in Allofresh's Database.
Always use this to verify product names.
Outputs product names and prices
"""
)
]
return initialize_agent(
ans_tools,
self.llms["gpt-4"],
agent="conversational-react-description",
verbose=self.debug,
return_intermediate_steps=True,
agent_kwargs={
'prefix': ANS_PREFIX,
# 'format_instructions': ANS_FORMAT_INSTRUCTIONS, # only needed for below gpt-4
'suffix': ANS_SUFFIX
}
)
def init_ans_chain(self):
ans_prompt = PromptTemplate(
template=ANS_CHAIN_PROMPT,
input_variables=["input", "chat_history"]
)
# Define the first LLM chain with the shared AzureOpenAI object and prompt template
return LLMChain(llm=self.llms["gpt-4"], prompt=ans_prompt)
def init_reco_agent(self):
reco_tools = [
Tool(
name="Product Search",
func=lctool_search_allo_api,
description="""
To search for products in Allofresh's Database.
Always use this to verify product names.
Outputs product names and prices
"""
),
Tool(
name="No Recommendation",
func=lambda x: "No recommendation",
description="""
Use this if based on the context you don't need to recommend any products
"""
)
]
prompt = ZeroShotAgent.create_prompt(
reco_tools,
prefix=RECO_PREFIX,
format_instructions=RECO_FORMAT_INSTRUCTIONS,
suffix=RECO_SUFFIX,
input_variables=["input", "agent_scratchpad"]
)
llm_chain_reco = LLMChain(llm=self.llms["gpt-4"], prompt=prompt)
agent_reco = ZeroShotAgent(llm_chain=llm_chain_reco, allowed_tools=[tool.name for tool in reco_tools])
return AgentExecutor.from_agent_and_tools(agent=agent_reco, tools=reco_tools, verbose=self.debug)
def answer(self, query):
# moderate
mod_verdict = self.mod_chain.run({"query": query})
# if pass moderation
if mod_verdict == "True":
# answer question
answer = self.ans_pipeline(query)
# recommend
reco = self.reco_agent.run({"input": self.ans_agent.memory.buffer})
if len(reco) > 0:
self.ans_agent.memory.chat_memory.add_ai_message(reco)
# construct output
return (answer, reco)
else:
return (
FALLBACK_MESSAGE,
None
)
def answer_optim_v1(self, query, chat_history):
"""
We plugged off the tools from the 'answering' component and replaced it with a simple chain
"""
# moderate
mod_verdict = self.mod_chain.run({"input": query})
# if pass moderation
if mod_verdict == "True":
# answer question
return self.ans_chain.run({"input": query, "chat_history": str(chat_history)})
return FALLBACK_MESSAGE
def answer_optim_v2(self, query, chat_history):
"""
We plugged off the tools from the 'answering' component and replaced it with a simple chain
"""
# moderate
mod_verdict = self.mod_chain.run({"input": query})
llm_input = {"input": query, "chat_history": str(chat_history)}
logger.info(f"mod verdict: {mod_verdict}")
# if no need to access knowledge base
if mod_verdict == "ANS_CHAIN":
# answer question
return self.ans_chain.run(llm_input)
# if need to access knowledge base
elif mod_verdict == "ANS_AGENT":
res = self.ans_agent(llm_input)
return res['output'].replace("\\", "/")
return FALLBACK_MESSAGE
def reco_optim_v1(self, chat_history):
reco = self.reco_agent.run({"input": chat_history})
# filter out reco (str) to only contain alphabeticals
return reco if reco != NO_RECO_OUTPUT else None |