Text-Gym-Agents / RL_based /train_PPO.py
Jarvis-K
init
2a33798
raw
history blame
12.1 kB
import argparse
import sys
sys.path.insert(0, sys.path[0]+"/../")
import prompts as task_prompts
import envs
import os
from envs.translator import InitSummarizer, CurrSummarizer, FutureSummarizer, Translator
import gym
from torch.optim.lr_scheduler import LambdaLR
import torch
from tianshou.data import Collector, VectorReplayBuffer, ReplayBuffer
from tianshou.env import DummyVectorEnv, SubprocVectorEnv
from tianshou.policy import PPOPolicy, ICMPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils.net.common import ActorCritic
from tianshou.utils.net.discrete import Actor, Critic, IntrinsicCuriosityModule
from RL_based.utils import Net_GRU_Bert_tianshou, Net_Bert_CLS_tianshou, Net_Bert_CNN_tianshou, Net_GRU_nn_emb_tianshou
from tianshou.utils import WandbLogger
from torch.utils.tensorboard import SummaryWriter
from tianshou.trainer.utils import test_episode
import warnings
warnings.filterwarnings('ignore')
class MaxStepLimitWrapper(gym.Wrapper):
def __init__(self, env, max_steps=200):
super(MaxStepLimitWrapper, self).__init__(env)
self.max_steps = max_steps
self.current_step = 0
def reset(self, **kwargs):
self.current_step = 0
return self.env.reset(**kwargs)
def step(self, action):
observation, reward, terminated, truncated, info = self.env.step(action)
self.current_step += 1
if self.current_step >= self.max_steps:
terminated = True
info['episode_step_limit'] = self.max_steps
return observation, reward, terminated, truncated, info
class SimpleTextWrapper(gym.Wrapper):
def __init__(self, env):
super(SimpleTextWrapper, self).__init__(env)
self.env = env
def reset(self, **kwargs):
observation, _ = self.env.reset(**kwargs)
return str(observation), {}
def step(self, action):
observation, reward, terminated, truncated, info = self.env.step(action)
return str(observation), reward, terminated, truncated, info
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Evaluate a translator in a gym environment with a ChatGPT model.')
parser.add_argument('--init_summarizer', type=str, required=True, help='The name of the init summarizer to use.')
parser.add_argument('--curr_summarizer', type=str, required=True, help='The name of the curr summarizer to use.')
parser.add_argument('--future_summarizer', type=str, help='The name of the future summarizer to use.')
parser.add_argument('--env', type=str, default='base_env', help='The name of the gym environment to use.')
parser.add_argument('--env_name', type=str, default='CartPole-v1', help='The name of the gym environment to use.')
parser.add_argument('--decider', type=str, default="naive_actor", help='The actor used to select action')
parser.add_argument('--render', type=str, default="rgb_array", help='The render mode')
parser.add_argument('--future_horizon', type=int, help='The horizon of looking to future')
parser.add_argument(
"--prompt_level",
type=int,
default=1,
help="The level of prompts",
)
parser.add_argument(
"--past_horizon", type=int, help="The horizon of looking back"
)
parser.add_argument(
"--max_episode_len", type=int, default=200, help="The max length of an episode"
)
### for RL training
parser.add_argument('--max_length', type=int, default=128, help='The token length of the observation')
# trans_model_name
parser.add_argument('--trans_model_name', type=str, default='bert-base-uncased', help='The name of the pretrained transformer to use.')
parser.add_argument('--model_name', type=str, default='bert-embedding', help='The name of the model to use.')
parser.add_argument('--vector_env', type=str, default='dummy', help='The name of the vector env to use.')
parser.add_argument('--eval', action='store_true', default=False, help='Whether to only eval the model')
parser.add_argument('--policy-path', type=str, default=None, help='The path to the policy to be evaluated')
parser.add_argument('--collect_one_episode', action='store_true', default=False, help='Whether to only collect one episode')
parser.add_argument('--lr', type=float, default=0.0003, help='The learning rate of the model')
parser.add_argument('--step_per_epoch', type=int, default=10000, help='The number of steps per epoch')
parser.add_argument('--step_per_collect', type=int, default=2000, help='The number of steps per collect')
parser.add_argument('--lr_decay', action='store_true', default=False, help='Whether to decay the learning rate')
parser.add_argument('--epoch', type=int, default=400, help='The number of epochs to train')
parser.add_argument('--resume_path', type=str, default=None, help='The path to the policy to be resumed')
parser.add_argument('--taxi_specific_env', action='store_true', default=False, help='Whether to use taxi specific env')
args = parser.parse_args()
args_dict = vars(args)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Get the specified translator, environment, and ChatGPT model
env_class = envs.REGISTRY[args.env]
init_summarizer = InitSummarizer(envs.REGISTRY[args.init_summarizer])
curr_summarizer = CurrSummarizer(envs.REGISTRY[args.curr_summarizer])
if args.future_summarizer:
future_summarizer = FutureSummarizer(
envs.REGISTRY[args.future_summarizer],
envs.REGISTRY["cart_policies"],
future_horizon=args.future_horizon,
)
else:
future_summarizer = None
wandb_log_config = {
"env": args.env_name,
"init_summarizer": args.init_summarizer,
"curr_summarizer": args.curr_summarizer,
"future_summarizer": args.future_summarizer,
}
wandb_log_config.update(args_dict)
if not args.eval:
logger = WandbLogger(
project="LLM-decider-bench-RL",
entity="llm-bench-team",
config=wandb_log_config,
)
random_name = logger.wandb_run.name
log_path = os.path.join('/home/ubuntu/LLM-Decider-Bench/RL_based/results', args.env_name, random_name)
writer = SummaryWriter(log_dir=log_path)
writer.add_text("args", str(args))
logger.load(writer)
def save_best_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))
sampling_env = envs.REGISTRY["sampling_wrapper"](gym.make(args.env_name))
if args.prompt_level == 5:
prompts_class = task_prompts.REGISTRY[(args.env_name,args.decider)]()
else:
prompts_class = task_prompts.REGISTRY[(args.decider)]()
translator = Translator(
init_summarizer, curr_summarizer, future_summarizer, env=sampling_env
)
if args.taxi_specific_env:
environment = gym.make(args.env_name, render_mode=args.render)
else:
environment = env_class(
gym.make(args.env_name, render_mode=args.render), translator
)
# Set the translation level
translate_level = 1
if args.past_horizon is None and args.future_horizon is None:
translate_level = 1
if args.past_horizon and args.future_horizon is None:
raise NotImplementedError
# translate_level = 2
if args.past_horizon is None and args.future_horizon:
raise NotImplementedError
# translate_level = 3
if args.past_horizon and args.future_horizon:
raise NotImplementedError
# translate_level = 3.5
if args.vector_env == 'dummy':
ThisEnv = DummyVectorEnv
elif args.vector_env == 'subproc':
ThisEnv = SubprocVectorEnv
def make_env():
if args.taxi_specific_env:
env = MaxStepLimitWrapper(SimpleTextWrapper(gym.make(args.env_name, render_mode=args.render)), max_steps=200)
env._max_episode_steps = args.max_episode_len
else:
env = env_class(MaxStepLimitWrapper(gym.make(args.env_name, render_mode=args.render), max_steps=200), translator)
env._max_episode_steps = args.max_episode_len
return env
train_envs = ThisEnv([make_env for _ in range(20)])
test_envs = ThisEnv([make_env for _ in range(10)])
# model & optimizer
def get_net():
if args.model_name == "bert-embedding":
net = Net_GRU_Bert_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[64, 64], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
elif args.model_name == "bert-CLS-embedding":
net = Net_Bert_CLS_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
elif args.model_name == "bert-CNN-embedding":
net = Net_Bert_CNN_tianshou(state_shape=environment.observation_space.shape, hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
elif args.model_name == "nn_embedding":
net = Net_GRU_nn_emb_tianshou(hidden_sizes=[256, 128], device=device, max_length=args.max_length, trans_model_name=args.trans_model_name)
return net
net = get_net()
actor = Actor(net, environment.action_space.n, device=device).to(device)
critic = Critic(net, device=device).to(device)
actor_critic = ActorCritic(actor, critic)
optim = torch.optim.Adam(actor_critic.parameters(), lr=args.lr)
# PPO policy
dist = torch.distributions.Categorical
lr_scheduler = None
if args.lr_decay:
max_update_num = args.step_per_epoch // args.step_per_collect * args.epoch
lr_scheduler = LambdaLR(optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num)
policy = PPOPolicy(actor, critic, optim, dist, action_space=environment.action_space, lr_scheduler=lr_scheduler).to(device)
# collector
train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)), exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
if not args.eval:
# trainer
# test train_collector and start filling replay buffer
if args.resume_path:
policy.load_state_dict(torch.load(args.resume_path, map_location='cuda'))
print("Loaded agent from: ", args.resume_path)
train_collector.collect(256 * 20)
result = onpolicy_trainer(
policy,
train_collector,
test_collector,
max_epoch=args.epoch,
step_per_epoch=50000, # the number of transitions collected per epoch
repeat_per_collect=4,
episode_per_test=10,
batch_size=256,
logger=logger,
step_per_collect=1000, # the number of transitions the collector would collect before the network update
save_best_fn=save_best_fn,
# stop_fn=lambda mean_reward: mean_reward >= environment.spec.reward_threshold,
)
print(result)
else:
assert args.policy_path is not None
policy.load_state_dict(torch.load(args.policy_path))
test_collector = Collector(policy, test_envs)
result = test_episode(policy, test_collector, None, None, n_episode=10)
print(result)
if args.collect_one_episode:
replaybuffer = ReplayBuffer(size=1000)
test_collector_1 = Collector(policy, environment, replaybuffer)
test_collector_1.reset_env()
test_collector_1.reset_buffer()
policy.eval()
result = test_collector_1.collect(n_episode=1)
print('sample results', f"/home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/{args.env_name}/output.txt")
sample_result = replaybuffer.sample(0)
f = open(f"/home/ubuntu/LLM-Decider-Bench/RL_based/checkpoints/{args.env_name}/output.txt", "w")
print(sample_result, file=f)
f.close()