File size: 5,678 Bytes
ff66cf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
import numpy as np
import os
import IPython
import random
import json
from gensim.utils import save_text
class Memory:
"""
class that maintains a buffer of generated tasks and codes
"""
def __init__(self, cfg):
self.prompt_folder = f"prompts/{cfg['prompt_folder']}"
self.data_path = cfg["prompt_data_path"]
self.cfg = cfg
# a chat history is a list of strings
self.chat_log = []
self.online_task_buffer = {}
self.online_code_buffer = {}
self.online_asset_buffer = {}
# directly load current offline memory into online memory
base_tasks, base_assets, base_task_codes = self.load_offline_memory()
self.online_task_buffer.update(base_tasks)
self.online_asset_buffer.update(base_assets)
# load each code file
for task_file in base_task_codes:
# the original cliport task path
if os.path.exists("cliport/tasks/" + task_file):
self.online_code_buffer[task_file] = open("cliport/tasks/" + task_file).read()
# the generated cliport task path
elif os.path.exists("cliport/generated_tasks/" + task_file):
self.online_code_buffer[task_file] = open("cliport/generated_tasks/" + task_file).read()
print(f"load {len(self.online_code_buffer)} tasks for memory from offline to online:")
cache_embedding_path = "outputs/task_cache_embedding.npz"
if os.path.exists(cache_embedding_path):
print("task code embeding:", cache_embedding_path)
self.task_code_embedding = np.load(cache_embedding_path)
def save_run(self, new_task):
"""save chat history and potentially save base memory"""
print("save all interaction to :", f'{new_task["task-name"]}_full_output')
unroll_chatlog = ''
for chat in self.chat_log:
unroll_chatlog += chat
save_text(
self.cfg['model_output_dir'], f'{new_task["task-name"]}_full_output', unroll_chatlog
)
def save_task_to_online(self, new_task, code):
"""(not dumping the task offline). save the task information for online bootstrapping."""
self.online_task_buffer[new_task['task-name']] = new_task
code_file_name = new_task["task-name"].replace("-", "_") + ".py"
# code file name: actual code in contrast to offline code files format.
self.online_code_buffer[code_file_name] = code
def save_task_to_offline(self, new_task, code, generate_task_path='generated_tasks'):
"""save the current task descriptions, assets, and code, if it passes reflection and environment test"""
generated_task_code_path = os.path.join(
self.cfg["prompt_data_path"], f"{generate_task_path}_codes.json"
)
generated_task_codes = json.load(open(generated_task_code_path))
new_file_path = new_task["task-name"].replace("-", "_") + ".py"
if new_file_path not in generated_task_codes:
generated_task_codes.append(new_file_path)
python_file_path = f"cliport/{generate_task_path}/{new_file_path}"
print(f"save {new_task['task-name']} to ", python_file_path)
with open(python_file_path, "w") as fhandle:
fhandle.write(code)
with open(generated_task_code_path, "w") as outfile:
json.dump(generated_task_codes, outfile, indent=4)
else:
print(f"{new_file_path}.py already exists.")
# save task descriptions
generated_task_path = os.path.join(
self.cfg["prompt_data_path"], f"{generate_task_path}.json"
)
generated_tasks = json.load(open(generated_task_path))
generated_tasks[new_task["task-name"]] = new_task
with open(generated_task_path, "w") as outfile:
json.dump(generated_tasks, outfile, indent=4)
def save_task_to_offline_topdown(self, new_task, code, generate_task_path='topdown_generated_tasks'):
new_file_path = new_task["task-name"].replace("-", "_") + ".py"
generated_task_codes.append(new_file_path)
python_file_path = f"cliport/{generate_task_path}/{new_file_path}"
print(f"save {new_task['task-name']} to ", python_file_path)
with open(python_file_path, "w") as fhandle:
fhandle.write(code)
def load_offline_memory(self):
"""get the current task descriptions, assets, and code"""
base_task_path = os.path.join(self.data_path, "base_tasks.json")
base_asset_path = os.path.join(self.data_path, "base_assets.json")
base_task_code_path = os.path.join(self.data_path, "base_task_codes.json")
base_tasks = json.load(open(base_task_path))
base_assets = json.load(open(base_asset_path))
base_task_codes = json.load(open(base_task_code_path))
if self.cfg["load_memory"]:
generated_task_path = os.path.join(self.data_path, "generated_tasks.json")
generated_asset_path = os.path.join(self.data_path, "generated_assets.json")
generated_task_code_path = os.path.join(self.data_path, "generated_task_codes.json")
print("original base task num:", len(base_tasks))
base_tasks.update(json.load(open(generated_task_path)))
# base_assets.update(json.load(open(generated_asset_path)))
for task in json.load(open(generated_task_code_path)):
if task not in base_task_codes:
base_task_codes.append(task)
print("current base task num:", len(base_tasks))
return base_tasks, base_assets, base_task_codes
|