File size: 7,882 Bytes
1cc747d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import numpy as np
import os
import IPython
import random
import json
import traceback
import pybullet as p
from gensim.utils import (
    save_text,
    add_to_txt,
    extract_code,
    extract_dict,
    extract_list,
    extract_assets,
    format_dict_prompt,
    sample_list_reference,
    generate_feedback,
)


class Agent:
    """
    class that design new tasks and codes for simulation environments
    """
    def __init__(self, cfg, memory):
        self.cfg = cfg
        self.model_output_dir = cfg["model_output_dir"]
        self.prompt_folder = f"prompts/{cfg['prompt_folder']}"
        self.memory = memory
        self.chat_log = memory.chat_log
        self.use_template = cfg['use_template']

    def propose_task(self, proposed_task_names):
        """Language descriptions for the task"""
        add_to_txt(self.chat_log, "================= Task and Asset Design!", with_print=True)

        if self.use_template:
            task_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_task.txt").read()
            task_asset_replacement_str = format_dict_prompt(self.memory.online_asset_buffer, self.cfg['task_asset_candidate_num'])
            task_prompt_text = task_prompt_text.replace("TASK_ASSET_PROMPT", task_asset_replacement_str)

            task_desc_replacement_str = format_dict_prompt(self.memory.online_task_buffer, self.cfg['task_description_candidate_num'])
            print("prompt task description candidates:")
            print(task_desc_replacement_str)
            task_prompt_text = task_prompt_text.replace("TASK_DESCRIPTION_PROMPT", task_desc_replacement_str)

            if len(self.cfg['target_task_name']) > 0:
                task_prompt_text = task_prompt_text.replace("TARGET_TASK_NAME", self.cfg['target_task_name'])

            # print("Template Task PROMPT: ", task_prompt_text)
        else:
            task_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_task.txt").read()

        # maximum number
        print("online_task_buffer size:", len(self.memory.online_task_buffer))
        total_tasks = self.memory.online_task_buffer

        MAX_NUM = 10
        if len(total_tasks) > MAX_NUM:
            total_tasks = dict(random.sample(total_tasks.items(), MAX_NUM))

        task_prompt_text = task_prompt_text.replace("PAST_TASKNAME_TEMPLATE", format_dict_prompt(total_tasks))

        res = generate_feedback(
            task_prompt_text,
            temperature=self.cfg["gpt_temperature"],
            interaction_txt=self.chat_log,
        )

        # Extract dictionary for task name, descriptions, and assets
        task_def = extract_dict(res, prefix="new_task")
        try:
            exec(task_def, globals())
            self.new_task = new_task
            return new_task
        except:
            self.new_task = {"task-name": "dummy", "assets-used": [], "task_descriptions": ""}
            print(str(traceback.format_exc()))
            return self.new_task

    def propose_assets(self):
        """Asset Generation. Not used for now."""
        if os.path.exists(f"{self.prompt_folder}/cliport_prompt_asset_template.txt"):
            add_to_txt(self.chat_log, "================= Asset Generation!", with_print=True)
            asset_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_asset_template.txt").read()

            if self.use_template:
                asset_prompt_text = asset_prompt_text.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"])
                asset_prompt_text = asset_prompt_text.replace("ASSET_STRING_TEMPLATE", str(self.new_task["assets-used"]))
                print("Template Asset PROMPT: ", asset_prompt_text)

            res = generate_feedback(asset_prompt_text, temperature=0, interaction_txt=self.chat_log)
            print("Save asset to:", self.model_output_dir, task_name + "_asset_output")
            save_text(self.model_output_dir, f'{self.new_task["task-name"]}_asset_output', res)
            asset_list = extract_assets(res)
            # save_urdf(asset_list)
        else:
            asset_list = {}
        return asset_list

    def api_review(self):
        """review the task api"""
        if os.path.exists(f"{self.prompt_folder}/cliport_prompt_api_template.txt"):
            add_to_txt(
                self.chat_log, "================= API Preview!", with_print=True)
            api_prompt_text = open(
                f"{self.prompt_folder}/cliport_prompt_api_template.txt").read()
            if "task-name" in self.new_task:
                api_prompt_text = api_prompt_text.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"])
            api_prompt_text = api_prompt_text.replace("TASK_STRING_TEMPLATE", str(self.new_task))

            res = generate_feedback(
                api_prompt_text, temperature=0, interaction_txt=self.chat_log)

    def template_reference_prompt(self):
        """ select which code reference to reference """
        if os.path.exists(f"{self.prompt_folder}/cliport_prompt_code_reference_selection_template.txt"):
            self.chat_log = add_to_txt(self.chat_log, "================= Code Reference!", with_print=True)
            code_reference_question = open(f'{self.prompt_folder}/cliport_prompt_code_reference_selection_template.txt').read()
            code_reference_question = code_reference_question.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"])
            code_reference_question = code_reference_question.replace("TASK_CODE_LIST_TEMPLATE", str(list(self.memory.online_code_buffer.keys())))

            code_reference_question = code_reference_question.replace("TASK_STRING_TEMPLATE", str(self.new_task))
            res = generate_feedback(code_reference_question, temperature=0., interaction_txt=self.chat_log)
            code_reference_cmd = extract_list(res, prefix='code_reference')
            exec(code_reference_cmd, globals())
            task_code_reference_replace_prompt = ''
            for key in code_reference:
                if key in self.memory.online_code_buffer:
                    task_code_reference_replace_prompt += f'```\n{self.memory.online_code_buffer[key]}\n```\n\n'
                else:
                    print("missing task reference code:", key)
        else:
            task_code_reference_replace_prompt = sample_list_reference(base_task_codes, sample_num=cfg['task_code_candidate_num'])
        # print("Template Reference Code PROMPT: ", task_code_reference_replace_prompt)

        return task_code_reference_replace_prompt

    def implement_task(self):
        """Generate Code for the task"""
        code_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_code_split_template.txt").read()
        code_prompt_text = code_prompt_text.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"])

        if self.use_template or os.path.exists(f"{self.prompt_folder}/cliport_prompt_code_reference_selection_template.txt"):
            task_code_reference_replace_prompt = self.template_reference_prompt()
            code_prompt_text = code_prompt_text.replace("TASK_CODE_REFERENCE_TEMPLATE", task_code_reference_replace_prompt)

        elif os.path.exists(f"{self.prompt_folder}/cliport_prompt_code_split_template.txt"):
            self.chat_log = add_to_txt(self.chat_log, "================= Code Generation!", with_print=True)
            code_prompt_text = code_prompt_text.replace("TASK_STRING_TEMPLATE", str(self.new_task))

        res = generate_feedback(
                code_prompt_text, temperature=0, interaction_txt=self.chat_log)
        code, task_name = extract_code(res)
        print("Save code to:", self.model_output_dir, task_name + "_code_output")
        save_text(self.model_output_dir, task_name + "_code_output", code)

        if len(task_name) == 0:
            print("empty task name:", task_name)
            return None

        return code, task_name