File size: 12,125 Bytes
2a33798
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import argparse
import sys
sys.path.insert(0, sys.path[0]+"/../")
import prompts as task_prompts
import envs
import os
from envs.translator import InitSummarizer, CurrSummarizer, FutureSummarizer, Translator
import gym
from torch.optim.lr_scheduler import LambdaLR
import torch
from tianshou.data import Collector, VectorReplayBuffer, ReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import PPOPolicy, ICMPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
from RL_based.utils import Net_GRU_Bert_tianshou, Net_Bert_CLS_tianshou, Net_Bert_CNN_tianshou, Net_GRU_nn_emb_tianshou
from tianshou.utils import WandbLogger
from torch.utils.tensorboard import SummaryWriter
from tianshou.trainer.utils import test_episode

import warnings
warnings.filterwarnings('ignore')

class MaxStepLimitWrapper(gym.Wrapper):
    def __init__(self, env, max_steps=200):
        super(MaxStepLimitWrapper, self).__init__(env)
        self.max_steps = max_steps
        self.current_step = 0

    def reset(self, **kwargs):
        self.current_step = 0
        return self.env.reset(**kwargs)

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        self.current_step += 1

        if self.current_step >= self.max_steps:
            terminated = True
            info['episode_step_limit'] = self.max_steps

        return observation, reward, terminated, truncated, info

class SimpleTextWrapper(gym.Wrapper):
    def __init__(self, env):
        super(SimpleTextWrapper, self).__init__(env)
        self.env = env

    def reset(self, **kwargs):
        observation, _ = self.env.reset(**kwargs)
        return str(observation), {}

    def step(self, action):
        observation, reward, terminated, truncated, info = self.env.step(action)
        return str(observation), reward, terminated, truncated, info

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description='Evaluate a translator in a gym environment with a ChatGPT model.')
    parser.add_argument('--init_summarizer', type=str, required=True, help='The name of the init summarizer to use.')
    parser.add_argument('--curr_summarizer', type=str, required=True, help='The name of the curr summarizer to use.')
    parser.add_argument('--future_summarizer', type=str, help='The name of the future summarizer to use.')
    parser.add_argument('--env', type=str, default='base_env', help='The name of the gym environment to use.')
    parser.add_argument('--env_name', type=str, default='CartPole-v1', help='The name of the gym environment to use.')
    parser.add_argument('--decider', type=str, default="naive_actor", help='The actor used to select action')
    parser.add_argument('--render', type=str, default="rgb_array", help='The render mode')
    parser.add_argument('--future_horizon', type=int, help='The horizon of looking to future')
    parser.add_argument(
        "--prompt_level",
        type=int,
        default=1,
        help="The level of prompts",
    )
    parser.add_argument(
        "--past_horizon", type=int, help="The horizon of looking back"
    )
    parser.add_argument(
        "--max_episode_len", type=int, default=200, help="The max length of an episode"
    )

### for RL training
    parser.add_argument('--max_length', type=int, default=128, help='The token length of the observation')
    # trans_model_name
    parser.add_argument('--trans_model_name', type=str, default='bert-base-uncased', help='The name of the pretrained transformer to use.')
    parser.add_argument('--model_name', type=str, default='bert-embedding', help='The name of the model to use.')
    parser.add_argument('--vector_env', type=str, default='dummy', help='The name of the vector env to use.')
    parser.add_argument('--eval', action='store_true', default=False, help='Whether to only eval the model')
    parser.add_argument('--policy-path', type=str, default=None, help='The path to the policy to be evaluated')
    parser.add_argument('--collect_one_episode', action='store_true', default=False, help='Whether to only collect one episode')
    parser.add_argument('--lr', type=float, default=0.0003, help='The learning rate of the model')
    parser.add_argument('--step_per_epoch', type=int, default=10000, help='The number of steps per epoch')
    parser.add_argument('--step_per_collect', type=int, default=2000, help='The number of steps per collect')
    parser.add_argument('--lr_decay', action='store_true', default=False, help='Whether to decay the learning rate')
    parser.add_argument('--epoch', type=int, default=400, help='The number of epochs to train')
    parser.add_argument('--resume_path', type=str, default=None, help='The path to the policy to be resumed')
    parser.add_argument('--taxi_specific_env', action='store_true', default=False, help='Whether to use taxi specific env')
    args = parser.parse_args()
    args_dict = vars(args)

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    # Get the specified translator, environment, and ChatGPT model
    env_class = envs.REGISTRY[args.env]
    init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer])
    curr_summarizer = CurrSummarizer(envs.REGISTRY[args.curr_summarizer])
    if args.future_summarizer:
        future_summarizer = FutureSummarizer(
            envs.REGISTRY[args.future_summarizer],
            envs.REGISTRY["cart_policies"],
            future_horizon=args.future_horizon,
        )
    else:
        future_summarizer = None

    wandb_log_config = {
        "env": args.env_name,
        "init_summarizer": args.init_summarizer,
        "curr_summarizer": args.curr_summarizer,
        "future_summarizer": args.future_summarizer,
    }
    wandb_log_config.update(args_dict)

    if not args.eval:
        logger = WandbLogger(
            project="LLM-decider-bench-RL",
            entity="llm-bench-team",
            config=wandb_log_config,
        )
        random_name = logger.wandb_run.name
        log_path = os.path.join('/home/ubuntu/LLM-Decider-Bench/RL_based/results', args.env_name, random_name)
        writer = SummaryWriter(log_dir=log_path)
        writer.add_text("args", str(args))
        logger.load(writer)
        def save_best_fn(policy):
            torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

    sampling_env = envs.REGISTRY["sampling_wrapper"](gym.make(args.env_name))
    if args.prompt_level == 5:
        prompts_class = task_prompts.REGISTRY[(args.env_name,args.decider)]()
    else:
        prompts_class = task_prompts.REGISTRY[(args.decider)]()
    translator = Translator(
        init_summarizer, curr_summarizer, future_summarizer, env=sampling_env
    )
    if args.taxi_specific_env:
        environment = gym.make(args.env_name, render_mode=args.render)
    else:
        environment = env_class(
            gym.make(args.env_name, render_mode=args.render), translator
        )
    
    # Set the translation level
    translate_level = 1
    if args.past_horizon is None and args.future_horizon is None:
        translate_level = 1
    if args.past_horizon and args.future_horizon is None:
        raise NotImplementedError
        # translate_level = 2
    if args.past_horizon is None and args.future_horizon:
        raise NotImplementedError
        # translate_level = 3
    if args.past_horizon and args.future_horizon:
        raise NotImplementedError
        # translate_level = 3.5


    if args.vector_env == 'dummy':
        ThisEnv = DummyVectorEnv
    elif args.vector_env == 'subproc':
        ThisEnv = SubprocVectorEnv
    def make_env():
        if args.taxi_specific_env:
            env = MaxStepLimitWrapper(SimpleTextWrapper(gym.make(args.env_name, render_mode=args.render)), max_steps=200)
            env._max_episode_steps = args.max_episode_len
        else:
            env = env_class(MaxStepLimitWrapper(gym.make(args.env_name, render_mode=args.render), max_steps=200), translator)
            env._max_episode_steps = args.max_episode_len
        
        return env
    train_envs = ThisEnv([make_env for _ in range(20)])
    test_envs = ThisEnv([make_env for _ in range(10)])
    # model & optimizer
    def get_net():
        if args.model_name == "bert-embedding":
            net = Net_GRU_Bert_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[64, 64], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)    
        elif args.model_name == "bert-CLS-embedding":
            net = Net_Bert_CLS_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
        elif args.model_name == "bert-CNN-embedding":
            net = Net_Bert_CNN_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
        elif args.model_name == "nn_embedding":
            net = Net_GRU_nn_emb_tianshou(hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
        return net
    net = get_net()
    actor = Actor(net, environment.action_space.n, device=device).to(device)
    critic = Critic(net, device=device).to(device)
    actor_critic = ActorCritic(actor, critic)
    optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)

    # PPO policy
    dist = torch.distributions.Categorical
    lr_scheduler = None
    if args.lr_decay:
            max_update_num = args.step_per_epoch // args.step_per_collect * args.epoch

            lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
    policy = PPOPolicy(actor, critic, optim, dist, action_space=environment.action_space, lr_scheduler=lr_scheduler).to(device)
    # collector
    train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)), exploration_noise=True)
    test_collector = Collector(policy, test_envs, exploration_noise=True)

    if not args.eval:
        # trainer
        # test train_collector and start filling replay buffer

        if args.resume_path:
            policy.load_state_dict(torch.load(args.resume_path, map_location='cuda'))
            print("Loaded agent from: ", args.resume_path)

        train_collector.collect(256 * 20)
        result = onpolicy_trainer(
            policy,
            train_collector,
            test_collector,
            max_epoch=args.epoch,
            step_per_epoch=50000, # the number of transitions collected per epoch
            repeat_per_collect=4,
            episode_per_test=10,
            batch_size=256,
            logger=logger,
            step_per_collect=1000,  # the number of transitions the collector would collect before the network update
            save_best_fn=save_best_fn,
            # stop_fn=lambda mean_reward: mean_reward >= environment.spec.reward_threshold,
        )
        print(result)
    else:
        assert args.policy_path is not None
        policy.load_state_dict(torch.load(args.policy_path))
        test_collector = Collector(policy, test_envs)
        result = test_episode(policy, test_collector, None, None, n_episode=10)
        print(result)
        if args.collect_one_episode:
            replaybuffer = ReplayBuffer(size=1000)
            test_collector_1 = Collector(policy, environment, replaybuffer)
            test_collector_1.reset_env()
            test_collector_1.reset_buffer()
            policy.eval()
            result = test_collector_1.collect(n_episode=1)
            print('sample results', f"/home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/{args.env_name}/output.txt")
            sample_result = replaybuffer.sample(0)
            f = open(f"/home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/{args.env_name}/output.txt", "w")
            print(sample_result, file=f)
            f.close()