sabretoothedhugs's picture
v2
9b19c29
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, Self
import gymnasium as gym
import numpy as np
from gymnasium import spaces
from tianshou.utils.string import ToStringMixin
@dataclass(kw_only=True)
class ActionSpaceInfo(ToStringMixin):
"""A data structure for storing the different attributes of the action space."""
action_shape: int | Sequence[int]
"""The shape of the action space."""
min_action: float
"""The smallest allowable action or in the continuous case the lower bound for allowable action value."""
max_action: float
"""The largest allowable action or in the continuous case the upper bound for allowable action value."""
@property
def action_dim(self) -> int:
"""Return the number of distinct actions (must be greater than zero) an agent can take it its action space."""
if isinstance(self.action_shape, int):
return self.action_shape
else:
return int(np.prod(self.action_shape))
@classmethod
def from_space(cls, space: spaces.Space) -> Self:
"""Instantiate the `ActionSpaceInfo` object from a `Space`, supported spaces are Box and Discrete."""
if isinstance(space, spaces.Box):
return cls(
action_shape=space.shape,
min_action=float(np.min(space.low)),
max_action=float(np.max(space.high)),
)
elif isinstance(space, spaces.Discrete):
return cls(
action_shape=int(space.n),
min_action=float(space.start),
max_action=float(space.start + space.n - 1),
)
else:
raise ValueError(
f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.",
)
def _tostring_additional_entries(self) -> dict[str, Any]:
return {"action_dim": self.action_dim}
@dataclass(kw_only=True)
class ObservationSpaceInfo(ToStringMixin):
"""A data structure for storing the different attributes of the observation space."""
obs_shape: int | Sequence[int]
"""The shape of the observation space."""
@property
def obs_dim(self) -> int:
"""Return the number of distinct features (must be greater than zero) or dimensions in the observation space."""
if isinstance(self.obs_shape, int):
return self.obs_shape
else:
return int(np.prod(self.obs_shape))
@classmethod
def from_space(cls, space: spaces.Space) -> Self:
"""Instantiate the `ObservationSpaceInfo` object from a `Space`, supported spaces are Box and Discrete."""
if isinstance(space, spaces.Box):
return cls(
obs_shape=space.shape,
)
elif isinstance(space, spaces.Discrete):
return cls(
obs_shape=int(space.n),
)
else:
raise ValueError(
f"Unsupported space type: {space.__class__}. Currently supported types are Discrete and Box.",
)
def _tostring_additional_entries(self) -> dict[str, Any]:
return {"obs_dim": self.obs_dim}
@dataclass(kw_only=True)
class SpaceInfo(ToStringMixin):
"""A data structure for storing the attributes of both the action and observation space."""
action_info: ActionSpaceInfo
"""Stores the attributes of the action space."""
observation_info: ObservationSpaceInfo
"""Stores the attributes of the observation space."""
@classmethod
def from_env(cls, env: gym.Env) -> Self:
"""Instantiate the `SpaceInfo` object from `gym.Env.action_space` and `gym.Env.observation_space`."""
return cls.from_spaces(env.action_space, env.observation_space)
@classmethod
def from_spaces(cls, action_space: spaces.Space, observation_space: spaces.Space) -> Self:
"""Instantiate the `SpaceInfo` object from `ActionSpaceInfo` and `ObservationSpaceInfo`."""
action_info = ActionSpaceInfo.from_space(action_space)
observation_info = ObservationSpaceInfo.from_space(observation_space)
return cls(
action_info=action_info,
observation_info=observation_info,
)