File size: 5,678 Bytes
ff66cf3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
import numpy as np
import os
import IPython

import random
import json
from gensim.utils import save_text


class Memory:
    """
    class that maintains a buffer of generated tasks and codes
    """
    def __init__(self, cfg):
        self.prompt_folder = f"prompts/{cfg['prompt_folder']}"
        self.data_path = cfg["prompt_data_path"]
        self.cfg = cfg

        # a chat history is a list of strings
        self.chat_log = []
        self.online_task_buffer = {}
        self.online_code_buffer = {}
        self.online_asset_buffer = {}

        # directly load current offline memory into online memory
        base_tasks, base_assets, base_task_codes = self.load_offline_memory()
        self.online_task_buffer.update(base_tasks)
        self.online_asset_buffer.update(base_assets)

        # load each code file
        for task_file in base_task_codes:
            # the original cliport task path
            if os.path.exists("cliport/tasks/" + task_file):
                self.online_code_buffer[task_file] = open("cliport/tasks/" + task_file).read()

            # the generated cliport task path
            elif os.path.exists("cliport/generated_tasks/" + task_file):
                self.online_code_buffer[task_file] = open("cliport/generated_tasks/" + task_file).read()

        print(f"load {len(self.online_code_buffer)} tasks for memory from offline to online:")
        cache_embedding_path = "outputs/task_cache_embedding.npz"

        if os.path.exists(cache_embedding_path):
            print("task code embeding:", cache_embedding_path)
            self.task_code_embedding = np.load(cache_embedding_path)

    def save_run(self, new_task):
        """save chat history and potentially save base memory"""
        print("save all interaction to :", f'{new_task["task-name"]}_full_output')
        unroll_chatlog = ''
        for chat in self.chat_log:
            unroll_chatlog += chat
        save_text(
            self.cfg['model_output_dir'], f'{new_task["task-name"]}_full_output', unroll_chatlog
        )

    def save_task_to_online(self, new_task, code):
        """(not dumping the task offline). save the task information for online bootstrapping."""
        self.online_task_buffer[new_task['task-name']] = new_task
        code_file_name = new_task["task-name"].replace("-", "_") + ".py"

        # code file name: actual code in contrast to offline code files format.
        self.online_code_buffer[code_file_name] = code

    def save_task_to_offline(self, new_task, code, generate_task_path='generated_tasks'):
        """save the current task descriptions, assets, and code, if it passes reflection and environment test"""
        generated_task_code_path = os.path.join(
            self.cfg["prompt_data_path"], f"{generate_task_path}_codes.json"
        )
        generated_task_codes = json.load(open(generated_task_code_path))
        new_file_path = new_task["task-name"].replace("-", "_") + ".py"

        if new_file_path not in generated_task_codes:
            generated_task_codes.append(new_file_path)

            python_file_path = f"cliport/{generate_task_path}/{new_file_path}"
            print(f"save {new_task['task-name']} to ", python_file_path)

            with open(python_file_path, "w") as fhandle:
                fhandle.write(code)

            with open(generated_task_code_path, "w") as outfile:
                json.dump(generated_task_codes, outfile, indent=4)
        else:
            print(f"{new_file_path}.py already exists.")

        # save task descriptions
        generated_task_path = os.path.join(
           self.cfg["prompt_data_path"], f"{generate_task_path}.json"
        )
        generated_tasks = json.load(open(generated_task_path))
        generated_tasks[new_task["task-name"]] = new_task

        with open(generated_task_path, "w") as outfile:
            json.dump(generated_tasks, outfile, indent=4)

    def save_task_to_offline_topdown(self, new_task, code, generate_task_path='topdown_generated_tasks'):
        new_file_path = new_task["task-name"].replace("-", "_") + ".py"
        generated_task_codes.append(new_file_path)

        python_file_path = f"cliport/{generate_task_path}/{new_file_path}"
        print(f"save {new_task['task-name']} to ", python_file_path)

        with open(python_file_path, "w") as fhandle:
            fhandle.write(code)


    def load_offline_memory(self):
        """get the current task descriptions, assets, and code"""
        base_task_path = os.path.join(self.data_path, "base_tasks.json")
        base_asset_path = os.path.join(self.data_path, "base_assets.json")
        base_task_code_path = os.path.join(self.data_path, "base_task_codes.json")

        base_tasks = json.load(open(base_task_path))
        base_assets = json.load(open(base_asset_path))
        base_task_codes = json.load(open(base_task_code_path))

        if self.cfg["load_memory"]:
            generated_task_path = os.path.join(self.data_path, "generated_tasks.json")
            generated_asset_path = os.path.join(self.data_path, "generated_assets.json")
            generated_task_code_path = os.path.join(self.data_path, "generated_task_codes.json")

            print("original base task num:", len(base_tasks))
            base_tasks.update(json.load(open(generated_task_path)))
            # base_assets.update(json.load(open(generated_asset_path)))

            for task in json.load(open(generated_task_code_path)):
                if task not in base_task_codes:
                    base_task_codes.append(task)

            print("current base task num:", len(base_tasks))
        return base_tasks, base_assets, base_task_codes