ppo-AntBulletEnv-v0 / rl_algo_impls /wrappers /initial_step_truncate_wrapper.py
sgoodfriend's picture
PPO playing AntBulletEnv-v0 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
b9803e8
raw
history blame
964 Bytes
import gym
import numpy as np
from typing import Any, Dict, Tuple, Union
from rl_algo_impls.wrappers.vectorable_wrapper import VecotarableWrapper
ObsType = Union[np.ndarray, dict]
ActType = Union[int, float, np.ndarray, dict]
class InitialStepTruncateWrapper(VecotarableWrapper):
def __init__(self, env: gym.Env, initial_steps_to_truncate: int) -> None:
super().__init__(env)
self.initial_steps_to_truncate = initial_steps_to_truncate
self.initialized = initial_steps_to_truncate == 0
self.steps = 0
def step(self, action: ActType) -> Tuple[ObsType, float, bool, Dict[str, Any]]:
obs, rew, done, info = self.env.step(action)
if not self.initialized:
self.steps += 1
if self.steps >= self.initial_steps_to_truncate:
print(f"Truncation at {self.steps} steps")
done = True
self.initialized = True
return obs, rew, done, info