Spaces:
Runtime error
Runtime error
File size: 8,556 Bytes
2a33798 4e5e176 2a33798 65ee2b8 2a33798 65ee2b8 2a33798 65ee2b8 2a33798 65ee2b8 2a33798 65ee2b8 2a33798 4e5e176 2d75a44 2a33798 65ee2b8 2a33798 65ee2b8 2a33798 65ee2b8 2a33798 65ee2b8 2a33798 65ee2b8 2a33798 65ee2b8 2a33798 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
import openai
from .misc import history_to_str
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.prompts.chat import (
PromptTemplate,
ChatPromptTemplate,
SystemMessagePromptTemplate,
HumanMessagePromptTemplate,
)
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain import LLMChain
from langchain.callbacks import FileCallbackHandler
from langchain.callbacks import get_openai_callback
from .act import NaiveAct
from memory.env_history import EnvironmentHistory
import tiktoken
from .utils import run_chain
from loguru import logger
class EXE(NaiveAct):
def __init__(self, action_space, args, prompts, distiller, temperature=0., max_tokens=None, logger=None, fixed_suggestion=None, fixed_insight=None):
super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
self.pre_memory = []
self.post_memory = []
self.is_first = True
self.num_trails = args.num_trails
self.game_description = args.game_description
self.goal_description = args.goal_description
self.action_description = args.action_description
self.action_desc_dict = args.action_desc_dict
self.mem_num = args.short_mem_num
self.fixed_suggestion = fixed_suggestion
self.fixed_insight = fixed_insight
self._update_mem(None)
self.insight = ""
def num_tokens_from_string(self,string: str) -> int:
"""Returns the number of tokens in a text string."""
num_tokens = len(self.encoding.encode(string))
return num_tokens
def update_mem(self,):
traj = self.game_description
traj += self.goal_description
traj += self.action_description
traj += str(self.env_history)
self._update_mem(traj)
def clear_mem(self):
self.update_mem()
self.pre_memory = []
self.post_memory = []
self.is_first = True
self.env_history.reset()
# self._update_mem(None)
def _update_mem(self, traj):
if self.memory:
self.post_memory = self.memory
self.insight = self.distiller.generate_insight(self.post_memory)
else:
if not self.is_first:
summary = self.distiller.generate_summary(traj, self.post_memory)
self.post_memory.append(summary)
self.insight = self.distiller.generate_insight(self.post_memory)
else:
self.is_first = False
self.insight = ""
suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.insight, self.num_trails)
if self.fixed_suggestion:
suggestion = self.fixed_suggestion
if self.fixed_insight:
self.insight = self.fixed_insight
self.pre_memory.append(suggestion)
self.env_history.reset()
def _read_mem(self, ):
insight_str = ""
if self.insight:
insight_str += "The insights of the game are listed below: "
insight_str += f"{self.insight}\n"
suggestion_str = "The suggestions are listed below:" + self.pre_memory[-1]
return insight_str + suggestion_str
def act(
self,
state_description,
action_description,
env_info,
game_description,
goal_description,
logfile=None,
):
self.game_description = game_description
self.goal_description = goal_description
self.env_history.add("observation", state_description)
if self.args.api_type == "azure":
chat = AzureChatOpenAI(
openai_api_type=openai.api_type,
openai_api_version=openai.api_version,
openai_api_base=openai.api_base,
openai_api_key=openai.api_key,
deployment_name=self.args.gpt_version,
temperature=self.temperature,
max_tokens=self.max_tokens
)
elif self.args.api_type == "openai":
chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key, model=self.args.gpt_version)
# print(self.logger)
reply_format_description = \
"Your response should choose an optimal action from valid action list, and terminated with following format: "
# only task relevant examplesA
template = "Now you are completing a task."
template += "You need to carefully understand the description of the game. "
# TODO: few shot example handle
if self.irr_few_shot_examples:
template += "Here are some examples of how you should completing a task."
for examples in self.irr_few_shot_examples:
template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
template += "\n\nNow you are in the task.\n"
template += " {game_description}\n{action_description}\n{goal_description}"
template += "You are observing something and " \
"you need to choose the optimal action acoordingly."
template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
template += self._read_mem()
system_message_prompt = SystemMessagePromptTemplate.from_template(template)
short_memory_template = HumanMessagePromptTemplate.from_template("{history}\nNext is the observation that the agent gets:\n{state_description}Please select an optimal action to gain higher rewards based on the current state and history. The action description is below: {action_description}. Please think step by step.")
chat_prompt = ChatPromptTemplate.from_messages(
[system_message_prompt, short_memory_template])
if self.logger:
pass
else:
if logfile:
# logger.remove()
if self.first_call:
self.logger = logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
self.first_call = False
handler = FileCallbackHandler(logfile)
total_tokens, total_cost = 0, 0
max_think_times = 1
for i_think in range(max_think_times):
# chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=True)
chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
with get_openai_callback() as cb:
response = run_chain(
chain,
game_description=game_description,
goal_description=goal_description,
action_description=action_description,
state_description = self.env_history.get_last_history(),
history=self.env_history.get_histories(self.mem_num),
format_instructions=self.parser.get_format_instructions(),
reply_format_description=reply_format_description,
max_token=self.max_tokens
)
total_tokens += cb.total_tokens
total_cost += cb.total_cost
action = self.parser.parse(response).action
self._add_history_after_action(action)
self.logger.info(f'The GPT response is: {response}.')
self.logger.info(f'The optimal action is: {action}.')
if self.pre_memory:
self.logger.info(f'The suggestion is: {self.pre_memory[-1]}.')
if self.post_memory:
self.logger.info(f'The summary is: {self.post_memory[-1]}.')
if env_info.get('history'):
self.logger.info(f'History: {history_to_str(env_info["history"])}')
text_prompt = chat_prompt.format_messages(
game_description=game_description,
goal_description=goal_description,
action_description=action_description,
state_description = self.env_history.get_last_history(),
history=self.env_history.get_histories(self.mem_num),
format_instructions=self.parser.get_format_instructions(),
reply_format_description=reply_format_description,
)
text_prompt = f'{text_prompt[0].content}\n{text_prompt[1].content}'
return action, text_prompt, response, total_tokens, total_cost |