Spaces:
Sleeping
Sleeping
"""Contains methods for step compatibility, from old-to-new and new-to-old API.""" | |
from typing import Tuple, Union | |
import numpy as np | |
from gym.core import ObsType | |
DoneStepType = Tuple[ | |
Union[ObsType, np.ndarray], | |
Union[float, np.ndarray], | |
Union[bool, np.ndarray], | |
Union[dict, list], | |
] | |
TerminatedTruncatedStepType = Tuple[ | |
Union[ObsType, np.ndarray], | |
Union[float, np.ndarray], | |
Union[bool, np.ndarray], | |
Union[bool, np.ndarray], | |
Union[dict, list], | |
] | |
def convert_to_terminated_truncated_step_api( | |
step_returns: Union[DoneStepType, TerminatedTruncatedStepType], is_vector_env=False | |
) -> TerminatedTruncatedStepType: | |
"""Function to transform step returns to new step API irrespective of input API. | |
Args: | |
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) | |
is_vector_env (bool): Whether the step_returns are from a vector environment | |
""" | |
if len(step_returns) == 5: | |
return step_returns | |
else: | |
assert len(step_returns) == 4 | |
observations, rewards, dones, infos = step_returns | |
# Cases to handle - info single env / info vector env (list) / info vector env (dict) | |
if is_vector_env is False: | |
truncated = infos.pop("TimeLimit.truncated", False) | |
return ( | |
observations, | |
rewards, | |
dones and not truncated, | |
dones and truncated, | |
infos, | |
) | |
elif isinstance(infos, list): | |
truncated = np.array( | |
[info.pop("TimeLimit.truncated", False) for info in infos] | |
) | |
return ( | |
observations, | |
rewards, | |
np.logical_and(dones, np.logical_not(truncated)), | |
np.logical_and(dones, truncated), | |
infos, | |
) | |
elif isinstance(infos, dict): | |
num_envs = len(dones) | |
truncated = infos.pop("TimeLimit.truncated", np.zeros(num_envs, dtype=bool)) | |
return ( | |
observations, | |
rewards, | |
np.logical_and(dones, np.logical_not(truncated)), | |
np.logical_and(dones, truncated), | |
infos, | |
) | |
else: | |
raise TypeError( | |
f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}" | |
) | |
def convert_to_done_step_api( | |
step_returns: Union[TerminatedTruncatedStepType, DoneStepType], | |
is_vector_env: bool = False, | |
) -> DoneStepType: | |
"""Function to transform step returns to old step API irrespective of input API. | |
Args: | |
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) | |
is_vector_env (bool): Whether the step_returns are from a vector environment | |
""" | |
if len(step_returns) == 4: | |
return step_returns | |
else: | |
assert len(step_returns) == 5 | |
observations, rewards, terminated, truncated, infos = step_returns | |
# Cases to handle - info single env / info vector env (list) / info vector env (dict) | |
if is_vector_env is False: | |
if truncated or terminated: | |
infos["TimeLimit.truncated"] = truncated and not terminated | |
return ( | |
observations, | |
rewards, | |
terminated or truncated, | |
infos, | |
) | |
elif isinstance(infos, list): | |
for info, env_truncated, env_terminated in zip( | |
infos, truncated, terminated | |
): | |
if env_truncated or env_terminated: | |
info["TimeLimit.truncated"] = env_truncated and not env_terminated | |
return ( | |
observations, | |
rewards, | |
np.logical_or(terminated, truncated), | |
infos, | |
) | |
elif isinstance(infos, dict): | |
if np.logical_or(np.any(truncated), np.any(terminated)): | |
infos["TimeLimit.truncated"] = np.logical_and( | |
truncated, np.logical_not(terminated) | |
) | |
return ( | |
observations, | |
rewards, | |
np.logical_or(terminated, truncated), | |
infos, | |
) | |
else: | |
raise TypeError( | |
f"Unexpected value of infos, as is_vector_envs=False, expects `info` to be a list or dict, actual type: {type(infos)}" | |
) | |
def step_api_compatibility( | |
step_returns: Union[TerminatedTruncatedStepType, DoneStepType], | |
output_truncation_bool: bool = True, | |
is_vector_env: bool = False, | |
) -> Union[TerminatedTruncatedStepType, DoneStepType]: | |
"""Function to transform step returns to the API specified by `output_truncation_bool` bool. | |
Done (old) step API refers to step() method returning (observation, reward, done, info) | |
Terminated Truncated (new) step API refers to step() method returning (observation, reward, terminated, truncated, info) | |
(Refer to docs for details on the API change) | |
Args: | |
step_returns (tuple): Items returned by step(). Can be (obs, rew, done, info) or (obs, rew, terminated, truncated, info) | |
output_truncation_bool (bool): Whether the output should return two booleans (new API) or one (old) (True by default) | |
is_vector_env (bool): Whether the step_returns are from a vector environment | |
Returns: | |
step_returns (tuple): Depending on `output_truncation_bool` bool, it can return (obs, rew, done, info) or (obs, rew, terminated, truncated, info) | |
Examples: | |
This function can be used to ensure compatibility in step interfaces with conflicting API. Eg. if env is written in old API, | |
wrapper is written in new API, and the final step output is desired to be in old API. | |
>>> obs, rew, done, info = step_api_compatibility(env.step(action), output_truncation_bool=False) | |
>>> obs, rew, terminated, truncated, info = step_api_compatibility(env.step(action), output_truncation_bool=True) | |
>>> observations, rewards, dones, infos = step_api_compatibility(vec_env.step(action), is_vector_env=True) | |
""" | |
if output_truncation_bool: | |
return convert_to_terminated_truncated_step_api(step_returns, is_vector_env) | |
else: | |
return convert_to_done_step_api(step_returns, is_vector_env) | |