File size: 8,556 Bytes
2a33798
 
4e5e176
2a33798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65ee2b8
2a33798
 
 
 
 
 
 
 
 
 
65ee2b8
2a33798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65ee2b8
2a33798
 
 
65ee2b8
 
2a33798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65ee2b8
2a33798
 
 
 
 
 
 
 
 
 
 
 
4e5e176
 
 
 
 
 
 
 
 
 
 
 
2d75a44
2a33798
 
 
 
 
 
 
 
 
 
 
 
65ee2b8
 
2a33798
 
 
 
 
 
 
65ee2b8
2a33798
 
 
 
 
 
 
 
 
 
 
 
 
65ee2b8
2a33798
 
 
 
 
 
 
 
 
65ee2b8
 
2a33798
 
65ee2b8
2a33798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65ee2b8
 
 
 
 
 
2a33798
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import openai
from .misc import history_to_str
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.prompts.chat import (
    PromptTemplate,
    ChatPromptTemplate,
    SystemMessagePromptTemplate,
    HumanMessagePromptTemplate,
)
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain import LLMChain
from langchain.callbacks import FileCallbackHandler
from langchain.callbacks import get_openai_callback
from .act import NaiveAct
from memory.env_history import EnvironmentHistory
import tiktoken
from .utils import run_chain
from loguru import logger



class EXE(NaiveAct):
    def __init__(self, action_space, args, prompts, distiller, temperature=0., max_tokens=None, logger=None, fixed_suggestion=None, fixed_insight=None):
        super().__init__(action_space, args, prompts, distiller, temperature, max_tokens, logger)
        self.pre_memory = []
        self.post_memory = []
        self.is_first = True
        self.num_trails = args.num_trails
        self.game_description = args.game_description
        self.goal_description = args.goal_description
        self.action_description = args.action_description
        self.action_desc_dict = args.action_desc_dict
        self.mem_num = args.short_mem_num
        self.fixed_suggestion = fixed_suggestion
        self.fixed_insight = fixed_insight
        self._update_mem(None)
        self.insight = ""

    def num_tokens_from_string(self,string: str) -> int:
        """Returns the number of tokens in a text string."""
        num_tokens = len(self.encoding.encode(string))
        return num_tokens
    
    def update_mem(self,):
        traj = self.game_description 
        traj += self.goal_description
        traj += self.action_description
        traj += str(self.env_history)
        self._update_mem(traj)

    def clear_mem(self):
        self.update_mem()
        self.pre_memory = []
        self.post_memory = []
        self.is_first = True
        self.env_history.reset()
        # self._update_mem(None)

    def _update_mem(self, traj):
        if self.memory:
            self.post_memory = self.memory
            self.insight = self.distiller.generate_insight(self.post_memory)
        else:
            if not self.is_first:
                summary = self.distiller.generate_summary(traj, self.post_memory)
                self.post_memory.append(summary)
                self.insight = self.distiller.generate_insight(self.post_memory)
            else:
                self.is_first = False
                self.insight = ""
        suggestion = self.distiller.generate_suggestion(self.game_description, self.goal_description, self.action_description, self.pre_memory, self.post_memory, self.insight, self.num_trails)
        if self.fixed_suggestion:
            suggestion = self.fixed_suggestion
        if self.fixed_insight:
            self.insight = self.fixed_insight
        self.pre_memory.append(suggestion)
        self.env_history.reset()
        
    def _read_mem(self, ):
        insight_str = ""
        if self.insight:
            insight_str += "The insights of the game are listed below: "
            insight_str += f"{self.insight}\n"
        suggestion_str = "The suggestions are listed below:" + self.pre_memory[-1]
        return insight_str + suggestion_str 
    
    def act(
        self,
        state_description,
        action_description,
        env_info,
        game_description,
        goal_description,
        logfile=None,
    ):
        self.game_description = game_description 
        self.goal_description = goal_description
        self.env_history.add("observation", state_description)

        if self.args.api_type == "azure":
            chat = AzureChatOpenAI(
                openai_api_type=openai.api_type,
                openai_api_version=openai.api_version,
                openai_api_base=openai.api_base,
                openai_api_key=openai.api_key,
                deployment_name=self.args.gpt_version,
                temperature=self.temperature,
                max_tokens=self.max_tokens
            )
        elif self.args.api_type == "openai":
            chat = ChatOpenAI(temperature=self.temperature, openai_api_key=openai.api_key, model=self.args.gpt_version)
        # print(self.logger)
        reply_format_description = \
            "Your response should choose an optimal action from valid action list, and terminated with following format: "        
            # only task relevant examplesA
        template = "Now you are completing a task."
        template += "You need to carefully understand the description of the game. " 
        # TODO: few shot example handle
        if self.irr_few_shot_examples:
            template += "Here are some examples of how you should completing a task."
            for examples in self.irr_few_shot_examples:
                template += "\nQuestion: \n" + examples['question'] + "Answer: \n" + examples['answer']
        
        template += "\n\nNow you are in the task.\n" 
        template += " {game_description}\n{action_description}\n{goal_description}"
        template += "You are observing something and  " \
                "you need to choose the optimal action acoordingly."
        template += 'Response and interact using the format: {reply_format_description}{format_instructions}\n'
        
        template += self._read_mem()
        system_message_prompt = SystemMessagePromptTemplate.from_template(template)
        
        short_memory_template = HumanMessagePromptTemplate.from_template("{history}\nNext is the observation that the agent gets:\n{state_description}Please select an optimal action to gain higher rewards based on the current state and history. The action description is below: {action_description}. Please think step by step.")
        chat_prompt = ChatPromptTemplate.from_messages(
            [system_message_prompt, short_memory_template])
        if self.logger:
            pass
        else:
            if logfile:
                # logger.remove()
                if self.first_call:
                    self.logger = logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' not in x['message'])
                    self.first_call = False
        handler = FileCallbackHandler(logfile)
        total_tokens, total_cost = 0, 0 
        max_think_times = 1

        for i_think in range(max_think_times):
            # chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=True)
            chain = LLMChain(llm=chat, prompt=chat_prompt, callbacks=[handler], verbose=False)
            with get_openai_callback() as cb:
                response = run_chain(
                    chain,
                    game_description=game_description,
                    goal_description=goal_description,
                    action_description=action_description,
                    state_description = self.env_history.get_last_history(),
                    history=self.env_history.get_histories(self.mem_num),
                    format_instructions=self.parser.get_format_instructions(),
                    reply_format_description=reply_format_description,
                    max_token=self.max_tokens
                )

                total_tokens += cb.total_tokens
                total_cost += cb.total_cost
            action = self.parser.parse(response).action        
        self._add_history_after_action(action)
        self.logger.info(f'The GPT response is: {response}.')
        self.logger.info(f'The optimal action is: {action}.')
        if self.pre_memory:
            self.logger.info(f'The suggestion is: {self.pre_memory[-1]}.')
        if self.post_memory:
            self.logger.info(f'The summary is: {self.post_memory[-1]}.')
        if env_info.get('history'):
            self.logger.info(f'History: {history_to_str(env_info["history"])}')
        text_prompt = chat_prompt.format_messages(
            game_description=game_description,
            goal_description=goal_description,
            action_description=action_description,
            state_description = self.env_history.get_last_history(),
            history=self.env_history.get_histories(self.mem_num),
            format_instructions=self.parser.get_format_instructions(),
            reply_format_description=reply_format_description,
        )
        text_prompt = f'{text_prompt[0].content}\n{text_prompt[1].content}'
        return action, text_prompt, response, total_tokens, total_cost