Spaces:
Running
Running
File size: 5,705 Bytes
375a1cf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
import numpy as np
from gym import utils
from gym.envs.mujoco import MuJocoPyEnv
from gym.spaces import Box
DEFAULT_CAMERA_CONFIG = {
"distance": 4.0,
}
class AntEnv(MuJocoPyEnv, utils.EzPickle):
metadata = {
"render_modes": [
"human",
"rgb_array",
"depth_array",
],
"render_fps": 20,
}
def __init__(
self,
xml_file="ant.xml",
ctrl_cost_weight=0.5,
contact_cost_weight=5e-4,
healthy_reward=1.0,
terminate_when_unhealthy=True,
healthy_z_range=(0.2, 1.0),
contact_force_range=(-1.0, 1.0),
reset_noise_scale=0.1,
exclude_current_positions_from_observation=True,
**kwargs
):
utils.EzPickle.__init__(
self,
xml_file,
ctrl_cost_weight,
contact_cost_weight,
healthy_reward,
terminate_when_unhealthy,
healthy_z_range,
contact_force_range,
reset_noise_scale,
exclude_current_positions_from_observation,
**kwargs
)
self._ctrl_cost_weight = ctrl_cost_weight
self._contact_cost_weight = contact_cost_weight
self._healthy_reward = healthy_reward
self._terminate_when_unhealthy = terminate_when_unhealthy
self._healthy_z_range = healthy_z_range
self._contact_force_range = contact_force_range
self._reset_noise_scale = reset_noise_scale
self._exclude_current_positions_from_observation = (
exclude_current_positions_from_observation
)
if exclude_current_positions_from_observation:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(111,), dtype=np.float64
)
else:
observation_space = Box(
low=-np.inf, high=np.inf, shape=(113,), dtype=np.float64
)
MuJocoPyEnv.__init__(
self, xml_file, 5, observation_space=observation_space, **kwargs
)
@property
def healthy_reward(self):
return (
float(self.is_healthy or self._terminate_when_unhealthy)
* self._healthy_reward
)
def control_cost(self, action):
control_cost = self._ctrl_cost_weight * np.sum(np.square(action))
return control_cost
@property
def contact_forces(self):
raw_contact_forces = self.sim.data.cfrc_ext
min_value, max_value = self._contact_force_range
contact_forces = np.clip(raw_contact_forces, min_value, max_value)
return contact_forces
@property
def contact_cost(self):
contact_cost = self._contact_cost_weight * np.sum(
np.square(self.contact_forces)
)
return contact_cost
@property
def is_healthy(self):
state = self.state_vector()
min_z, max_z = self._healthy_z_range
is_healthy = np.isfinite(state).all() and min_z <= state[2] <= max_z
return is_healthy
@property
def terminated(self):
terminated = not self.is_healthy if self._terminate_when_unhealthy else False
return terminated
def step(self, action):
xy_position_before = self.get_body_com("torso")[:2].copy()
self.do_simulation(action, self.frame_skip)
xy_position_after = self.get_body_com("torso")[:2].copy()
xy_velocity = (xy_position_after - xy_position_before) / self.dt
x_velocity, y_velocity = xy_velocity
ctrl_cost = self.control_cost(action)
contact_cost = self.contact_cost
forward_reward = x_velocity
healthy_reward = self.healthy_reward
rewards = forward_reward + healthy_reward
costs = ctrl_cost + contact_cost
reward = rewards - costs
terminated = self.terminated
observation = self._get_obs()
info = {
"reward_forward": forward_reward,
"reward_ctrl": -ctrl_cost,
"reward_contact": -contact_cost,
"reward_survive": healthy_reward,
"x_position": xy_position_after[0],
"y_position": xy_position_after[1],
"distance_from_origin": np.linalg.norm(xy_position_after, ord=2),
"x_velocity": x_velocity,
"y_velocity": y_velocity,
"forward_reward": forward_reward,
}
if self.render_mode == "human":
self.render()
return observation, reward, terminated, False, info
def _get_obs(self):
position = self.sim.data.qpos.flat.copy()
velocity = self.sim.data.qvel.flat.copy()
contact_force = self.contact_forces.flat.copy()
if self._exclude_current_positions_from_observation:
position = position[2:]
observations = np.concatenate((position, velocity, contact_force))
return observations
def reset_model(self):
noise_low = -self._reset_noise_scale
noise_high = self._reset_noise_scale
qpos = self.init_qpos + self.np_random.uniform(
low=noise_low, high=noise_high, size=self.model.nq
)
qvel = (
self.init_qvel
+ self._reset_noise_scale * self.np_random.standard_normal(self.model.nv)
)
self.set_state(qpos, qvel)
observation = self._get_obs()
return observation
def viewer_setup(self):
assert self.viewer is not None
for key, value in DEFAULT_CAMERA_CONFIG.items():
if isinstance(value, np.ndarray):
getattr(self.viewer.cam, key)[:] = value
else:
setattr(self.viewer.cam, key, value)
|