Spaces:
Running
on
Zero
Running
on
Zero
from typing import List, Dict, Optional | |
import numpy as np | |
import gym | |
from gym.spaces import Box | |
from robomimic.envs.env_robosuite import EnvRobosuite | |
class RobomimicLowdimWrapper(gym.Env): | |
def __init__( | |
self, | |
env: EnvRobosuite, | |
obs_keys: List[str] = ["object", "robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"], | |
init_state: Optional[np.ndarray] = None, | |
render_hw=(256, 256), | |
render_camera_name="agentview", | |
): | |
self.env = env | |
self.obs_keys = obs_keys | |
self.init_state = init_state | |
self.render_hw = render_hw | |
self.render_camera_name = render_camera_name | |
self.seed_state_map = dict() | |
self._seed = None | |
# import IPython; IPython.embed() | |
# setup spaces | |
low = np.full(env.action_dimension, fill_value=-1) | |
high = np.full(env.action_dimension, fill_value=1) | |
self.action_space = Box( | |
low=low, | |
high=high, | |
) | |
obs_example = self.get_observation() | |
low = np.full_like(obs_example, fill_value=-1) | |
high = np.full_like(obs_example, fill_value=1) | |
self.observation_space = Box( | |
low=low, | |
high=high, | |
) | |
def get_observation(self): | |
raw_obs = self.env.get_observation() | |
obs = np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0) | |
return obs | |
def seed(self, seed=None): | |
np.random.seed(seed=seed) | |
self._seed = seed | |
def reset(self): | |
if self.init_state is not None: | |
# always reset to the same state | |
# to be compatible with gym | |
self.env.reset_to({"states": self.init_state}) | |
elif self._seed is not None: | |
# reset to a specific seed | |
seed = self._seed | |
if seed in self.seed_state_map: | |
# env.reset is expensive, use cache | |
self.env.reset_to({"states": self.seed_state_map[seed]}) | |
else: | |
# robosuite's initializes all use numpy global random state | |
np.random.seed(seed=seed) | |
self.env.reset() | |
state = self.env.get_state()["states"] | |
self.seed_state_map[seed] = state | |
self._seed = None | |
else: | |
# random reset | |
self.env.reset() | |
# return obs | |
obs = self.get_observation() | |
return obs | |
def step(self, action): | |
raw_obs, reward, done, info = self.env.step(action) | |
obs = np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0) | |
return obs, reward, done, info | |
def render(self, mode="rgb_array"): | |
h, w = self.render_hw | |
return self.env.render(mode=mode, height=h, width=w, camera_name=self.render_camera_name) | |
def close(self): | |
self.env.env.close() |