Spaces:
Sleeping
Sleeping
from easydict import EasyDict | |
from typing import Callable | |
def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]: | |
""" | |
Overview: | |
Get the rollout length scheduler that adapts rollout length based\ | |
on the current environment steps. | |
Returns: | |
- scheduler (:obj:`Callble`): The function that takes envstep and\ | |
return the current rollout length. | |
""" | |
if cfg.type == 'linear': | |
x0 = cfg.rollout_start_step | |
x1 = cfg.rollout_end_step | |
y0 = cfg.rollout_length_min | |
y1 = cfg.rollout_length_max | |
w = (y1 - y0) / (x1 - x0) | |
b = y0 | |
return lambda x: int(min(max(w * (x - x0) + b, y0), y1)) | |
elif cfg.type == 'constant': | |
return lambda x: cfg.rollout_length | |
else: | |
raise KeyError("not implemented key: {}".format(cfg.type)) | |