Spaces:
Sleeping
Sleeping
from easydict import EasyDict | |
from functools import partial | |
from tensorboardX import SummaryWriter | |
import metadrive | |
import gym | |
from ding.envs import BaseEnvManager, SyncSubprocessEnvManager | |
from ding.config import compile_config | |
from ding.model.template import ContinuousQAC, VAC | |
from ding.policy import PPOPolicy | |
from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner | |
from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv | |
from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper | |
metadrive_basic_config = dict( | |
exp_name='metadrive_onppo_seed0', | |
env=dict( | |
metadrive=dict( | |
use_render=False, | |
traffic_density=0.10, # Density of vehicles occupying the roads, range in [0,1] | |
map='XSOS', # Int or string: an easy way to fill map_config | |
horizon=4000, # Max step number | |
driving_reward=1.0, # Reward to encourage agent to move forward. | |
speed_reward=0.1, # Reward to encourage agent to drive at a high speed | |
use_lateral_reward=False, # reward for lane keeping | |
out_of_road_penalty=40.0, # Penalty to discourage driving out of road | |
crash_vehicle_penalty=40.0, # Penalty to discourage collision | |
decision_repeat=20, # Reciprocal of decision frequency | |
out_of_route_done=True, # Game over if driving out of road | |
), | |
manager=dict( | |
shared_memory=False, | |
max_retry=2, | |
context='spawn', | |
), | |
n_evaluator_episode=16, | |
stop_value=255, | |
collector_env_num=8, | |
evaluator_env_num=8, | |
), | |
policy=dict( | |
cuda=True, | |
action_space='continuous', | |
model=dict( | |
obs_shape=[5, 84, 84], | |
action_shape=2, | |
action_space='continuous', | |
bound_type='tanh', | |
encoder_hidden_size_list=[128, 128, 64], | |
), | |
learn=dict( | |
epoch_per_collect=10, | |
batch_size=64, | |
learning_rate=3e-4, | |
entropy_weight=0.001, | |
value_weight=0.5, | |
clip_ratio=0.02, | |
adv_norm=False, | |
value_norm=True, | |
grad_clip_value=10, | |
), | |
collect=dict(n_sample=3000, ), | |
eval=dict(evaluator=dict(eval_freq=1000, ), ), | |
), | |
) | |
main_config = EasyDict(metadrive_basic_config) | |
def wrapped_env(env_cfg, wrapper_cfg=None): | |
return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg) | |
def main(cfg): | |
cfg = compile_config( | |
cfg, SyncSubprocessEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator | |
) | |
collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num | |
collector_env = SyncSubprocessEnvManager( | |
env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(collector_env_num)], | |
cfg=cfg.env.manager, | |
) | |
evaluator_env = SyncSubprocessEnvManager( | |
env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(evaluator_env_num)], | |
cfg=cfg.env.manager, | |
) | |
model = VAC(**cfg.policy.model) | |
policy = PPOPolicy(cfg.policy, model=model) | |
tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name)) | |
learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name) | |
collector = SampleSerialCollector( | |
cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name | |
) | |
evaluator = InteractionSerialEvaluator( | |
cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name | |
) | |
learner.call_hook('before_run') | |
while True: | |
if evaluator.should_eval(learner.train_iter): | |
stop, rate = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep) | |
if stop: | |
break | |
# Sampling data from environments | |
new_data = collector.collect(cfg.policy.collect.n_sample, train_iter=learner.train_iter) | |
learner.train(new_data, collector.envstep) | |
learner.call_hook('after_run') | |
collector.close() | |
evaluator.close() | |
learner.close() | |
if __name__ == '__main__': | |
main(main_config) | |