|
import numpy as np |
|
import os |
|
import IPython |
|
|
|
import random |
|
import json |
|
from gensim.utils import save_text |
|
|
|
|
|
class Memory: |
|
""" |
|
class that maintains a buffer of generated tasks and codes |
|
""" |
|
def __init__(self, cfg): |
|
self.prompt_folder = f"prompts/{cfg['prompt_folder']}" |
|
self.data_path = cfg["prompt_data_path"] |
|
self.cfg = cfg |
|
|
|
|
|
self.chat_log = [] |
|
self.online_task_buffer = {} |
|
self.online_code_buffer = {} |
|
self.online_asset_buffer = {} |
|
|
|
|
|
base_tasks, base_assets, base_task_codes = self.load_offline_memory() |
|
self.online_task_buffer.update(base_tasks) |
|
self.online_asset_buffer.update(base_assets) |
|
|
|
|
|
for task_file in base_task_codes: |
|
|
|
if os.path.exists("cliport/tasks/" + task_file): |
|
self.online_code_buffer[task_file] = open("cliport/tasks/" + task_file).read() |
|
|
|
|
|
elif os.path.exists("cliport/generated_tasks/" + task_file): |
|
self.online_code_buffer[task_file] = open("cliport/generated_tasks/" + task_file).read() |
|
|
|
print(f"load {len(self.online_code_buffer)} tasks for memory from offline to online:") |
|
cache_embedding_path = "outputs/task_cache_embedding.npz" |
|
|
|
if os.path.exists(cache_embedding_path): |
|
print("task code embeding:", cache_embedding_path) |
|
self.task_code_embedding = np.load(cache_embedding_path) |
|
|
|
def save_run(self, new_task): |
|
"""save chat history and potentially save base memory""" |
|
print("save all interaction to :", f'{new_task["task-name"]}_full_output') |
|
unroll_chatlog = '' |
|
for chat in self.chat_log: |
|
unroll_chatlog += chat |
|
save_text( |
|
self.cfg['model_output_dir'], f'{new_task["task-name"]}_full_output', unroll_chatlog |
|
) |
|
|
|
def save_task_to_online(self, new_task, code): |
|
"""(not dumping the task offline). save the task information for online bootstrapping.""" |
|
self.online_task_buffer[new_task['task-name']] = new_task |
|
code_file_name = new_task["task-name"].replace("-", "_") + ".py" |
|
|
|
|
|
self.online_code_buffer[code_file_name] = code |
|
|
|
def save_task_to_offline(self, new_task, code): |
|
"""save the current task descriptions, assets, and code, if it passes reflection and environment test""" |
|
generated_task_code_path = os.path.join( |
|
self.cfg["prompt_data_path"], "generated_task_codes.json" |
|
) |
|
generated_task_codes = json.load(open(generated_task_code_path)) |
|
new_file_path = new_task["task-name"].replace("-", "_") + ".py" |
|
|
|
if new_file_path not in generated_task_codes: |
|
generated_task_codes.append(new_file_path) |
|
|
|
python_file_path = "cliport/generated_tasks/" + new_file_path |
|
print(f"save {new_task['task-name']} to ", python_file_path) |
|
|
|
with open(python_file_path, "w", |
|
) as fhandle: |
|
fhandle.write(code) |
|
|
|
with open(generated_task_code_path, "w") as outfile: |
|
json.dump(generated_task_codes, outfile, indent=4) |
|
else: |
|
print(f"{new_file_path}.py already exists.") |
|
|
|
|
|
generated_task_path = os.path.join( |
|
self.cfg["prompt_data_path"], "generated_tasks.json" |
|
) |
|
generated_tasks = json.load(open(generated_task_path)) |
|
generated_tasks[new_task["task-name"]] = new_task |
|
|
|
with open(generated_task_path, "w") as outfile: |
|
json.dump(generated_tasks, outfile, indent=4) |
|
|
|
def load_offline_memory(self): |
|
"""get the current task descriptions, assets, and code""" |
|
base_task_path = os.path.join(self.data_path, "base_tasks.json") |
|
base_asset_path = os.path.join(self.data_path, "base_assets.json") |
|
base_task_code_path = os.path.join(self.data_path, "base_task_codes.json") |
|
|
|
base_tasks = json.load(open(base_task_path)) |
|
base_assets = json.load(open(base_asset_path)) |
|
base_task_codes = json.load(open(base_task_code_path)) |
|
|
|
if self.cfg["load_memory"]: |
|
generated_task_path = os.path.join(self.data_path, "generated_tasks.json") |
|
generated_asset_path = os.path.join(self.data_path, "generated_assets.json") |
|
generated_task_code_path = os.path.join(self.data_path, "generated_task_codes.json") |
|
|
|
print("original base task num:", len(base_tasks)) |
|
base_tasks.update(json.load(open(generated_task_path))) |
|
|
|
|
|
for task in json.load(open(generated_task_code_path)): |
|
if task not in base_task_codes: |
|
base_task_codes.append(task) |
|
|
|
print("current base task num:", len(base_tasks)) |
|
return base_tasks, base_assets, base_task_codes |
|
|