Spaces:
Runtime error
Runtime error
File size: 7,882 Bytes
8fc2b4e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import numpy as np
import os
import IPython
import random
import json
import traceback
import pybullet as p
from gensim.utils import (
save_text,
add_to_txt,
extract_code,
extract_dict,
extract_list,
extract_assets,
format_dict_prompt,
sample_list_reference,
generate_feedback,
)
class Agent:
"""
class that design new tasks and codes for simulation environments
"""
def __init__(self, cfg, memory):
self.cfg = cfg
self.model_output_dir = cfg["model_output_dir"]
self.prompt_folder = f"prompts/{cfg['prompt_folder']}"
self.memory = memory
self.chat_log = memory.chat_log
self.use_template = cfg['use_template']
def propose_task(self, proposed_task_names):
"""Language descriptions for the task"""
add_to_txt(self.chat_log, "================= Task and Asset Design!", with_print=True)
if self.use_template:
task_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_task.txt").read()
task_asset_replacement_str = format_dict_prompt(self.memory.online_asset_buffer, self.cfg['task_asset_candidate_num'])
task_prompt_text = task_prompt_text.replace("TASK_ASSET_PROMPT", task_asset_replacement_str)
task_desc_replacement_str = format_dict_prompt(self.memory.online_task_buffer, self.cfg['task_description_candidate_num'])
print("prompt task description candidates:")
print(task_desc_replacement_str)
task_prompt_text = task_prompt_text.replace("TASK_DESCRIPTION_PROMPT", task_desc_replacement_str)
if len(self.cfg['target_task_name']) > 0:
task_prompt_text = task_prompt_text.replace("TARGET_TASK_NAME", self.cfg['target_task_name'])
# print("Template Task PROMPT: ", task_prompt_text)
else:
task_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_task.txt").read()
# maximum number
print("online_task_buffer size:", len(self.memory.online_task_buffer))
total_tasks = self.memory.online_task_buffer
MAX_NUM = 20
if len(total_tasks) > MAX_NUM:
total_tasks = dict(random.sample(total_tasks.items(), MAX_NUM))
task_prompt_text = task_prompt_text.replace("PAST_TASKNAME_TEMPLATE", format_dict_prompt(total_tasks))
res = generate_feedback(
task_prompt_text,
temperature=self.cfg["gpt_temperature"],
interaction_txt=self.chat_log,
)
# Extract dictionary for task name, descriptions, and assets
task_def = extract_dict(res, prefix="new_task")
try:
exec(task_def, globals())
self.new_task = new_task
return new_task
except:
self.new_task = {"task-name": "dummy", "assets-used": [], "task_descriptions": ""}
print(str(traceback.format_exc()))
return self.new_task
def propose_assets(self):
"""Asset Generation. Not used for now."""
if os.path.exists(f"{self.prompt_folder}/cliport_prompt_asset_template.txt"):
add_to_txt(self.chat_log, "================= Asset Generation!", with_print=True)
asset_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_asset_template.txt").read()
if self.use_template:
asset_prompt_text = asset_prompt_text.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"])
asset_prompt_text = asset_prompt_text.replace("ASSET_STRING_TEMPLATE", str(self.new_task["assets-used"]))
print("Template Asset PROMPT: ", asset_prompt_text)
res = generate_feedback(asset_prompt_text, temperature=0, interaction_txt=self.chat_log)
print("Save asset to:", self.model_output_dir, task_name + "_asset_output")
save_text(self.model_output_dir, f'{self.new_task["task-name"]}_asset_output', res)
asset_list = extract_assets(res)
# save_urdf(asset_list)
else:
asset_list = {}
return asset_list
def api_review(self):
"""review the task api"""
if os.path.exists(f"{self.prompt_folder}/cliport_prompt_api_template.txt"):
add_to_txt(
self.chat_log, "================= API Preview!", with_print=True)
api_prompt_text = open(
f"{self.prompt_folder}/cliport_prompt_api_template.txt").read()
if "task-name" in self.new_task:
api_prompt_text = api_prompt_text.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"])
api_prompt_text = api_prompt_text.replace("TASK_STRING_TEMPLATE", str(self.new_task))
res = generate_feedback(
api_prompt_text, temperature=0, interaction_txt=self.chat_log)
def template_reference_prompt(self):
""" select which code reference to reference """
if os.path.exists(f"{self.prompt_folder}/cliport_prompt_code_reference_selection_template.txt"):
self.chat_log = add_to_txt(self.chat_log, "================= Code Reference!", with_print=True)
code_reference_question = open(f'{self.prompt_folder}/cliport_prompt_code_reference_selection_template.txt').read()
code_reference_question = code_reference_question.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"])
code_reference_question = code_reference_question.replace("TASK_CODE_LIST_TEMPLATE", str(list(self.memory.online_code_buffer.keys())))
code_reference_question = code_reference_question.replace("TASK_STRING_TEMPLATE", str(self.new_task))
res = generate_feedback(code_reference_question, temperature=0., interaction_txt=self.chat_log)
code_reference_cmd = extract_list(res, prefix='code_reference')
exec(code_reference_cmd, globals())
task_code_reference_replace_prompt = ''
for key in code_reference:
if key in self.memory.online_code_buffer:
task_code_reference_replace_prompt += f'```\n{self.memory.online_code_buffer[key]}\n```\n\n'
else:
print("missing task reference code:", key)
else:
task_code_reference_replace_prompt = sample_list_reference(base_task_codes, sample_num=cfg['task_code_candidate_num'])
# print("Template Reference Code PROMPT: ", task_code_reference_replace_prompt)
return task_code_reference_replace_prompt
def implement_task(self):
"""Generate Code for the task"""
code_prompt_text = open(f"{self.prompt_folder}/cliport_prompt_code_split_template.txt").read()
code_prompt_text = code_prompt_text.replace("TASK_NAME_TEMPLATE", self.new_task["task-name"])
if self.use_template or os.path.exists(f"{self.prompt_folder}/cliport_prompt_code_reference_selection_template.txt"):
task_code_reference_replace_prompt = self.template_reference_prompt()
code_prompt_text = code_prompt_text.replace("TASK_CODE_REFERENCE_TEMPLATE", task_code_reference_replace_prompt)
elif os.path.exists(f"{self.prompt_folder}/cliport_prompt_code_split_template.txt"):
self.chat_log = add_to_txt(self.chat_log, "================= Code Generation!", with_print=True)
code_prompt_text = code_prompt_text.replace("TASK_STRING_TEMPLATE", str(self.new_task))
res = generate_feedback(
code_prompt_text, temperature=0, interaction_txt=self.chat_log)
code, task_name = extract_code(res)
print("Save code to:", self.model_output_dir, task_name + "_code_output")
save_text(self.model_output_dir, task_name + "_code_output", code)
if len(task_name) == 0:
print("empty task name:", task_name)
return None
return code, task_name
|