File size: 3,185 Bytes
9b19c29 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from typing import Any
import numpy as np
from tianshou.data import (
HERReplayBuffer,
HERReplayBufferManager,
PrioritizedReplayBuffer,
PrioritizedReplayBufferManager,
ReplayBuffer,
ReplayBufferManager,
)
class VectorReplayBuffer(ReplayBufferManager):
"""VectorReplayBuffer contains n ReplayBuffer with the same size.
It is used for storing transition from different environments yet keeping the order
of time.
:param total_size: the total size of VectorReplayBuffer.
:param buffer_num: the number of ReplayBuffer it uses, which are under the same
configuration.
Other input arguments (stack_num/ignore_obs_next/save_only_last_obs/sample_avail)
are the same as :class:`~tianshou.data.ReplayBuffer`.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
assert buffer_num > 0
size = int(np.ceil(total_size / buffer_num))
buffer_list = [ReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
super().__init__(buffer_list)
class PrioritizedVectorReplayBuffer(PrioritizedReplayBufferManager):
"""PrioritizedVectorReplayBuffer contains n PrioritizedReplayBuffer with same size.
It is used for storing transition from different environments yet keeping the order
of time.
:param total_size: the total size of PrioritizedVectorReplayBuffer.
:param buffer_num: the number of PrioritizedReplayBuffer it uses, which are
under the same configuration.
Other input arguments (alpha/beta/stack_num/ignore_obs_next/save_only_last_obs/
sample_avail) are the same as :class:`~tianshou.data.PrioritizedReplayBuffer`.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
assert buffer_num > 0
size = int(np.ceil(total_size / buffer_num))
buffer_list = [PrioritizedReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
super().__init__(buffer_list)
def set_beta(self, beta: float) -> None:
for buffer in self.buffers:
buffer.set_beta(beta)
class HERVectorReplayBuffer(HERReplayBufferManager):
"""HERVectorReplayBuffer contains n HERReplayBuffer with same size.
It is used for storing transition from different environments yet keeping the order
of time.
:param total_size: the total size of HERVectorReplayBuffer.
:param buffer_num: the number of HERReplayBuffer it uses, which are
under the same configuration.
Other input arguments are the same as :class:`~tianshou.data.HERReplayBuffer`.
.. seealso::
Please refer to :class:`~tianshou.data.ReplayBuffer` for other APIs' usage.
"""
def __init__(self, total_size: int, buffer_num: int, **kwargs: Any) -> None:
assert buffer_num > 0
size = int(np.ceil(total_size / buffer_num))
buffer_list = [HERReplayBuffer(size, **kwargs) for _ in range(buffer_num)]
super().__init__(buffer_list)
|