File size: 4,270 Bytes
9b19c29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 |
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Self
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from tianshou.utils.string import ToStringMixin
@dataclass(kw_only=True)
class ActionSpaceInfo(ToStringMixin):
"""A data structure for storing the different attributes of the action space."""
action_shape: int | Sequence[int]
"""The shape of the action space."""
min_action: float
"""The smallest allowable action or in the continuous case the lower bound for allowable action value."""
max_action: float
"""The largest allowable action or in the continuous case the upper bound for allowable action value."""
@property
def action_dim(self) -> int:
"""Return the number of distinct actions (must be greater than zero) an agent can take it its action space."""
if isinstance(self.action_shape, int):
return self.action_shape
else:
return int(np.prod(self.action_shape))
@classmethod
def from_space(cls, space: spaces.Space) -> Self:
"""Instantiate the `ActionSpaceInfo` object from a `Space`, supported spaces are Box and Discrete."""
if isinstance(space, spaces.Box):
return cls(
action_shape=space.shape,
min_action=float(np.min(space.low)),
max_action=float(np.max(space.high)),
)
elif isinstance(space, spaces.Discrete):
return cls(
action_shape=int(space.n),
min_action=float(space.start),
max_action=float(space.start + space.n - 1),
)
else:
raise ValueError(
f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.",
)
def _tostring_additional_entries(self) -> dict[str, Any]:
return {"action_dim": self.action_dim}
@dataclass(kw_only=True)
class ObservationSpaceInfo(ToStringMixin):
"""A data structure for storing the different attributes of the observation space."""
obs_shape: int | Sequence[int]
"""The shape of the observation space."""
@property
def obs_dim(self) -> int:
"""Return the number of distinct features (must be greater than zero) or dimensions in the observation space."""
if isinstance(self.obs_shape, int):
return self.obs_shape
else:
return int(np.prod(self.obs_shape))
@classmethod
def from_space(cls, space: spaces.Space) -> Self:
"""Instantiate the `ObservationSpaceInfo` object from a `Space`, supported spaces are Box and Discrete."""
if isinstance(space, spaces.Box):
return cls(
obs_shape=space.shape,
)
elif isinstance(space, spaces.Discrete):
return cls(
obs_shape=int(space.n),
)
else:
raise ValueError(
f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.",
)
def _tostring_additional_entries(self) -> dict[str, Any]:
return {"obs_dim": self.obs_dim}
@dataclass(kw_only=True)
class SpaceInfo(ToStringMixin):
"""A data structure for storing the attributes of both the action and observation space."""
action_info: ActionSpaceInfo
"""Stores the attributes of the action space."""
observation_info: ObservationSpaceInfo
"""Stores the attributes of the observation space."""
@classmethod
def from_env(cls, env: gym.Env) -> Self:
"""Instantiate the `SpaceInfo` object from `gym.Env.action_space` and `gym.Env.observation_space`."""
return cls.from_spaces(env.action_space, env.observation_space)
@classmethod
def from_spaces(cls, action_space: spaces.Space, observation_space: spaces.Space) -> Self:
"""Instantiate the `SpaceInfo` object from `ActionSpaceInfo` and `ObservationSpaceInfo`."""
action_info = ActionSpaceInfo.from_space(action_space)
observation_info = ObservationSpaceInfo.from_space(observation_space)
return cls(
action_info=action_info,
observation_info=observation_info,
)
|