sabretoothedhugs's picture
v2
9b19c29
from typing import Any
import numpy as np
from tianshou.data import (
HERReplayBuffer,
HERReplayBufferManager,
PrioritizedReplayBuffer,
PrioritizedReplayBufferManager,
ReplayBuffer,
ReplayBufferManager,
)
class VectorReplayBuffer(ReplayBufferManager):
"""VectorReplayBuffer contains n ReplayBuffer with the same size.
It is used for storing transition from different environments yet keeping the order
of time.
:param total_size: the total size of VectorReplayBuffer.
:param buffer_num: the number of ReplayBuffer it uses, which are under the same
configuration.
Other input arguments (stack_num/ignore_obs_next/save_only_last_obs/sample_avail)
are the same as :class:`~tianshou.data.ReplayBuffer`.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
assert buffer_num > 0
size = int(np.ceil(total_size / buffer_num))
buffer_list = [ReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
super().__init__(buffer_list)
class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager):
"""PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size.
It is used for storing transition from different environments yet keeping the order
of time.
:param total_size: the total size of PrioritizedVectorReplayBuffer.
:param buffer_num: the number of PrioritizedReplayBuffer it uses, which are
under the same configuration.
Other input arguments (alpha/beta/stack_num/ignore_obs_next/save_only_last_obs/
sample_avail) are the same as :class:`~tianshou.data.PrioritizedReplayBuffer`.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
assert buffer_num > 0
size = int(np.ceil(total_size / buffer_num))
buffer_list = [PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
super().__init__(buffer_list)
def set_beta(self, beta: float) -> None:
for buffer in self.buffers:
buffer.set_beta(beta)
class HERVectorReplayBuffer(HERReplayBufferManager):
"""HERVectorReplayBuffer contains n HERReplayBuffer with same size.
It is used for storing transition from different environments yet keeping the order
of time.
:param total_size: the total size of HERVectorReplayBuffer.
:param buffer_num: the number of HERReplayBuffer it uses, which are
under the same configuration.
Other input arguments are the same as :class:`~tianshou.data.HERReplayBuffer`.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
assert buffer_num > 0
size = int(np.ceil(total_size / buffer_num))
buffer_list = [HERReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
super().__init__(buffer_list)