File size: 2,704 Bytes
2a33798
 
 
65ee2b8
 
 
2a33798
65ee2b8
2a33798
 
 
65ee2b8
 
 
 
2a33798
 
 
 
 
 
65ee2b8
 
2a33798
65ee2b8
377ad82
 
65ee2b8
 
 
 
 
2a33798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c640769
65ee2b8
 
2a33798
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
import random 
from deciders.utils import get_completion
import json
from loguru import logger


class TrajPromptSummarizer():
    def __init__(self,args=None,logfile=None):
        self.args = args
        with open("./distillers/traj_summary_few_shot_examples.txt", 'r') as f:
            self.FEW_SHOT_EXAMPLES = f.read()
        
        if logfile:
            # logger.remove()
            logger.add(logfile, colorize=True, enqueue=True, filter=lambda x: '[Reflexion Memory]' in x['message'])

    def generate_from_file(self, file_path,max_step_num=200):
        mem = []
        with open(file_path, 'r') as infile:
            data = json.load(infile)
        for traj in data: 
            traj_text = traj[0]['game_description']+'\n'
            traj_text += traj[0]['goal_description']+'\n'
            for transition in traj[-max_step_num:]: 
                traj_text += transition['observation']+'\n'
                if type(eval(str(transition['action']))) == type([]):
                    action = float(eval(str(transition['action']))[0])-1
                else:
                    action = transition['action']
                traj_text += f"Action: {action}\n"
                traj_text += f"Reward: {transition['reward']}\n"
            traj_text += f"Your performance is: {transition['cum_reward']}\n"
            reflection = self.generate(traj_text, mem, max_len_mem=5)
            mem.append(reflection)        
        return mem

    def _generate_summary_query(self, traj, memory):
        """Allows the Agent to reflect upon a past experience."""
        query: str = f"""You will be given the history of a past experience in which you were placed in an environment and given a task to complete. Summarize your trajectory and reasoning the relation between your policy and the obtained result. Here are two examples:

        {self.FEW_SHOT_EXAMPLES}

        {traj}"""
        if len(memory) > 0:
            query += '\n\nPlans from past attempts:\n'
            for i, m in enumerate(memory):
                query += f'Trial #{i}: {m}\n'

        query += '\n\nSummary:'
        return query

    def generate(self, traj, memory, max_len_mem=5):
        if len(memory)> max_len_mem:
            reflection_query = self._generate_summary_query(traj, memory[-max_len_mem:])
        else:
            reflection_query = self._generate_summary_query(traj, memory)
        reflection = get_completion(reflection_query, api_type=self.args.api_type, engine=self.args.gpt_version)
        logger.info(f'[Reflexion Memory]The reflexion prompt is: {reflection_query}.')
        logger.info(f'[Reflexion Memory]The reflexion response is: {reflection}.')
        return reflection