Spaces:
Sleeping
Sleeping
from typing import Any, Dict, Optional | |
from easydict import EasyDict | |
import matplotlib.pyplot as plt | |
import gym | |
import copy | |
import numpy as np | |
from ding.envs.env.base_env import BaseEnvTimestep | |
from ding.torch_utils.data_helper import to_ndarray | |
from ding.utils.default_helper import deep_merge_dicts | |
from dizoo.metadrive.env.drive_utils import BaseDriveEnv | |
def draw_multi_channels_top_down_observation(obs, show_time=0.5): | |
num_channels = obs.shape[-1] | |
assert num_channels == 5 | |
channel_names = [ | |
"Road and navigation", "Ego now and previous pos", "Neighbor at step t", "Neighbor at step t-1", | |
"Neighbor at step t-2" | |
] | |
fig, axs = plt.subplots(1, num_channels, figsize=(15, 4), dpi=80) | |
count = 0 | |
def close_event(): | |
plt.close() | |
timer = fig.canvas.new_timer(interval=show_time * 1000) | |
timer.add_callback(close_event) | |
for i, name in enumerate(channel_names): | |
count += 1 | |
ax = axs[i] | |
ax.imshow(obs[..., i], cmap="bone") | |
ax.set_xticks([]) | |
ax.set_yticks([]) | |
ax.set_title(name) | |
fig.suptitle("Multi-channels Top-down Observation") | |
timer.start() | |
plt.show() | |
plt.close() | |
class DriveEnvWrapper(gym.Wrapper): | |
""" | |
Overview: | |
Environment wrapper to make ``gym.Env`` align with DI-engine definitions, so as to use utilities in DI-engine. | |
It changes ``step``, ``reset`` and ``info`` method of ``gym.Env``, while others are straightly delivered. | |
Arguments: | |
- env (BaseDriveEnv): The environment to be wrapped. | |
- cfg (Dict): Config dict. | |
""" | |
config = dict() | |
def __init__(self, env: BaseDriveEnv, cfg: Dict = None, **kwargs) -> None: | |
if cfg is None: | |
self._cfg = self.__class__.default_config() | |
elif 'cfg_type' not in cfg: | |
self._cfg = self.__class__.default_config() | |
self._cfg = deep_merge_dicts(self._cfg, cfg) | |
else: | |
self._cfg = cfg | |
self.env = env | |
if not hasattr(self.env, 'reward_space'): | |
self.reward_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(1, )) | |
if 'show_bird_view' in self._cfg and self._cfg['show_bird_view'] is True: | |
self.show_bird_view = True | |
else: | |
self.show_bird_view = False | |
self.action_space = self.env.action_space | |
self.env = env | |
def reset(self, *args, **kwargs) -> Any: | |
""" | |
Overview: | |
Wrapper of ``reset`` method in env. The observations are converted to ``np.ndarray`` and final reward | |
are recorded. | |
Returns: | |
- Any: Observations from environment | |
""" | |
obs = self.env.reset(*args, **kwargs) | |
obs = to_ndarray(obs, dtype=np.float32) | |
if isinstance(obs, np.ndarray) and len(obs.shape) == 3: | |
obs = obs.transpose((2, 0, 1)) | |
elif isinstance(obs, dict): | |
vehicle_state = obs['vehicle_state'] | |
birdview = obs['birdview'].transpose((2, 0, 1)) | |
obs = {'vehicle_state': vehicle_state, 'birdview': birdview} | |
self._eval_episode_return = 0.0 | |
self._arrive_dest = False | |
return obs | |
def step(self, action: Any = None) -> BaseEnvTimestep: | |
""" | |
Overview: | |
Wrapper of ``step`` method in env. This aims to convert the returns of ``gym.Env`` step method into | |
that of ``ding.envs.BaseEnv``, from ``(obs, reward, done, info)`` tuple to a ``BaseEnvTimestep`` | |
namedtuple defined in DI-engine. It will also convert actions, observations and reward into | |
``np.ndarray``, and check legality if action contains control signal. | |
Arguments: | |
- action (Any, optional): Actions sent to env. Defaults to None. | |
Returns: | |
- BaseEnvTimestep: DI-engine format of env step returns. | |
""" | |
action = to_ndarray(action) | |
obs, rew, done, info = self.env.step(action) | |
if self.show_bird_view: | |
draw_multi_channels_top_down_observation(obs, show_time=0.5) | |
self._eval_episode_return += rew | |
obs = to_ndarray(obs, dtype=np.float32) | |
if isinstance(obs, np.ndarray) and len(obs.shape) == 3: | |
obs = obs.transpose((2, 0, 1)) | |
elif isinstance(obs, dict): | |
vehicle_state = obs['vehicle_state'] | |
birdview = obs['birdview'].transpose((2, 0, 1)) | |
obs = {'vehicle_state': vehicle_state, 'birdview': birdview} | |
rew = to_ndarray([rew], dtype=np.float32) | |
if done: | |
info['eval_episode_return'] = self._eval_episode_return | |
return BaseEnvTimestep(obs, rew, done, info) | |
def observation_space(self): | |
return gym.spaces.Box(0, 1, shape=(5, 84, 84), dtype=np.float32) | |
def seed(self, seed: int, dynamic_seed: bool = True) -> None: | |
self._seed = seed | |
self._dynamic_seed = dynamic_seed | |
np.random.seed(self._seed) | |
def enable_save_replay(self, replay_path: Optional[str] = None) -> None: | |
if replay_path is None: | |
replay_path = './video' | |
self._replay_path = replay_path | |
self.env = gym.wrappers.Monitor(self.env, self._replay_path, video_callable=lambda episode_id: True, force=True) | |
def default_config(cls: type) -> EasyDict: | |
cfg = EasyDict(cls.config) | |
cfg.cfg_type = cls.__name__ + 'Config' | |
return copy.deepcopy(cfg) | |
def __repr__(self) -> str: | |
return repr(self.env) | |
def render(self): | |
self.env.render() | |
def clone(self, caller: str): | |
cfg = copy.deepcopy(self._cfg) | |
return DriveEnvWrapper(self.env.clone(caller), cfg) | |