hma / sim /robomimic /robomimic_wrapper.py
LeroyWaa's picture
draft
246c106
raw
history blame
2.87 kB
from typing import List, Dict, Optional
import numpy as np
import gym
from gym.spaces import Box
from robomimic.envs.env_robosuite import EnvRobosuite
class RobomimicLowdimWrapper(gym.Env):
def __init__(
self,
env: EnvRobosuite,
obs_keys: List[str] = ["object", "robot0_eef_pos", "robot0_eef_quat", "robot0_gripper_qpos"],
init_state: Optional[np.ndarray] = None,
render_hw=(256, 256),
render_camera_name="agentview",
):
self.env = env
self.obs_keys = obs_keys
self.init_state = init_state
self.render_hw = render_hw
self.render_camera_name = render_camera_name
self.seed_state_map = dict()
self._seed = None
# import IPython; IPython.embed()
# setup spaces
low = np.full(env.action_dimension, fill_value=-1)
high = np.full(env.action_dimension, fill_value=1)
self.action_space = Box(
low=low,
high=high,
)
obs_example = self.get_observation()
low = np.full_like(obs_example, fill_value=-1)
high = np.full_like(obs_example, fill_value=1)
self.observation_space = Box(
low=low,
high=high,
)
def get_observation(self):
raw_obs = self.env.get_observation()
obs = np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0)
return obs
def seed(self, seed=None):
np.random.seed(seed=seed)
self._seed = seed
def reset(self):
if self.init_state is not None:
# always reset to the same state
# to be compatible with gym
self.env.reset_to({"states": self.init_state})
elif self._seed is not None:
# reset to a specific seed
seed = self._seed
if seed in self.seed_state_map:
# env.reset is expensive, use cache
self.env.reset_to({"states": self.seed_state_map[seed]})
else:
# robosuite's initializes all use numpy global random state
np.random.seed(seed=seed)
self.env.reset()
state = self.env.get_state()["states"]
self.seed_state_map[seed] = state
self._seed = None
else:
# random reset
self.env.reset()
# return obs
obs = self.get_observation()
return obs
def step(self, action):
raw_obs, reward, done, info = self.env.step(action)
obs = np.concatenate([raw_obs[key] for key in self.obs_keys], axis=0)
return obs, reward, done, info
def render(self, mode="rgb_array"):
h, w = self.render_hw
return self.env.render(mode=mode, height=h, width=w, camera_name=self.render_camera_name)
def close(self):
self.env.env.close()